In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.modules import *
from src.models.loss import L1_epsilon_lambda
from dataclasses import dataclass
from torch.functional import F
import torch

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass
class SDFTransformerConfig:
    dim_context: int = 4
    dim_input: int = 3
    num_outputs: int = 1
    dim_output: int = 1
    delta: float = 0.1
    dim_hidden: int = 128
    num_ctx_seeds: int = 32
    num_x_seeds: int = 32
    num_heads: int = 1

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.lambdaa = None
        self.proj_x = nn.Linear(config.dim_input, config.dim_hidden)
        self.proj_ctx = nn.Linear(config.dim_context, config.dim_hidden)
        self.pool_ctx = PMA(config.dim_hidden, config.num_heads, config.num_ctx_seeds)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads)
        self.dec = nn.Sequential(
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads),
            nn.SiLU(),
            PMA(config.dim_hidden, config.num_heads, config.num_outputs),
            nn.SiLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        x = self.proj_x(x)                                  # [B, 1, H]
        x = x.expand(-1, self.config.num_x_seeds, -1)       # [B, X, H]
        y = self.pool_ctx(self.proj_ctx(context))           # [B, Y, H]
        xy = self.cross(x, y)                               # [B, X, H]
        xy = xy + x                                         # [B, X, H]
        xy = self.dec(xy)

        loss = None
        if labels is not None:
            loss = L1_epsilon_lambda(xy, labels, self.epsilon, self.lambdaa, self.config.delta)
        return {'loss': loss, 'logits': xy}

config = SDFTransformerConfig()
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [4]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2025_01_22_no_curriculum'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2025_01_22_no_curriculum-2025-01-23-07-46-53


In [5]:
from transformers import Trainer, TrainingArguments

batch_size = 64
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 10, "epsilon": 0.0,   "lambda": 0.0,  'learning_rate': 5e-5, 'resolution': 256},
]

In [6]:
from src.visualization.generate_mesh import generate_meshes
from src.data.load_data import get_data_dir

obj_dir = get_data_dir() / 'intermediate'
format_string_base = "{name}-" + current_date + "-curriculum-"

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    model.lambdaa = stage['lambda']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    format_string = format_string_base + str(i) + ".obj"
    generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=stage['resolution'], context_size=256)
train_dataset.close()
val_dataset.close()

  0%|          | 0/117190 [00:00<?, ?it/s]

