In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.modules import *
from src.models.loss import L1_epsilon
from dataclasses import dataclass
from torch.functional import F
import torch

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass
class SDFTransformerConfig:
    dim_context: int = 4
    dim_input: int = 3
    num_outputs: int = 1
    dim_output: int = 1
    delta: float = 0.1
    dim_hidden: int = 128
    num_ctx_seeds: int = 32
    num_x_seeds: int = 32
    num_heads: int = 1

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.proj_x = nn.Linear(config.dim_input, config.dim_hidden)
        self.proj_ctx = nn.Linear(config.dim_context, config.dim_hidden)
        self.pool_ctx = PMA(config.dim_hidden, config.num_heads, config.num_ctx_seeds)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads)
        self.dec = nn.Sequential(
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads),
            nn.SiLU(),
            PMA(config.dim_hidden, config.num_heads, config.num_outputs),
            nn.SiLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        x = self.proj_x(x)                                  # [B, 1, H]
        x = x.expand(-1, self.config.num_x_seeds, -1)       # [B, X, H]
        y = self.pool_ctx(self.proj_ctx(context))           # [B, Y, H]
        xy = self.cross(x, y)                               # [B, X, H]
        xy = xy + x                                         # [B, X, H]
        xy = self.dec(xy)

        loss = None
        if labels is not None:
            loss = L1_epsilon(xy, labels, self.epsilon, self.config.delta)
        return {'loss': loss, 'logits': xy}

config = SDFTransformerConfig()
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [4]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2025_01_20_L_epsilon'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2025_01_20_L_epsilon-2025-01-20-15-05-38


In [5]:
from transformers import Trainer, TrainingArguments

batch_size = 64
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 2, "epsilon": 0.02,   'learning_rate': 5e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0075, 'learning_rate': 4e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.004,  'learning_rate': 3e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.002,  'learning_rate': 2e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0,    'learning_rate': 1e-5, 'resolution': 256},
]

In [6]:
from src.visualization.generate_mesh import generate_meshes
from src.data.load_data import get_data_dir

obj_dir = get_data_dir() / 'intermediate'
format_string_base = "{name}-" + current_date + "-curriculum-"

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    format_string = format_string_base + str(i) + ".obj"
    generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=stage['resolution'], context_size=256)
