In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.modules import *
from src.models.loss import L1_epsilon_lambda
from dataclasses import dataclass
import torch

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    dim_hidden: int = 64
    num_inds: int = 32
    num_heads: int = 4
    ln: bool = False

class SDFEncoder(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFEncoder, self).__init__()
        self.config = config

    def forward(self, context: torch.Tensor):
        return torch.ones(context.shape[0], config.num_outputs, config.dim_hidden).to(device)

class SDFDecoder(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFDecoder, self).__init__()
        self.dec = nn.Sequential(
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor):
        return self.dec(context)

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.lambdaa = None
        self.enc = SDFEncoder(config)
        self.input_proj = nn.Linear(config.dim_input, config.dim_hidden)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.dec = SDFDecoder(config)

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        y = self.enc(context)           # [batch_size, context_size, dim_hidden]
        x = self.input_proj(x)          # [batch_size, num_outputs, dim_hidden]
        x = x.repeat(1, y.shape[1], 1)  # [batch_size, context_size, dim_hidden]
        y = self.cross(x, y)            # [batch_size, context_size, dim_hidden]
        y = self.dec(y)                 # [batch_size, num_outputs, dim_output]

        loss = None
        if labels is not None:
            loss = L1_epsilon_lambda(y, labels, self.epsilon, self.lambdaa, self.config.delta)
        return {'loss': loss, 'logits': y}

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [None]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_12_06_no_enc'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_12_06_enhanced_encoder-2024-12-08-20-17-19


In [5]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 2, "epsilon": 0.02,   "lambda": 0.0,  'learning_rate': 5e-5},
    {"epochs": 2, "epsilon": 0.0075, "lambda": 0.15, 'learning_rate': 4e-5},
    {"epochs": 2, "epsilon": 0.004,  "lambda": 0.3,  'learning_rate': 3e-5},
    {"epochs": 2, "epsilon": 0.002,  "lambda": 0.4,  'learning_rate': 2e-5},
    {"epochs": 2, "epsilon": 0.0,    "lambda": 0.5,  'learning_rate': 1e-5}
]

In [6]:
from src.visualization.generate_mesh import generate_meshes
from src.data.load_data import get_data_dir

obj_dir = get_data_dir() / 'intermediate'
format_string_base = "{name}-" + current_date + "-curriculum-"

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    model.lambdaa = stage['lambda']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    format_string = format_string_base + str(i) + ".obj"
    generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=100, context_size=200)