{'loss': 0.0108, 'grad_norm': 0.34355127811431885, 'learning_rate': 4.999573342435362e-05, 'epoch': 0.0}
{'loss': 0.0078, 'grad_norm': 0.9027493000030518, 'learning_rate': 4.999146684870723e-05, 'epoch': 0.0}
{'loss': 0.0064, 'grad_norm': 0.7875605821609497, 'learning_rate': 4.998720027306084e-05, 'epoch': 0.0}
{'loss': 0.0062, 'grad_norm': 0.11942391842603683, 'learning_rate': 4.998293369741446e-05, 'epoch': 0.0}
{'loss': 0.0069, 'grad_norm': 0.4873679578304291, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0076, 'grad_norm': 0.6002193689346313, 'learning_rate': 4.997440054612169e-05, 'epoch': 0.01}
{'loss': 0.0052, 'grad_norm': 0.33612123131752014, 'learning_rate': 4.9970133970475295e-05, 'epoch': 0.01}
{'loss': 0.0075, 'grad_norm': 0.1183251217007637, 'learning_rate': 4.9965867394828916e-05, 'epoch': 0.01}
{'loss': 0.0055, 'grad_norm': 0.030841905623674393, 'learning_rate': 4.9961600819182524e-05, 'epoch': 0.01}
{'loss': 0.0075, 'grad_norm': 0.05888752639293671, '

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.003118994180113077, 'eval_runtime': 98.3995, 'eval_samples_per_second': 1524.398, 'eval_steps_per_second': 23.821, 'epoch': 1.0}
{'loss': 0.0032, 'grad_norm': 0.046507928520441055, 'learning_rate': 4.4999573342435364e-05, 'epoch': 1.0}
{'loss': 0.0031, 'grad_norm': 0.27737534046173096, 'learning_rate': 4.499530676678897e-05, 'epoch': 1.0}
{'loss': 0.0031, 'grad_norm': 0.19756516814231873, 'learning_rate': 4.499104019114259e-05, 'epoch': 1.0}
{'loss': 0.0033, 'grad_norm': 0.1946389079093933, 'learning_rate': 4.49867736154962e-05, 'epoch': 1.0}
{'loss': 0.0031, 'grad_norm': 0.1391858011484146, 'learning_rate': 4.498250703984982e-05, 'epoch': 1.0}
{'loss': 0.0029, 'grad_norm': 0.31913432478904724, 'learning_rate': 4.497824046420343e-05, 'epoch': 1.0}
{'loss': 0.0032, 'grad_norm': 0.4175320565700531, 'learning_rate': 4.497397388855705e-05, 'epoch': 1.01}
{'loss': 0.0029, 'grad_norm': 0.681201696395874, 'learning_rate': 4.496970731291066e-05, 'epoch': 1.01}
{'loss': 0.0036, 

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.002621446270495653, 'eval_runtime': 103.0947, 'eval_samples_per_second': 1454.973, 'eval_steps_per_second': 22.736, 'epoch': 2.0}
{'loss': 0.0023, 'grad_norm': 0.17209471762180328, 'learning_rate': 3.9999146684870726e-05, 'epoch': 2.0}
{'loss': 0.0025, 'grad_norm': 0.1535900980234146, 'learning_rate': 3.9994880109224334e-05, 'epoch': 2.0}
{'loss': 0.0027, 'grad_norm': 0.13433434069156647, 'learning_rate': 3.9990613533577955e-05, 'epoch': 2.0}
{'loss': 0.0035, 'grad_norm': 0.02677236683666706, 'learning_rate': 3.998634695793156e-05, 'epoch': 2.0}
{'loss': 0.0028, 'grad_norm': 0.3018716871738434, 'learning_rate': 3.998208038228518e-05, 'epoch': 2.0}
{'loss': 0.0022, 'grad_norm': 0.16853326559066772, 'learning_rate': 3.997781380663879e-05, 'epoch': 2.0}
{'loss': 0.0029, 'grad_norm': 0.3556843101978302, 'learning_rate': 3.997354723099241e-05, 'epoch': 2.01}
{'loss': 0.0017, 'grad_norm': 0.28045329451560974, 'learning_rate': 3.996928065534602e-05, 'epoch': 2.01}
{'loss': 0.0

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.002071832539513707, 'eval_runtime': 98.2848, 'eval_samples_per_second': 1526.177, 'eval_steps_per_second': 23.849, 'epoch': 3.0}
{'loss': 0.0026, 'grad_norm': 0.17079675197601318, 'learning_rate': 3.499872002730609e-05, 'epoch': 3.0}
{'loss': 0.0017, 'grad_norm': 0.16091923415660858, 'learning_rate': 3.4994453451659695e-05, 'epoch': 3.0}
{'loss': 0.0018, 'grad_norm': 0.40460702776908875, 'learning_rate': 3.4990186876013316e-05, 'epoch': 3.0}
{'loss': 0.002, 'grad_norm': 0.4370710551738739, 'learning_rate': 3.4985920300366924e-05, 'epoch': 3.0}
{'loss': 0.0024, 'grad_norm': 0.19679684937000275, 'learning_rate': 3.4981653724720545e-05, 'epoch': 3.0}
{'loss': 0.0022, 'grad_norm': 0.24111050367355347, 'learning_rate': 3.497738714907415e-05, 'epoch': 3.0}
{'loss': 0.0017, 'grad_norm': 0.25863102078437805, 'learning_rate': 3.4973120573427773e-05, 'epoch': 3.01}
{'loss': 0.0023, 'grad_norm': 0.2026134729385376, 'learning_rate': 3.496885399778138e-05, 'epoch': 3.01}
{'loss': 0.

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.002039511688053608, 'eval_runtime': 100.8989, 'eval_samples_per_second': 1486.636, 'eval_steps_per_second': 23.231, 'epoch': 4.0}
{'loss': 0.002, 'grad_norm': 0.3225826621055603, 'learning_rate': 2.9998293369741446e-05, 'epoch': 4.0}
{'loss': 0.002, 'grad_norm': 0.22357790172100067, 'learning_rate': 2.999402679409506e-05, 'epoch': 4.0}
{'loss': 0.0014, 'grad_norm': 0.13347648084163666, 'learning_rate': 2.9989760218448675e-05, 'epoch': 4.0}
{'loss': 0.0016, 'grad_norm': 0.11706611514091492, 'learning_rate': 2.9985493642802286e-05, 'epoch': 4.0}
{'loss': 0.0017, 'grad_norm': 0.37065058946609497, 'learning_rate': 2.9981227067155903e-05, 'epoch': 4.0}
{'loss': 0.002, 'grad_norm': 0.049810752272605896, 'learning_rate': 2.9976960491509514e-05, 'epoch': 4.0}
{'loss': 0.0017, 'grad_norm': 0.061001308262348175, 'learning_rate': 2.9972693915863132e-05, 'epoch': 4.01}
{'loss': 0.0021, 'grad_norm': 0.406019002199173, 'learning_rate': 2.9968427340216743e-05, 'epoch': 4.01}
{'loss': 

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0019171599997207522, 'eval_runtime': 99.8359, 'eval_samples_per_second': 1502.465, 'eval_steps_per_second': 23.479, 'epoch': 5.0}
{'loss': 0.002, 'grad_norm': 0.3689191937446594, 'learning_rate': 2.499786671217681e-05, 'epoch': 5.0}
{'loss': 0.0022, 'grad_norm': 0.4799303114414215, 'learning_rate': 2.499360013653042e-05, 'epoch': 5.0}
{'loss': 0.002, 'grad_norm': 0.22163288295269012, 'learning_rate': 2.4989333560884033e-05, 'epoch': 5.0}
{'loss': 0.0023, 'grad_norm': 0.21702170372009277, 'learning_rate': 2.4985066985237648e-05, 'epoch': 5.0}
{'loss': 0.0018, 'grad_norm': 0.32039546966552734, 'learning_rate': 2.4980800409591262e-05, 'epoch': 5.0}
{'loss': 0.0027, 'grad_norm': 0.22097912430763245, 'learning_rate': 2.4976533833944876e-05, 'epoch': 5.0}
{'loss': 0.0025, 'grad_norm': 0.04018888249993324, 'learning_rate': 2.497226725829849e-05, 'epoch': 5.01}
{'loss': 0.002, 'grad_norm': 0.09526459872722626, 'learning_rate': 2.4968000682652105e-05, 'epoch': 5.01}
{'loss': 0.0

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0018245121464133263, 'eval_runtime': 100.8093, 'eval_samples_per_second': 1487.958, 'eval_steps_per_second': 23.252, 'epoch': 6.0}
{'loss': 0.0014, 'grad_norm': 0.032219793647527695, 'learning_rate': 1.9997440054612167e-05, 'epoch': 6.0}
{'loss': 0.0025, 'grad_norm': 0.28854188323020935, 'learning_rate': 1.999317347896578e-05, 'epoch': 6.0}
{'loss': 0.0015, 'grad_norm': 0.13179390132427216, 'learning_rate': 1.9988906903319395e-05, 'epoch': 6.0}
{'loss': 0.002, 'grad_norm': 0.10278695076704025, 'learning_rate': 1.998464032767301e-05, 'epoch': 6.0}
{'loss': 0.002, 'grad_norm': 0.12786605954170227, 'learning_rate': 1.9980373752026624e-05, 'epoch': 6.0}
{'loss': 0.0015, 'grad_norm': 0.16653800010681152, 'learning_rate': 1.9976107176380238e-05, 'epoch': 6.0}
{'loss': 0.0013, 'grad_norm': 0.09481091052293777, 'learning_rate': 1.9971840600733852e-05, 'epoch': 6.01}
{'loss': 0.0014, 'grad_norm': 0.06686746329069138, 'learning_rate': 1.9967574025087466e-05, 'epoch': 6.01}
{'loss

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0017822867957875133, 'eval_runtime': 101.3725, 'eval_samples_per_second': 1479.691, 'eval_steps_per_second': 23.123, 'epoch': 7.0}
{'loss': 0.002, 'grad_norm': 0.04649193957448006, 'learning_rate': 1.499701339704753e-05, 'epoch': 7.0}
{'loss': 0.0016, 'grad_norm': 0.24555325508117676, 'learning_rate': 1.4992746821401143e-05, 'epoch': 7.0}
{'loss': 0.0016, 'grad_norm': 0.026532934978604317, 'learning_rate': 1.4988480245754757e-05, 'epoch': 7.0}
{'loss': 0.0018, 'grad_norm': 0.21801543235778809, 'learning_rate': 1.4984213670108371e-05, 'epoch': 7.0}
{'loss': 0.0017, 'grad_norm': 0.09409346431493759, 'learning_rate': 1.4979947094461986e-05, 'epoch': 7.0}
{'loss': 0.0018, 'grad_norm': 0.042280081659555435, 'learning_rate': 1.49756805188156e-05, 'epoch': 7.0}
{'loss': 0.0019, 'grad_norm': 0.043570782989263535, 'learning_rate': 1.4971413943169212e-05, 'epoch': 7.01}
{'loss': 0.0012, 'grad_norm': 0.046280406415462494, 'learning_rate': 1.4967147367522827e-05, 'epoch': 7.01}
{'l

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0016770290676504374, 'eval_runtime': 99.2115, 'eval_samples_per_second': 1511.922, 'eval_steps_per_second': 23.626, 'epoch': 8.0}
{'loss': 0.0019, 'grad_norm': 0.1213863343000412, 'learning_rate': 9.99658673948289e-06, 'epoch': 8.0}
{'loss': 0.0017, 'grad_norm': 0.1531178057193756, 'learning_rate': 9.992320163836505e-06, 'epoch': 8.0}
{'loss': 0.0018, 'grad_norm': 0.037459541112184525, 'learning_rate': 9.988053588190119e-06, 'epoch': 8.0}
{'loss': 0.0017, 'grad_norm': 0.09789890050888062, 'learning_rate': 9.983787012543733e-06, 'epoch': 8.0}
{'loss': 0.0018, 'grad_norm': 0.1468200534582138, 'learning_rate': 9.979520436897346e-06, 'epoch': 8.0}
{'loss': 0.0019, 'grad_norm': 0.1309654265642166, 'learning_rate': 9.97525386125096e-06, 'epoch': 8.0}
{'loss': 0.0019, 'grad_norm': 0.26912495493888855, 'learning_rate': 9.970987285604574e-06, 'epoch': 8.01}
{'loss': 0.0014, 'grad_norm': 0.14607146382331848, 'learning_rate': 9.966720709958189e-06, 'epoch': 8.01}
{'loss': 0.0017, 

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0016455453587695956, 'eval_runtime': 99.9706, 'eval_samples_per_second': 1500.441, 'eval_steps_per_second': 23.447, 'epoch': 9.0}
{'loss': 0.0015, 'grad_norm': 0.14851370453834534, 'learning_rate': 4.996160081918252e-06, 'epoch': 9.0}
{'loss': 0.0014, 'grad_norm': 0.09763351082801819, 'learning_rate': 4.991893506271867e-06, 'epoch': 9.0}
{'loss': 0.0014, 'grad_norm': 0.05155812203884125, 'learning_rate': 4.98762693062548e-06, 'epoch': 9.0}
{'loss': 0.002, 'grad_norm': 0.23904816806316376, 'learning_rate': 4.983360354979094e-06, 'epoch': 9.0}
{'loss': 0.0015, 'grad_norm': 0.16483327746391296, 'learning_rate': 4.979093779332708e-06, 'epoch': 9.0}
{'loss': 0.0015, 'grad_norm': 0.16533324122428894, 'learning_rate': 4.974827203686322e-06, 'epoch': 9.01}
{'loss': 0.002, 'grad_norm': 0.09620021283626556, 'learning_rate': 4.970560628039935e-06, 'epoch': 9.01}
{'loss': 0.0021, 'grad_norm': 0.226893350481987, 'learning_rate': 4.966294052393549e-06, 'epoch': 9.01}
{'loss': 0.0018,

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.001608179067261517, 'eval_runtime': 101.6483, 'eval_samples_per_second': 1475.676, 'eval_steps_per_second': 23.06, 'epoch': 10.0}
{'train_runtime': 8174.036, 'train_samples_per_second': 917.539, 'train_steps_per_second': 14.337, 'train_loss': 0.0022108236364165165, 'epoch': 10.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

In [7]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)