In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.modules import *
from src.models.loss import L1_epsilon_lambda
from dataclasses import dataclass
import torch

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass
class SDFTransformerConfig:
    dim_context: int = 4
    dim_input: int = 3
    num_outputs: int = 1
    dim_output: int = 1
    delta: float = 0.1
    dim_hidden: int = 128
    num_ctx_seeds: int = 32
    num_x_seeds: int = 32
    num_heads: int = 4

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.lambdaa = None
        self.proj_x = nn.Linear(config.dim_input, config.dim_hidden)
        self.proj_ctx = nn.Linear(config.dim_context, config.dim_hidden)
        self.pool_ctx = PMA(config.dim_hidden, config.num_heads, config.num_ctx_seeds)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads)
        self.sab1 = SAB(config.dim_hidden, config.dim_hidden, config.num_heads)
        self.sab2 = SAB(config.dim_hidden, config.dim_hidden, config.num_heads)
        self.pool_final = nn.AdaptiveAvgPool1d(config.num_outputs)
        self.proj_final = nn.Linear(config.dim_hidden, config.dim_output)

        self.silu = nn.SiLU()

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        x = self.proj_x(x)                                  # [B, 1, H]
        x = x.expand(-1, self.config.num_x_seeds, -1)       # [B, X, H]
        y = self.pool_ctx(self.proj_ctx(context))           # [B, Y, H]
        xy = self.cross(x, y)                               # [B, X, H]
        xy = xy + x                                         # [B, X, H]
        xy = self.silu(self.sab1(xy))                       # [B, X, H]
        xy = self.silu(self.sab2(xy))                       # [B, X, H]
        xy = xy.permute(0, 2, 1)                            # [B, H, X]
        xy = self.pool_final(xy)                            # [B, H, F]
        xy = xy.permute(0, 2, 1)                            # [B, F, H]
        xy = self.proj_final(xy)                            # [B, F, O]

        loss = None
        if labels is not None:
            loss = L1_epsilon_lambda(xy, labels, self.epsilon, self.lambdaa, self.config.delta)
        return {'loss': loss, 'logits': xy}

config = SDFTransformerConfig()
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [4]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2025_01_17_linear_pool'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2025_01_17_linear_pool-2025-01-17-15-25-13


In [5]:
from transformers import Trainer, TrainingArguments

batch_size = 64
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 2, "epsilon": 0.02,   "lambda": 0.0,  'learning_rate': 5e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0075, "lambda": 0.15, 'learning_rate': 4e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.004,  "lambda": 0.3,  'learning_rate': 3e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.002,  "lambda": 0.4,  'learning_rate': 2e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0,    "lambda": 0.5,  'learning_rate': 1e-5, 'resolution': 256}
]

In [6]:
from src.visualization.generate_mesh import generate_meshes
from src.data.load_data import get_data_dir

obj_dir = get_data_dir() / 'intermediate'
format_string_base = "{name}-" + current_date + "-curriculum-"

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    model.lambdaa = stage['lambda']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    format_string = format_string_base + str(i) + ".obj"
    generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=stage['resolution'], context_size=256)
