In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.modules import *
from src.models.loss import L1_epsilon_lambda
from dataclasses import dataclass
from torch.functional import F
import torch

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass
class SDFTransformerConfig:
    dim_context: int = 4
    dim_input: int = 3
    num_outputs: int = 1
    dim_output: int = 1
    delta: float = 0.1
    dim_hidden: int = 128
    num_ctx_seeds: int = 32
    num_x_seeds: int = 32
    num_heads: int = 1

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.lambdaa = None
        self.proj_x = nn.Linear(config.dim_input, config.dim_hidden)
        self.proj_ctx = nn.Linear(config.dim_context, config.dim_hidden)
        self.pool_ctx = PMA(config.dim_hidden, config.num_heads, config.num_ctx_seeds)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads)
        self.dec = nn.Sequential(
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads),
            nn.SiLU(),
            PMA(config.dim_hidden, config.num_heads, config.num_outputs),
            nn.SiLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        x = self.proj_x(x)                                  # [B, 1, H]
        x = x.expand(-1, self.config.num_x_seeds, -1)       # [B, X, H]
        y = self.pool_ctx(self.proj_ctx(context))           # [B, Y, H]
        xy = self.cross(x, y)                               # [B, X, H]
        xy = xy + x                                         # [B, X, H]
        xy = self.dec(xy)

        loss = None
        if labels is not None:
            loss = L1_epsilon_lambda(xy, labels, self.epsilon, self.lambdaa, self.config.delta)
        return {'loss': loss, 'logits': xy}

config = SDFTransformerConfig()
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [4]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2025_01_19_fast_enc'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2025_01_19_fast_enc-2025-01-22-08-57-47


In [5]:
from transformers import Trainer, TrainingArguments

batch_size = 64
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 2, "epsilon": 0.02,   "lambda": 0.0,  'learning_rate': 5e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0075, "lambda": 0.15, 'learning_rate': 4e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.004,  "lambda": 0.3,  'learning_rate': 3e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.002,  "lambda": 0.4,  'learning_rate': 2e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0,    "lambda": 0.5,  'learning_rate': 1e-5, 'resolution': 256}
]

In [6]:
from src.visualization.generate_mesh import generate_meshes
from src.data.load_data import get_data_dir

obj_dir = get_data_dir() / 'intermediate'
format_string_base = "{name}-" + current_date + "-curriculum-"

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    model.lambdaa = stage['lambda']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    format_string = format_string_base + str(i) + ".obj"
    generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=stage['resolution'], context_size=256)