train_dataset.close()
val_dataset.close()

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0036, 'grad_norm': 0.03864193335175514, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0036, 'grad_norm': 0.021532291546463966, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 0.0, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0032, 'grad_norm': 0.1723002940416336, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0036, 'grad_norm': 0.07762668281793594, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0042, 'grad_norm': 0.040684618055820465, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0026, 'grad_norm': 0.013352339155972004, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0041, 'grad_norm': 0.07696598768234253, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0028, 'grad_norm': 0.09507149457931519, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0042, 'grad_norm': 0.04015125334262848, 'learning_ra

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.00027321826200932264, 'eval_runtime': 102.0436, 'eval_samples_per_second': 1469.96, 'eval_steps_per_second': 22.971, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.07110368460416794, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.0435212105512619, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0002, 'grad_norm': 0.02861190028488636, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.0562865324318409, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0002, 'grad_norm': 0.05618368834257126, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.06423118710517883, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0002, 'grad_norm': 0.0, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0002, 'grad_norm': 0.08854155242443085, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.0004, 'grad_

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0001344066549791023, 'eval_runtime': 101.2993, 'eval_samples_per_second': 1480.76, 'eval_steps_per_second': 23.139, 'epoch': 2.0}
{'train_runtime': 1672.7702, 'train_samples_per_second': 896.716, 'train_steps_per_second': 14.011, 'train_loss': 0.0006224449491205313, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0039, 'grad_norm': 0.7119138836860657, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0057, 'grad_norm': 0.7144072651863098, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0041, 'grad_norm': 0.6308208703994751, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0034, 'grad_norm': 0.34049853682518005, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0023, 'grad_norm': 0.20532654225826263, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0016, 'grad_norm': 0.24118825793266296, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0014, 'grad_norm': 0.44205325841903687, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.002, 'grad_norm': 0.40371665358543396, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0016, 'grad_norm': 0.30726295709609985, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0014, 'grad_norm': 0.13199633359909058, 'l

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0003830253263004124, 'eval_runtime': 115.602, 'eval_samples_per_second': 1297.555, 'eval_steps_per_second': 20.276, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.11466019600629807, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.097388356924057, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.07467695325613022, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.13530464470386505, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.08897264301776886, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.08422817289829254, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.09511748701334, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0003, 'grad_norm': 0.11041709035634995, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0002496231463737786, 'eval_runtime': 103.2005, 'eval_samples_per_second': 1453.481, 'eval_steps_per_second': 22.713, 'epoch': 2.0}
{'train_runtime': 1867.4884, 'train_samples_per_second': 803.218, 'train_steps_per_second': 12.551, 'train_loss': 0.00042445712552046944, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0022, 'grad_norm': 1.0858551263809204, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0033, 'grad_norm': 0.4486837387084961, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0025, 'grad_norm': 0.46062251925468445, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0026, 'grad_norm': 0.5175108909606934, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.002, 'grad_norm': 0.6645110845565796, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0022, 'grad_norm': 0.458798348903656, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0016, 'grad_norm': 0.38690152764320374, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0016, 'grad_norm': 0.319939523935318, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0014, 'grad_norm': 0.6521835327148438, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0018, 'grad_norm': 0.5160829424858093, 'learning

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.00044778233859688044, 'eval_runtime': 104.7199, 'eval_samples_per_second': 1432.392, 'eval_steps_per_second': 22.384, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.12372158467769623, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.0814693421125412, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.06535962969064713, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.1478525996208191, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.11447437107563019, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.08184505999088287, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.08121153712272644, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0004, 'grad_norm': 0.08427944779396057, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.000314940232783556, 'eval_runtime': 100.958, 'eval_samples_per_second': 1485.766, 'eval_steps_per_second': 23.218, 'epoch': 2.0}
{'train_runtime': 1757.2312, 'train_samples_per_second': 853.616, 'train_steps_per_second': 13.338, 'train_loss': 0.000469636406198116, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0013, 'grad_norm': 0.48977404832839966, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0022, 'grad_norm': 0.4372447729110718, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0023, 'grad_norm': 0.37552350759506226, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.003, 'grad_norm': 0.7630776762962341, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0029, 'grad_norm': 0.3329988420009613, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0029, 'grad_norm': 0.523750364780426, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0016, 'grad_norm': 0.8352086544036865, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0017, 'grad_norm': 0.24219970405101776, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0017, 'grad_norm': 0.2791759967803955, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0019, 'grad_norm': 0.31317317485809326, 'learn

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0007001618505455554, 'eval_runtime': 105.5894, 'eval_samples_per_second': 1420.598, 'eval_steps_per_second': 22.199, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.2731851637363434, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.08025228977203369, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.09422027319669724, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.18294377624988556, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.24052882194519043, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.15587671101093292, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.1592019647359848, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0005, 'grad_norm': 0.14560195803642273, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss'

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.00046045196359045804, 'eval_runtime': 119.3661, 'eval_samples_per_second': 1256.638, 'eval_steps_per_second': 19.637, 'epoch': 2.0}
{'train_runtime': 1956.0607, 'train_samples_per_second': 766.847, 'train_steps_per_second': 11.982, 'train_loss': 0.0006486485632391487, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.003, 'grad_norm': 0.41412198543548584, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0038, 'grad_norm': 0.666930615901947, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0033, 'grad_norm': 0.6466337442398071, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.004, 'grad_norm': 0.27089622616767883, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.003, 'grad_norm': 0.2861522138118744, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 0.26824167370796204, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0027, 'grad_norm': 0.21653331816196442, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0029, 'grad_norm': 0.2786013185977936, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0029, 'grad_norm': 0.3453538119792938, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0027, 'grad_norm': 0.20090137422084808, 'learni

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0017469769809395075, 'eval_runtime': 107.8504, 'eval_samples_per_second': 1390.816, 'eval_steps_per_second': 21.734, 'epoch': 1.0}
{'loss': 0.0016, 'grad_norm': 0.30806729197502136, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0017, 'grad_norm': 0.26281818747520447, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.25280576944351196, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0016, 'grad_norm': 0.2749772071838379, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0016, 'grad_norm': 0.3625491261482239, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.10507752001285553, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0017, 'grad_norm': 0.30172619223594666, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0015, 'grad_norm': 0.3182053864002228, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss':

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.001318948925472796, 'eval_runtime': 109.3625, 'eval_samples_per_second': 1371.585, 'eval_steps_per_second': 21.433, 'epoch': 2.0}
{'train_runtime': 1971.557, 'train_samples_per_second': 760.82, 'train_steps_per_second': 11.888, 'train_loss': 0.0016155595075326682, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

In [7]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)