train_dataset.close()
val_dataset.close()

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0107, 'grad_norm': 0.2368687242269516, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0071, 'grad_norm': 0.6764666438102722, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.004, 'grad_norm': 0.2363852858543396, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0036, 'grad_norm': 0.06197431683540344, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0041, 'grad_norm': 0.2931247055530548, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0046, 'grad_norm': 0.41990432143211365, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0028, 'grad_norm': 0.0744590163230896, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0044, 'grad_norm': 0.17291218042373657, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.003, 'grad_norm': 0.19013716280460358, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0045, 'grad_norm': 0.08347204327583313, 'lear

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0012235671747475863, 'eval_runtime': 102.3947, 'eval_samples_per_second': 1464.919, 'eval_steps_per_second': 22.892, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.12317471206188202, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.11186361312866211, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.04912436380982399, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.18640238046646118, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.08866454660892487, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0013, 'grad_norm': 0.0, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.06577179580926895, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.001, 'grad_norm': 0.20276540517807007, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.0015, 'grad

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0008529024198651314, 'eval_runtime': 103.4448, 'eval_samples_per_second': 1450.049, 'eval_steps_per_second': 22.659, 'epoch': 2.0}
{'train_runtime': 1672.9912, 'train_samples_per_second': 896.598, 'train_steps_per_second': 14.01, 'train_loss': 0.0014186613744288322, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0061, 'grad_norm': 1.4955447912216187, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0066, 'grad_norm': 0.7177596688270569, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0056, 'grad_norm': 1.1489784717559814, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0044, 'grad_norm': 0.8052846193313599, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0039, 'grad_norm': 0.5788162350654602, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0036, 'grad_norm': 0.46000775694847107, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0023, 'grad_norm': 0.17737837135791779, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0033, 'grad_norm': 0.7390291094779968, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0034, 'grad_norm': 0.7998262047767639, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0034, 'grad_norm': 1.0661702156066895, 'learn

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0012623862130567431, 'eval_runtime': 106.0411, 'eval_samples_per_second': 1414.546, 'eval_steps_per_second': 22.105, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.34667107462882996, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.23841606080532074, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.1564759761095047, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.440687894821167, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.2744790017604828, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.11733964085578918, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.208428755402565, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.001, 'grad_norm': 0.2153368890285492, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.00

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0007675801170989871, 'eval_runtime': 103.9196, 'eval_samples_per_second': 1443.423, 'eval_steps_per_second': 22.556, 'epoch': 2.0}
{'train_runtime': 1783.0226, 'train_samples_per_second': 841.268, 'train_steps_per_second': 13.145, 'train_loss': 0.0012738748032916618, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0051, 'grad_norm': 0.9935641288757324, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0074, 'grad_norm': 1.6314836740493774, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0094, 'grad_norm': 1.4370635747909546, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0086, 'grad_norm': 1.4049928188323975, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0066, 'grad_norm': 0.7450999617576599, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0046, 'grad_norm': 1.0614683628082275, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0043, 'grad_norm': 2.4709579944610596, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0066, 'grad_norm': 2.207554340362549, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0062, 'grad_norm': 1.4932390451431274, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0053, 'grad_norm': 0.8718714118003845, 'learning

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0013660972472280264, 'eval_runtime': 106.0172, 'eval_samples_per_second': 1414.865, 'eval_steps_per_second': 22.11, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.6023945808410645, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.25636380910873413, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.6304897665977478, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0015, 'grad_norm': 0.4834327697753906, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0015, 'grad_norm': 0.43618300557136536, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0015, 'grad_norm': 1.0669609308242798, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.26953622698783875, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0012, 'grad_norm': 0.4035002887248993, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.000897457473911345, 'eval_runtime': 107.6383, 'eval_samples_per_second': 1393.556, 'eval_steps_per_second': 21.777, 'epoch': 2.0}
{'train_runtime': 1764.0265, 'train_samples_per_second': 850.327, 'train_steps_per_second': 13.287, 'train_loss': 0.001336941752676644, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0039, 'grad_norm': 0.8673357367515564, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0048, 'grad_norm': 0.9976819753646851, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0045, 'grad_norm': 1.1128045320510864, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0051, 'grad_norm': 1.4304403066635132, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0054, 'grad_norm': 1.0114601850509644, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0062, 'grad_norm': 1.2489769458770752, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0047, 'grad_norm': 1.137641191482544, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0059, 'grad_norm': 1.0346568822860718, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0044, 'grad_norm': 1.1676950454711914, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0046, 'grad_norm': 1.0253239870071411, 'learning

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0024351852480322123, 'eval_runtime': 111.209, 'eval_samples_per_second': 1348.812, 'eval_steps_per_second': 21.077, 'epoch': 1.0}
{'loss': 0.0019, 'grad_norm': 1.085972785949707, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0019, 'grad_norm': 0.31803494691848755, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0016, 'grad_norm': 0.372996985912323, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0018, 'grad_norm': 0.6695573329925537, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0018, 'grad_norm': 0.8067996501922607, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0017, 'grad_norm': 1.1370949745178223, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.38339632749557495, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0015, 'grad_norm': 0.43963032960891724, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.00

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0011903065023943782, 'eval_runtime': 108.8937, 'eval_samples_per_second': 1377.49, 'eval_steps_per_second': 21.526, 'epoch': 2.0}
{'train_runtime': 1811.2079, 'train_samples_per_second': 828.177, 'train_steps_per_second': 12.941, 'train_loss': 0.0017127697508431204, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0047, 'grad_norm': 2.225177049636841, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0055, 'grad_norm': 1.035247564315796, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0046, 'grad_norm': 1.0986747741699219, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0056, 'grad_norm': 1.4892646074295044, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0051, 'grad_norm': 0.4717409014701843, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0056, 'grad_norm': 1.633090615272522, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0052, 'grad_norm': 0.8207352757453918, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0054, 'grad_norm': 1.1879756450653076, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0051, 'grad_norm': 0.7891991138458252, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0051, 'grad_norm': 0.7020358443260193, 'learning_r

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.003651614999398589, 'eval_runtime': 110.5184, 'eval_samples_per_second': 1357.24, 'eval_steps_per_second': 21.209, 'epoch': 1.0}
{'loss': 0.0035, 'grad_norm': 1.053330659866333, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0031, 'grad_norm': 0.7420674562454224, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0031, 'grad_norm': 0.7744703888893127, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0034, 'grad_norm': 0.5555905699729919, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0032, 'grad_norm': 0.9988405108451843, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0033, 'grad_norm': 0.768137514591217, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0032, 'grad_norm': 1.4600985050201416, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0033, 'grad_norm': 1.2168679237365723, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.0036, 

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.002517012180760503, 'eval_runtime': 108.9419, 'eval_samples_per_second': 1376.881, 'eval_steps_per_second': 21.516, 'epoch': 2.0}
{'train_runtime': 1841.0223, 'train_samples_per_second': 814.765, 'train_steps_per_second': 12.731, 'train_loss': 0.0033077353088509275, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

In [8]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)