In [9]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
from src.models.modules import *
from src.models.loss import L1_epsilon_lambda
from dataclasses import dataclass
from torch.functional import F
import torch

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass
class SDFTransformerConfig:
    dim_context: int = 4
    dim_input: int = 3
    num_outputs: int = 1
    dim_output: int = 1
    delta: float = 0.1
    dim_hidden: int = 128
    num_ctx_seeds: int = 32
    num_x_seeds: int = 32
    num_heads: int = 1

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.lambdaa = None
        self.proj_x = nn.Linear(config.dim_input, config.dim_hidden)
        self.proj_ctx = nn.Linear(config.dim_context, config.dim_hidden)
        self.pool_ctx = PMA(config.dim_hidden, config.num_heads, config.num_ctx_seeds)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads)
        self.dec = nn.Sequential(
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads),
            nn.SiLU(),
            PMA(config.dim_hidden, config.num_heads, config.num_outputs),
            nn.SiLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        x = self.proj_x(x)                                  # [B, 1, H]
        x = x.expand(-1, self.config.num_x_seeds, -1)       # [B, X, H]
        y = self.pool_ctx(self.proj_ctx(context))           # [B, Y, H]
        xy = self.cross(x, y)                               # [B, X, H]
        xy = xy + x                                         # [B, X, H]
        xy = self.dec(xy)

        loss = None
        if labels is not None:
            loss = L1_epsilon_lambda(xy, labels, self.epsilon, self.lambdaa, self.config.delta)
        return {'loss': loss, 'logits': xy}

config = SDFTransformerConfig()
model = SDFTransformer(config).to(device)
print(device)

cuda


In [11]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [12]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2025_01_18_fast_enc'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2025_01_18_fast_enc-2025-01-18-16-19-46


In [13]:
from transformers import Trainer, TrainingArguments

batch_size = 64
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 2, "epsilon": 0.02,   "lambda": 0.0,  'learning_rate': 5e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0075, "lambda": 0.15, 'learning_rate': 4e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.004,  "lambda": 0.3,  'learning_rate': 3e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.002,  "lambda": 0.4,  'learning_rate': 2e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0,    "lambda": 0.5,  'learning_rate': 1e-5, 'resolution': 256}
]

In [14]:
from src.visualization.generate_mesh import generate_meshes
from src.data.load_data import get_data_dir

obj_dir = get_data_dir() / 'intermediate'
format_string_base = "{name}-" + current_date + "-curriculum-"

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    model.lambdaa = stage['lambda']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    format_string = format_string_base + str(i) + ".obj"
    generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=stage['resolution'], context_size=256)
