In [6]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from src.models.modules import *
from src.models.loss import L1_epsilon_lambda
from dataclasses import dataclass
import torch

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    num_inds: int = 32
    dim_hidden: int = 64
    num_heads: int = 4
    ln: bool = False

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.lambdaa = None
        self.enc = nn.Sequential(
            ISAB(config.dim_context, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            nn.SiLU(),
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            nn.SiLU(),
            PMA(config.dim_hidden, config.num_heads, config.num_outputs, ln=config.ln),
            nn.SiLU(),
        )
        self.input_proj = nn.Linear(config.dim_input, config.dim_hidden)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.dec = nn.Sequential(
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
        )
        self.regr = nn.Sequential(
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        y = self.enc(context)           # [batch_size, context_size, dim_hidden]
        x = self.input_proj(x)          # [batch_size, num_outputs, dim_hidden]
        x = x.repeat(1, y.shape[1], 1)  # [batch_size, context_size, dim_hidden]
        y = self.cross(x, y)            # [batch_size, context_size, dim_hidden]
        y = self.dec(y)                 # [batch_size, num_outputs, dim_hidden]
        y = self.regr(y)                # [batch_size, num_outputs, dim_output]

        loss = None
        if labels is not None:
            loss = L1_epsilon_lambda(y, labels, self.epsilon, self.lambdaa, self.config.delta)
        return {'loss': loss, 'logits': y}

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [8]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [9]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_12_05_more_models'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_12_05_more_models-2024-12-06-15-42-17


In [10]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 2, "epsilon": 0.02,   "lambda": 0.0,  'learning_rate': 5e-5},
    {"epochs": 2, "epsilon": 0.0075, "lambda": 0.15, 'learning_rate': 4e-5},
    {"epochs": 2, "epsilon": 0.004,  "lambda": 0.3,  'learning_rate': 3e-5},
    {"epochs": 2, "epsilon": 0.002,  "lambda": 0.4,  'learning_rate': 2e-5},
    {"epochs": 2, "epsilon": 0.0,    "lambda": 0.5,  'learning_rate': 1e-5}
]

In [11]:
from src.visualization.generate_mesh import generate_meshes
from src.data.load_data import get_data_dir

obj_dir = get_data_dir() / 'intermediate'
format_string_base = "{name}-" + current_date + "-curriculum-"

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    model.lambdaa = stage['lambda']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    format_string = format_string_base + str(i) + ".obj"
    generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=100, context_size=200)