train_dataset.close()
val_dataset.close()

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0036, 'grad_norm': 0.03864193335175514, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0036, 'grad_norm': 0.021532291546463966, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 0.0, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0032, 'grad_norm': 0.1723002940416336, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0036, 'grad_norm': 0.07762668281793594, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0042, 'grad_norm': 0.040684618055820465, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0026, 'grad_norm': 0.013352339155972004, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0041, 'grad_norm': 0.07696598768234253, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0028, 'grad_norm': 0.09507149457931519, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0042, 'grad_norm': 0.04015125334262848, 'learning_ra

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.00027321826200932264, 'eval_runtime': 98.3506, 'eval_samples_per_second': 1525.156, 'eval_steps_per_second': 23.833, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.07110368460416794, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.0435212105512619, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0002, 'grad_norm': 0.02861190028488636, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.0562865324318409, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0002, 'grad_norm': 0.05618368834257126, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.06423118710517883, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0002, 'grad_norm': 0.0, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0002, 'grad_norm': 0.08854155242443085, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.0004, 'grad_

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0001344066549791023, 'eval_runtime': 101.3344, 'eval_samples_per_second': 1480.248, 'eval_steps_per_second': 23.131, 'epoch': 2.0}
{'train_runtime': 1656.6799, 'train_samples_per_second': 905.425, 'train_steps_per_second': 14.148, 'train_loss': 0.0006224449491205313, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0047, 'grad_norm': 1.1401360034942627, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0056, 'grad_norm': 0.5574620962142944, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0038, 'grad_norm': 0.2585609555244446, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0027, 'grad_norm': 0.2112899124622345, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0024, 'grad_norm': 0.3607332110404968, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0018, 'grad_norm': 0.09031083434820175, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0013, 'grad_norm': 0.19117268919944763, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.002, 'grad_norm': 0.13816705346107483, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0013, 'grad_norm': 0.30283430218696594, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0013, 'grad_norm': 0.1974918246269226, 'lear

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0004565678536891937, 'eval_runtime': 96.6078, 'eval_samples_per_second': 1552.67, 'eval_steps_per_second': 24.263, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.13126741349697113, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.0640251561999321, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.0723448395729065, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.161955788731575, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.28551167249679565, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.0, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.12682466208934784, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0003, 'grad_norm': 0.15540531277656555, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.0005, 'grad_norm

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0002774254244286567, 'eval_runtime': 99.7496, 'eval_samples_per_second': 1503.765, 'eval_steps_per_second': 23.499, 'epoch': 2.0}
{'train_runtime': 1711.9462, 'train_samples_per_second': 876.196, 'train_steps_per_second': 13.691, 'train_loss': 0.00047287225160943543, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0017, 'grad_norm': 0.6666340231895447, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0041, 'grad_norm': 0.3599577844142914, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0038, 'grad_norm': 0.6687119603157043, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0029, 'grad_norm': 0.2600783407688141, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0021, 'grad_norm': 0.3934599459171295, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0028, 'grad_norm': 0.40570610761642456, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0017, 'grad_norm': 0.38190242648124695, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0016, 'grad_norm': 0.2744452953338623, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0018, 'grad_norm': 0.45259329676628113, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0018, 'grad_norm': 0.7275049686431885, 'lear

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0006018340936861932, 'eval_runtime': 97.9521, 'eval_samples_per_second': 1531.361, 'eval_steps_per_second': 23.93, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.23760820925235748, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.18874411284923553, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.09456150978803635, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.1328350007534027, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.20804530382156372, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.1492864489555359, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.09326635301113129, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0004, 'grad_norm': 0.10606340318918228, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.00038267276249825954, 'eval_runtime': 97.038, 'eval_samples_per_second': 1545.787, 'eval_steps_per_second': 24.155, 'epoch': 2.0}
{'train_runtime': 1711.2439, 'train_samples_per_second': 876.555, 'train_steps_per_second': 13.696, 'train_loss': 0.0005803873345597538, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0021, 'grad_norm': 1.0888373851776123, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.003, 'grad_norm': 0.2910907566547394, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0033, 'grad_norm': 0.7916288375854492, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0041, 'grad_norm': 0.6645312905311584, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0025, 'grad_norm': 0.27876484394073486, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0023, 'grad_norm': 0.466959685087204, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0022, 'grad_norm': 0.315621942281723, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0026, 'grad_norm': 0.3030073344707489, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0022, 'grad_norm': 0.41138797998428345, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0027, 'grad_norm': 0.9482737183570862, 'learning

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.000829285301733762, 'eval_runtime': 101.4025, 'eval_samples_per_second': 1479.254, 'eval_steps_per_second': 23.116, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.27311110496520996, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.001, 'grad_norm': 0.29023292660713196, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.07545241713523865, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.2874329686164856, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.26387348771095276, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.11010493338108063, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.001, 'grad_norm': 0.1923186480998993, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0007, 'grad_norm': 0.18703047931194305, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0005916773807257414, 'eval_runtime': 98.6275, 'eval_samples_per_second': 1520.874, 'eval_steps_per_second': 23.766, 'epoch': 2.0}
{'train_runtime': 1725.1159, 'train_samples_per_second': 869.507, 'train_steps_per_second': 13.586, 'train_loss': 0.000859019253424803, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0032, 'grad_norm': 0.611083447933197, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0039, 'grad_norm': 0.41270574927330017, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0034, 'grad_norm': 0.45496171712875366, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.005, 'grad_norm': 1.1577825546264648, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0048, 'grad_norm': 0.4620605409145355, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0042, 'grad_norm': 0.5883559584617615, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0036, 'grad_norm': 0.461180180311203, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0036, 'grad_norm': 0.3595559597015381, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.004, 'grad_norm': 0.6656192541122437, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0042, 'grad_norm': 0.42889049649238586, 'learning

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0023564945440739393, 'eval_runtime': 108.3601, 'eval_samples_per_second': 1384.273, 'eval_steps_per_second': 21.632, 'epoch': 1.0}
{'loss': 0.0023, 'grad_norm': 0.21017028391361237, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0024, 'grad_norm': 0.187651589512825, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.002, 'grad_norm': 0.21938736736774445, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0022, 'grad_norm': 0.24856777489185333, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0021, 'grad_norm': 0.37287867069244385, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0021, 'grad_norm': 0.4141024351119995, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0024, 'grad_norm': 0.4271658658981323, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0021, 'grad_norm': 0.3081442415714264, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0017882755491882563, 'eval_runtime': 110.902, 'eval_samples_per_second': 1352.546, 'eval_steps_per_second': 21.136, 'epoch': 2.0}
{'train_runtime': 1836.5725, 'train_samples_per_second': 816.739, 'train_steps_per_second': 12.762, 'train_loss': 0.0022582734517126255, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

In [7]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)