train_dataset.close()
val_dataset.close()

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0028, 'grad_norm': 0.08662301301956177, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.0037, 'grad_norm': 0.0, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.0044, 'grad_norm': 0.043719545006752014, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 0.08595050871372223, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.002, 'grad_norm': 0.04164993762969971, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0039, 'grad_norm': 0.010827702470123768, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.003, 'grad_norm': 0.04617235064506531, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0032, 'grad_norm': 0.17770326137542725, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.0028, 'grad_norm': 0.08496013283729553, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.0056, 'grad_norm': 0.2159336656332016, 'learning_rate': 4.972222

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0012926532654091716, 'eval_runtime': 21.9197, 'eval_samples_per_second': 1824.84, 'eval_steps_per_second': 45.621, 'epoch': 1.0}
{'loss': 0.0016, 'grad_norm': 0.06525403261184692, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.002, 'grad_norm': 0.035227082669734955, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0013, 'grad_norm': 0.07954756915569305, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.0018, 'grad_norm': 0.06144788861274719, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.11638706922531128, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.0008, 'grad_norm': 0.1395125538110733, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.0004, 'grad_norm': 0.0, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.0011, 'grad_norm': 0.0727742537856102, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss': 0.0012, 'grad_

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0011219135485589504, 'eval_runtime': 22.3222, 'eval_samples_per_second': 1791.938, 'eval_steps_per_second': 44.798, 'epoch': 2.0}
{'train_runtime': 643.5343, 'train_samples_per_second': 1118.821, 'train_steps_per_second': 27.971, 'train_loss': 0.0015026671464761926, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0048, 'grad_norm': 1.4188967943191528, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.0053, 'grad_norm': 0.9823011159896851, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.0053, 'grad_norm': 0.3913842439651489, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0035, 'grad_norm': 0.4480286240577698, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.003, 'grad_norm': 0.16823925077915192, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0035, 'grad_norm': 0.22585733234882355, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.0029, 'grad_norm': 0.6725784540176392, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0026, 'grad_norm': 0.4929681718349457, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.0029, 'grad_norm': 0.18137302994728088, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.0045, 'grad_norm': 0.3133786618709564, 'learning_rate':

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.002028378425166011, 'eval_runtime': 24.2434, 'eval_samples_per_second': 1649.934, 'eval_steps_per_second': 41.248, 'epoch': 1.0}
{'loss': 0.0025, 'grad_norm': 0.25925931334495544, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.003, 'grad_norm': 0.46241986751556396, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0022, 'grad_norm': 0.24147020280361176, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.0026, 'grad_norm': 0.15927264094352722, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.002, 'grad_norm': 0.13504035770893097, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.0013, 'grad_norm': 0.22268374264240265, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.0008, 'grad_norm': 0.1361384093761444, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.0016, 'grad_norm': 0.08514687418937683, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss':

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.001704422407783568, 'eval_runtime': 22.4632, 'eval_samples_per_second': 1780.694, 'eval_steps_per_second': 44.517, 'epoch': 2.0}
{'train_runtime': 663.4948, 'train_samples_per_second': 1085.163, 'train_steps_per_second': 27.129, 'train_loss': 0.002042381049845264, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0027, 'grad_norm': 0.5947365760803223, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.004, 'grad_norm': 1.3721195459365845, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.005, 'grad_norm': 0.715655505657196, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0025, 'grad_norm': 0.2822189927101135, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0025, 'grad_norm': 0.8649483323097229, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0028, 'grad_norm': 0.3772972822189331, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.003, 'grad_norm': 1.1475276947021484, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0029, 'grad_norm': 2.003664016723633, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.0028, 'grad_norm': 1.7691174745559692, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.0048, 'grad_norm': 0.4846031367778778, 'learning_rate': 4.9722

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0022926838137209415, 'eval_runtime': 23.1125, 'eval_samples_per_second': 1730.669, 'eval_steps_per_second': 43.267, 'epoch': 1.0}
{'loss': 0.003, 'grad_norm': 0.6558509469032288, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.0034, 'grad_norm': 0.7799602150917053, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0027, 'grad_norm': 0.2856532633304596, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.0029, 'grad_norm': 0.46211913228034973, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.0022, 'grad_norm': 0.17348141968250275, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.0015, 'grad_norm': 0.07942144572734833, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.001, 'grad_norm': 0.27515318989753723, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.0018, 'grad_norm': 0.171432226896286, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss': 0.

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.002068569418042898, 'eval_runtime': 22.7143, 'eval_samples_per_second': 1761.005, 'eval_steps_per_second': 44.025, 'epoch': 2.0}
{'train_runtime': 668.5842, 'train_samples_per_second': 1076.902, 'train_steps_per_second': 26.923, 'train_loss': 0.0023477528783906665, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0026, 'grad_norm': 0.8219172954559326, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.0048, 'grad_norm': 1.5379457473754883, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.0054, 'grad_norm': 0.39420390129089355, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0026, 'grad_norm': 0.29164227843284607, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 1.5596445798873901, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0049, 'grad_norm': 0.7429008483886719, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.005, 'grad_norm': 2.1244776248931885, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0024, 'grad_norm': 1.001949429512024, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.004, 'grad_norm': 1.5317336320877075, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.005, 'grad_norm': 0.6926044821739197, 'learning_rate': 4.9

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.002738921670243144, 'eval_runtime': 21.5157, 'eval_samples_per_second': 1859.104, 'eval_steps_per_second': 46.478, 'epoch': 1.0}
{'loss': 0.0037, 'grad_norm': 0.5912764668464661, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.0041, 'grad_norm': 0.9999126195907593, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0032, 'grad_norm': 0.3943593502044678, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.0037, 'grad_norm': 0.1793772578239441, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.0029, 'grad_norm': 0.38469985127449036, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.002, 'grad_norm': 0.7193138599395752, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.0014, 'grad_norm': 0.9302263259887695, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.0023, 'grad_norm': 0.3464461863040924, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss': 0.00

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0025375294499099255, 'eval_runtime': 22.087, 'eval_samples_per_second': 1811.016, 'eval_steps_per_second': 45.275, 'epoch': 2.0}
{'train_runtime': 678.8842, 'train_samples_per_second': 1060.564, 'train_steps_per_second': 26.514, 'train_loss': 0.002825036188044275, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0043, 'grad_norm': 1.9113374948501587, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.0061, 'grad_norm': 1.2296949625015259, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.0063, 'grad_norm': 0.48761123418807983, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0042, 'grad_norm': 0.851062536239624, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0044, 'grad_norm': 0.8769205808639526, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0055, 'grad_norm': 0.8515201807022095, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.0054, 'grad_norm': 1.5227725505828857, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0046, 'grad_norm': 0.8711882829666138, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.0049, 'grad_norm': 1.4964649677276611, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.0072, 'grad_norm': 1.5540419816970825, 'learning_rate': 4

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.004394472111016512, 'eval_runtime': 23.5115, 'eval_samples_per_second': 1701.292, 'eval_steps_per_second': 42.532, 'epoch': 1.0}
{'loss': 0.0053, 'grad_norm': 0.2139846384525299, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.0058, 'grad_norm': 0.5842447280883789, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0048, 'grad_norm': 0.5438944101333618, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.0057, 'grad_norm': 0.27059099078178406, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.0047, 'grad_norm': 0.5067827701568604, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.0035, 'grad_norm': 0.4927249252796173, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.0026, 'grad_norm': 0.41908150911331177, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.0038, 'grad_norm': 1.4401429891586304, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss': 0.

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0040221489034593105, 'eval_runtime': 22.7268, 'eval_samples_per_second': 1760.035, 'eval_steps_per_second': 44.001, 'epoch': 2.0}
{'train_runtime': 698.5507, 'train_samples_per_second': 1030.705, 'train_steps_per_second': 25.768, 'train_loss': 0.004408453373031484, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [7]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)