train_dataset.close()
val_dataset.close()

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0271, 'grad_norm': 1.2878047227859497, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.0046, 'grad_norm': 0.0, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.0049, 'grad_norm': 0.10745882987976074, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0043, 'grad_norm': 0.10782333463430405, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0022, 'grad_norm': 0.034805770963430405, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0046, 'grad_norm': 0.10465945303440094, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.0035, 'grad_norm': 0.03636068105697632, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0034, 'grad_norm': 0.13969896733760834, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.0029, 'grad_norm': 0.03685876727104187, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.0058, 'grad_norm': 0.042724404484033585, 'learning_rate': 4.9722

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0011229101801291108, 'eval_runtime': 27.3769, 'eval_samples_per_second': 1461.086, 'eval_steps_per_second': 36.527, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.053860850632190704, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.0018, 'grad_norm': 0.17638817429542542, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.07176236063241959, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.0018, 'grad_norm': 0.11345639079809189, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.08604974299669266, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.0007, 'grad_norm': 0.10372030735015869, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.0003, 'grad_norm': 0.0, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.0009, 'grad_norm': 0.13673171401023865, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss': 0.0008, 'g

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0009351725457236171, 'eval_runtime': 28.8174, 'eval_samples_per_second': 1388.05, 'eval_steps_per_second': 34.701, 'epoch': 2.0}
{'train_runtime': 940.2414, 'train_samples_per_second': 765.761, 'train_steps_per_second': 19.144, 'train_loss': 0.001290088912161688, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0045, 'grad_norm': 1.4580222368240356, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.004, 'grad_norm': 1.2658226490020752, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.0041, 'grad_norm': 0.2875961661338806, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 0.6707594990730286, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0017, 'grad_norm': 0.21787814795970917, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0028, 'grad_norm': 0.1913316249847412, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.0022, 'grad_norm': 0.23537257313728333, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0014, 'grad_norm': 0.4996550381183624, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.0016, 'grad_norm': 0.3220706582069397, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.0034, 'grad_norm': 0.26485469937324524, 'learning_rate':

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.001456697005778551, 'eval_runtime': 28.7045, 'eval_samples_per_second': 1393.509, 'eval_steps_per_second': 34.838, 'epoch': 1.0}
{'loss': 0.0018, 'grad_norm': 0.0, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.0025, 'grad_norm': 0.036005035042762756, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0017, 'grad_norm': 0.2882101833820343, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.0023, 'grad_norm': 0.05101699382066727, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.0013, 'grad_norm': 0.16484110057353973, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.001, 'grad_norm': 0.05747819319367409, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.0004, 'grad_norm': 0.05501934885978699, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.001, 'grad_norm': 0.05620568245649338, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss': 0.0012, 'grad_

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0012066704221069813, 'eval_runtime': 27.8248, 'eval_samples_per_second': 1437.565, 'eval_steps_per_second': 35.939, 'epoch': 2.0}
{'train_runtime': 989.871, 'train_samples_per_second': 727.367, 'train_steps_per_second': 18.184, 'train_loss': 0.0014507469252631482, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0018, 'grad_norm': 0.8392648696899414, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.0019, 'grad_norm': 0.32329851388931274, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.0024, 'grad_norm': 0.6142975091934204, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0011, 'grad_norm': 0.5987160205841064, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0014, 'grad_norm': 0.24469268321990967, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0021, 'grad_norm': 0.6136938333511353, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.0021, 'grad_norm': 0.6768694519996643, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0012, 'grad_norm': 0.22401100397109985, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.0015, 'grad_norm': 0.2976487874984741, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.0028, 'grad_norm': 0.6518856883049011, 'learning_rate'

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0016334812389686704, 'eval_runtime': 27.2094, 'eval_samples_per_second': 1470.082, 'eval_steps_per_second': 36.752, 'epoch': 1.0}
{'loss': 0.0016, 'grad_norm': 0.06603258103132248, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.0026, 'grad_norm': 0.21361497044563293, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0015, 'grad_norm': 0.07147286087274551, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.003, 'grad_norm': 0.03594113141298294, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.0013, 'grad_norm': 0.2506066858768463, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.0011, 'grad_norm': 0.1071372702717781, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.0005, 'grad_norm': 0.053651291877031326, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.001, 'grad_norm': 0.057893283665180206, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0012997296871617436, 'eval_runtime': 27.037, 'eval_samples_per_second': 1479.455, 'eval_steps_per_second': 36.986, 'epoch': 2.0}
{'train_runtime': 989.0258, 'train_samples_per_second': 727.989, 'train_steps_per_second': 18.2, 'train_loss': 0.0014984050681394162, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0018, 'grad_norm': 0.7475482225418091, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.0019, 'grad_norm': 0.7477930784225464, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.0023, 'grad_norm': 0.893632709980011, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0016, 'grad_norm': 0.6727256774902344, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0014, 'grad_norm': 0.8092762231826782, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0024, 'grad_norm': 0.03989649564027786, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.0021, 'grad_norm': 1.1117490530014038, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0014, 'grad_norm': 0.17513813078403473, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.0019, 'grad_norm': 0.9924823045730591, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.0029, 'grad_norm': 0.5238434076309204, 'learning_rate': 

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.002083503408357501, 'eval_runtime': 25.9026, 'eval_samples_per_second': 1544.248, 'eval_steps_per_second': 38.606, 'epoch': 1.0}
{'loss': 0.0019, 'grad_norm': 0.23414826393127441, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.0028, 'grad_norm': 0.21893629431724548, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0019, 'grad_norm': 0.5335991978645325, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.0035, 'grad_norm': 0.2536373734474182, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.33933350443840027, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.0012, 'grad_norm': 0.2743685245513916, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.0007, 'grad_norm': 0.313859760761261, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.0013, 'grad_norm': 0.035471320152282715, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss': 

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.0015826119342818856, 'eval_runtime': 27.9071, 'eval_samples_per_second': 1433.328, 'eval_steps_per_second': 35.833, 'epoch': 2.0}
{'train_runtime': 934.9274, 'train_samples_per_second': 770.113, 'train_steps_per_second': 19.253, 'train_loss': 0.0017810083483733859, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/18000 [00:00<?, ?it/s]

{'loss': 0.0034, 'grad_norm': 0.30827248096466064, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.0043, 'grad_norm': 0.6106992363929749, 'learning_rate': 4.994444444444445e-05, 'epoch': 0.0}
{'loss': 0.004, 'grad_norm': 0.7480981349945068, 'learning_rate': 4.991666666666667e-05, 'epoch': 0.0}
{'loss': 0.0034, 'grad_norm': 0.49208876490592957, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0032, 'grad_norm': 0.4056619703769684, 'learning_rate': 4.986111111111111e-05, 'epoch': 0.01}
{'loss': 0.0039, 'grad_norm': 0.30604851245880127, 'learning_rate': 4.9833333333333336e-05, 'epoch': 0.01}
{'loss': 0.003, 'grad_norm': 0.7237206101417542, 'learning_rate': 4.9805555555555554e-05, 'epoch': 0.01}
{'loss': 0.0035, 'grad_norm': 0.3991623818874359, 'learning_rate': 4.977777777777778e-05, 'epoch': 0.01}
{'loss': 0.0035, 'grad_norm': 1.2180330753326416, 'learning_rate': 4.975e-05, 'epoch': 0.01}
{'loss': 0.0046, 'grad_norm': 1.0519479513168335, 'learning_rate': 

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.003353700740262866, 'eval_runtime': 28.0878, 'eval_samples_per_second': 1424.108, 'eval_steps_per_second': 35.603, 'epoch': 1.0}
{'loss': 0.0035, 'grad_norm': 1.1253081560134888, 'learning_rate': 2.4972222222222226e-05, 'epoch': 1.0}
{'loss': 0.0038, 'grad_norm': 0.4303174614906311, 'learning_rate': 2.4944444444444447e-05, 'epoch': 1.0}
{'loss': 0.0034, 'grad_norm': 0.5376290678977966, 'learning_rate': 2.4916666666666668e-05, 'epoch': 1.0}
{'loss': 0.0051, 'grad_norm': 0.3129482567310333, 'learning_rate': 2.488888888888889e-05, 'epoch': 1.0}
{'loss': 0.0026, 'grad_norm': 0.26800772547721863, 'learning_rate': 2.4861111111111114e-05, 'epoch': 1.01}
{'loss': 0.0028, 'grad_norm': 0.26594939827919006, 'learning_rate': 2.4833333333333335e-05, 'epoch': 1.01}
{'loss': 0.0021, 'grad_norm': 0.22119638323783875, 'learning_rate': 2.4805555555555556e-05, 'epoch': 1.01}
{'loss': 0.0025, 'grad_norm': 0.3098735213279724, 'learning_rate': 2.477777777777778e-05, 'epoch': 1.01}
{'loss': 0

  0%|          | 0/1000 [00:00<?, ?it/s]

{'eval_loss': 0.002900592051446438, 'eval_runtime': 27.7962, 'eval_samples_per_second': 1439.044, 'eval_steps_per_second': 35.976, 'epoch': 2.0}
{'train_runtime': 960.1417, 'train_samples_per_second': 749.889, 'train_steps_per_second': 18.747, 'train_loss': 0.003134594148542318, 'epoch': 2.0}


Processing models:   0%|          | 0/1 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [12]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)