In [6]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [7]:
from src.models.modules import *
from src.models.loss import L1_clamp_loss
from dataclasses import dataclass
import torch

# process_models(source_dir, target_dir, train_ratio=0.8, context_size=64, batch_size=128, num_samples=300_032)

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    num_inds: int = 32
    dim_hidden: int = 128
    num_heads: int = 4
    ln: bool = False

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.enc = nn.Sequential(
            ISAB(config.dim_context, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln)
        )
        self.input_proj = nn.Linear(config.dim_input, config.dim_hidden)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.dec = nn.Sequential(
            PMA(config.dim_hidden, config.num_heads, config.num_outputs, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
        )
        self.regr = nn.Sequential(
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        y = self.enc(context)           # [batch_size, context_size, dim_hidden]
        x = self.input_proj(x)          # [batch_size, num_outputs, dim_hidden]
        x = x.repeat(1, y.shape[1], 1)  # [batch_size, context_size, dim_hidden]
        y = self.cross(x, y)
        y = self.dec(y)
        y = self.regr(y)

        loss = None
        if labels is not None:
            loss = L1_clamp_loss(y, labels, self.config.delta)
        return {'loss': loss, 'logits': y}

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [8]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [9]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_11_11_set_transformer'
current_date = datetime.now().strftime("%Y-%m-%d")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_11_11_set_transformer-2024-11-11


In [10]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [11]:
trainer.train()

  0%|          | 0/90000 [00:00<?, ?it/s]

{'loss': 0.0187, 'grad_norm': 1.239774465560913, 'learning_rate': 4.9994444444444446e-05, 'epoch': 0.0}
{'loss': 0.0087, 'grad_norm': 0.7377306222915649, 'learning_rate': 4.998888888888889e-05, 'epoch': 0.0}
{'loss': 0.0091, 'grad_norm': 0.7301199436187744, 'learning_rate': 4.998333333333334e-05, 'epoch': 0.0}
{'loss': 0.0071, 'grad_norm': 0.8633569478988647, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0057, 'grad_norm': 0.6636713743209839, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.006, 'grad_norm': 0.5335493087768555, 'learning_rate': 4.996666666666667e-05, 'epoch': 0.0}
{'loss': 0.0059, 'grad_norm': 0.9268941879272461, 'learning_rate': 4.9961111111111114e-05, 'epoch': 0.0}
{'loss': 0.0056, 'grad_norm': 0.596102237701416, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0069, 'grad_norm': 0.07587701082229614, 'learning_rate': 4.995e-05, 'epoch': 0.01}
{'loss': 0.0058, 'grad_norm': 0.264676958322525, 'learning_rate': 4.9944444

  0%|          | 0/4503 [00:00<?, ?it/s]

{'eval_loss': 0.004777882713824511, 'eval_runtime': 117.5172, 'eval_samples_per_second': 1532.507, 'eval_steps_per_second': 38.318, 'epoch': 1.0}
{'loss': 0.0058, 'grad_norm': 0.48081842064857483, 'learning_rate': 3.999444444444445e-05, 'epoch': 1.0}
{'loss': 0.0055, 'grad_norm': 0.660835862159729, 'learning_rate': 3.998888888888889e-05, 'epoch': 1.0}
{'loss': 0.0052, 'grad_norm': 0.14352993667125702, 'learning_rate': 3.9983333333333334e-05, 'epoch': 1.0}
{'loss': 0.0051, 'grad_norm': 0.32098639011383057, 'learning_rate': 3.997777777777778e-05, 'epoch': 1.0}
{'loss': 0.005, 'grad_norm': 0.061732593923807144, 'learning_rate': 3.997222222222222e-05, 'epoch': 1.0}
{'loss': 0.0053, 'grad_norm': 0.24115508794784546, 'learning_rate': 3.996666666666667e-05, 'epoch': 1.0}
{'loss': 0.0044, 'grad_norm': 0.11292140930891037, 'learning_rate': 3.996111111111111e-05, 'epoch': 1.0}
{'loss': 0.0055, 'grad_norm': 0.6018632650375366, 'learning_rate': 3.995555555555556e-05, 'epoch': 1.0}
{'loss': 0.0046,

  0%|          | 0/4503 [00:00<?, ?it/s]

{'eval_loss': 0.003455488011240959, 'eval_runtime': 116.5293, 'eval_samples_per_second': 1545.5, 'eval_steps_per_second': 38.643, 'epoch': 2.0}
{'loss': 0.0051, 'grad_norm': 0.03370257094502449, 'learning_rate': 2.9994444444444448e-05, 'epoch': 2.0}
{'loss': 0.0028, 'grad_norm': 0.06067823991179466, 'learning_rate': 2.9988888888888888e-05, 'epoch': 2.0}
{'loss': 0.0033, 'grad_norm': 0.17996826767921448, 'learning_rate': 2.9983333333333335e-05, 'epoch': 2.0}
{'loss': 0.0037, 'grad_norm': 0.35691550374031067, 'learning_rate': 2.997777777777778e-05, 'epoch': 2.0}
{'loss': 0.0041, 'grad_norm': 0.3199816346168518, 'learning_rate': 2.9972222222222225e-05, 'epoch': 2.0}
{'loss': 0.003, 'grad_norm': 0.3896274268627167, 'learning_rate': 2.9966666666666672e-05, 'epoch': 2.0}
{'loss': 0.0039, 'grad_norm': 0.21552802622318268, 'learning_rate': 2.9961111111111112e-05, 'epoch': 2.0}
{'loss': 0.0036, 'grad_norm': 0.5770612955093384, 'learning_rate': 2.995555555555556e-05, 'epoch': 2.0}
{'loss': 0.003

  0%|          | 0/4503 [00:00<?, ?it/s]

{'eval_loss': 0.003135876962915063, 'eval_runtime': 114.315, 'eval_samples_per_second': 1575.436, 'eval_steps_per_second': 39.391, 'epoch': 3.0}
{'loss': 0.0032, 'grad_norm': 0.11939902603626251, 'learning_rate': 1.9994444444444445e-05, 'epoch': 3.0}
{'loss': 0.0028, 'grad_norm': 0.29717957973480225, 'learning_rate': 1.998888888888889e-05, 'epoch': 3.0}
{'loss': 0.0016, 'grad_norm': 0.009218208491802216, 'learning_rate': 1.9983333333333336e-05, 'epoch': 3.0}
{'loss': 0.0038, 'grad_norm': 0.29685819149017334, 'learning_rate': 1.997777777777778e-05, 'epoch': 3.0}
{'loss': 0.0035, 'grad_norm': 0.14982913434505463, 'learning_rate': 1.9972222222222223e-05, 'epoch': 3.0}
{'loss': 0.0021, 'grad_norm': 0.12164395302534103, 'learning_rate': 1.9966666666666666e-05, 'epoch': 3.0}
{'loss': 0.003, 'grad_norm': 0.11885754019021988, 'learning_rate': 1.996111111111111e-05, 'epoch': 3.0}
{'loss': 0.0036, 'grad_norm': 0.3633973002433777, 'learning_rate': 1.9955555555555557e-05, 'epoch': 3.0}
{'loss': 0.

  0%|          | 0/4503 [00:00<?, ?it/s]

{'eval_loss': 0.0027346634306013584, 'eval_runtime': 118.5158, 'eval_samples_per_second': 1519.595, 'eval_steps_per_second': 37.995, 'epoch': 4.0}
{'loss': 0.0019, 'grad_norm': 0.2773076295852661, 'learning_rate': 9.994444444444444e-06, 'epoch': 4.0}
{'loss': 0.0019, 'grad_norm': 0.11902415007352829, 'learning_rate': 9.98888888888889e-06, 'epoch': 4.0}
{'loss': 0.0034, 'grad_norm': 0.2866610884666443, 'learning_rate': 9.983333333333333e-06, 'epoch': 4.0}
{'loss': 0.0026, 'grad_norm': 0.2551330029964447, 'learning_rate': 9.977777777777778e-06, 'epoch': 4.0}
{'loss': 0.0035, 'grad_norm': 0.062253549695014954, 'learning_rate': 9.972222222222224e-06, 'epoch': 4.0}
{'loss': 0.0035, 'grad_norm': 0.3464434742927551, 'learning_rate': 9.966666666666667e-06, 'epoch': 4.0}
{'loss': 0.002, 'grad_norm': 0.03199583664536476, 'learning_rate': 9.96111111111111e-06, 'epoch': 4.0}
{'loss': 0.0051, 'grad_norm': 0.24499431252479553, 'learning_rate': 9.955555555555556e-06, 'epoch': 4.0}
{'loss': 0.0038, 'g

  0%|          | 0/4503 [00:00<?, ?it/s]

{'eval_loss': 0.002559139858931303, 'eval_runtime': 124.8817, 'eval_samples_per_second': 1442.132, 'eval_steps_per_second': 36.058, 'epoch': 5.0}
{'train_runtime': 4557.2326, 'train_samples_per_second': 789.953, 'train_steps_per_second': 19.749, 'train_loss': 0.0038134187745137347, 'epoch': 5.0}


TrainOutput(global_step=90000, training_loss=0.0038134187745137347, metrics={'train_runtime': 4557.2326, 'train_samples_per_second': 789.953, 'train_steps_per_second': 19.749, 'total_flos': 0.0, 'train_loss': 0.0038134187745137347, 'epoch': 5.0})

In [12]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)

In [13]:
from src.visualization.generate_mesh import generate_mesh
from src.data.load_data import get_data_dir

obj_path = get_data_dir() / 'intermediate' / 'bunny' / 'stanford-bunny.obj'
infered_obj_path = result_dir / f"{current_date}.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [14]:
import importlib
import src.visualization.generate_mesh
importlib.reload(src.visualization.generate_mesh)

<module 'src.visualization.generate_mesh' from 'c:\\_prog\\vm_shared\\attention-sdf\\src\\visualization\\generate_mesh.py'>