train_dataset.close()
val_dataset.close()

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0043, 'grad_norm': 0.38274845480918884, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0041, 'grad_norm': 0.12820535898208618, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0036, 'grad_norm': 0.0, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0035, 'grad_norm': 0.27489152550697327, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0039, 'grad_norm': 0.12391854077577591, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0043, 'grad_norm': 0.038271140307188034, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0026, 'grad_norm': 0.04791709780693054, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0042, 'grad_norm': 0.1218525692820549, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0028, 'grad_norm': 0.12031153589487076, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0043, 'grad_norm': 0.05315001308917999, 'learning_rate

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0005897946539334953, 'eval_runtime': 98.2346, 'eval_samples_per_second': 1526.956, 'eval_steps_per_second': 23.861, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.1570439487695694, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.08252749592065811, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.03145088255405426, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.05386516824364662, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.10807931423187256, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.0, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.052605219185352325, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0005, 'grad_norm': 0.11466821283102036, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.0008, 'grad

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0002441282122163102, 'eval_runtime': 97.5706, 'eval_samples_per_second': 1537.348, 'eval_steps_per_second': 24.024, 'epoch': 2.0}
{'train_runtime': 1614.726, 'train_samples_per_second': 928.95, 'train_steps_per_second': 14.515, 'train_loss': 0.0009088358613125892, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.008, 'grad_norm': 0.8071458339691162, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0063, 'grad_norm': 0.6693010330200195, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.005, 'grad_norm': 1.146103858947754, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0043, 'grad_norm': 0.31141120195388794, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0041, 'grad_norm': 0.3937087953090668, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0028, 'grad_norm': 0.3163921535015106, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0024, 'grad_norm': 0.8482304215431213, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0034, 'grad_norm': 0.2575301229953766, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0024, 'grad_norm': 0.49950772523880005, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0022, 'grad_norm': 0.4498053193092346, 'learning

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0006744459387846291, 'eval_runtime': 97.777, 'eval_samples_per_second': 1534.104, 'eval_steps_per_second': 23.973, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.1338411271572113, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.13055090606212616, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.09154541045427322, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.1252441257238388, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.09717487543821335, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.09577464312314987, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.10971226543188095, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0006, 'grad_norm': 0.1423894762992859, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.00042540524736978114, 'eval_runtime': 98.5712, 'eval_samples_per_second': 1521.743, 'eval_steps_per_second': 23.78, 'epoch': 2.0}
{'train_runtime': 1619.2414, 'train_samples_per_second': 926.36, 'train_steps_per_second': 14.475, 'train_loss': 0.0007346390547120852, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0032, 'grad_norm': 0.6569888591766357, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0059, 'grad_norm': 0.5777238607406616, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0055, 'grad_norm': 0.7365223169326782, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0034, 'grad_norm': 0.476275771856308, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0024, 'grad_norm': 0.4690331518650055, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0024, 'grad_norm': 0.8529994487762451, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0022, 'grad_norm': 0.2701254189014435, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0033, 'grad_norm': 0.504292368888855, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0026, 'grad_norm': 0.6160906553268433, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0025, 'grad_norm': 0.37540775537490845, 'learning

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0009037701529450715, 'eval_runtime': 96.7948, 'eval_samples_per_second': 1549.67, 'eval_steps_per_second': 24.216, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.27453359961509705, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.001, 'grad_norm': 0.12637463212013245, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.07949529588222504, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.2102363407611847, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.1142333447933197, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.11637802422046661, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.18517780303955078, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0006, 'grad_norm': 0.0721496194601059, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0005566065083257854, 'eval_runtime': 93.2845, 'eval_samples_per_second': 1607.984, 'eval_steps_per_second': 25.127, 'epoch': 2.0}
{'train_runtime': 1608.2706, 'train_samples_per_second': 932.679, 'train_steps_per_second': 14.573, 'train_loss': 0.0008514715463347983, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0025, 'grad_norm': 1.5852794647216797, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0052, 'grad_norm': 0.8927856683731079, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0059, 'grad_norm': 0.7166507840156555, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0046, 'grad_norm': 0.3967350423336029, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0038, 'grad_norm': 0.7144489288330078, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0044, 'grad_norm': 0.4748205542564392, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0035, 'grad_norm': 0.31124573945999146, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0036, 'grad_norm': 0.4788019061088562, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0029, 'grad_norm': 0.3962222933769226, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.003, 'grad_norm': 0.29199790954589844, 'learni

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.00128228182438761, 'eval_runtime': 98.7224, 'eval_samples_per_second': 1519.412, 'eval_steps_per_second': 23.743, 'epoch': 1.0}
{'loss': 0.001, 'grad_norm': 0.31090158224105835, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.1855137199163437, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.24113574624061584, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.25039881467819214, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.001, 'grad_norm': 0.20859836041927338, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.3769609034061432, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.5086906552314758, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.001, 'grad_norm': 0.24002708494663239, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.001

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0007550538866780698, 'eval_runtime': 98.7627, 'eval_samples_per_second': 1518.792, 'eval_steps_per_second': 23.734, 'epoch': 2.0}
{'train_runtime': 1629.5909, 'train_samples_per_second': 920.476, 'train_steps_per_second': 14.383, 'train_loss': 0.0011155878883327235, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0036, 'grad_norm': 0.5984553694725037, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0048, 'grad_norm': 0.31148332357406616, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0042, 'grad_norm': 0.5380585193634033, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.005, 'grad_norm': 1.0963183641433716, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0057, 'grad_norm': 0.3686167895793915, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0056, 'grad_norm': 0.5306471586227417, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.005, 'grad_norm': 0.9360674619674683, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0046, 'grad_norm': 0.3181281089782715, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0044, 'grad_norm': 0.6550936698913574, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0042, 'grad_norm': 0.4783981442451477, 'learning

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.002597587648779154, 'eval_runtime': 99.17, 'eval_samples_per_second': 1512.554, 'eval_steps_per_second': 23.636, 'epoch': 1.0}
{'loss': 0.0026, 'grad_norm': 0.27911707758903503, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0028, 'grad_norm': 0.6459172964096069, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0023, 'grad_norm': 0.29929181933403015, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0025, 'grad_norm': 0.2049856185913086, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0024, 'grad_norm': 0.38882654905319214, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0023, 'grad_norm': 0.23633958399295807, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0027, 'grad_norm': 0.5139310359954834, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0024, 'grad_norm': 0.26130807399749756, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.0

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.001984893810003996, 'eval_runtime': 99.2406, 'eval_samples_per_second': 1511.479, 'eval_steps_per_second': 23.619, 'epoch': 2.0}
{'train_runtime': 1637.5266, 'train_samples_per_second': 916.016, 'train_steps_per_second': 14.313, 'train_loss': 0.0025156151022682907, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

In [15]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)