In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.modules import *
from src.models.loss import L1_epsilon_lambda
from dataclasses import dataclass
import torch

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass
class SDFTransformerConfig:
    dim_context: int = 4
    dim_input: int = 3
    num_outputs: int = 1
    dim_output: int = 1
    delta: float = 0.1
    dim_hidden: int = 128
    num_ctx_seeds: int = 32
    num_x_seeds: int = 32
    num_heads: int = 4

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.lambdaa = None
        self.proj_x = nn.Linear(config.dim_input, config.dim_hidden)
        self.proj_ctx = nn.Linear(config.dim_context, config.dim_hidden)
        self.pool_ctx = PMA(config.dim_hidden, config.num_heads, config.num_ctx_seeds)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads)
        self.sab1 = SAB(config.dim_hidden, config.dim_hidden, config.num_heads)
        self.sab2 = SAB(config.dim_hidden, config.dim_hidden, config.num_heads)
        self.pool_final = PMA(config.dim_hidden, config.num_heads, config.num_outputs)
        self.proj_final = nn.Linear(config.dim_hidden, config.dim_output)

        self.silu = torch.sin

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        x = self.proj_x(x)                                  # [B, 1, H]
        x = x.expand(-1, self.config.num_x_seeds, -1)       # [B, X, H]
        y = self.pool_ctx(self.proj_ctx(context))           # [B, Y, H]
        xy = self.silu(self.cross(x, y))                    # [B, X, H]
        xy = xy + x                                         # [B, X, H]
        xy = self.silu(self.sab1(xy))                       # [B, X, H]
        xy = self.silu(self.sab2(xy))                       # [B, X, H]
        xy = self.pool_final(xy)                            # [B, F, H]
        xy = self.proj_final(xy)                            # [B, F, O]

        loss = None
        if labels is not None:
            loss = L1_epsilon_lambda(xy, labels, self.epsilon, self.lambdaa, self.config.delta)
        return {'loss': loss, 'logits': xy}

config = SDFTransformerConfig()
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [4]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2025_01_12_sin_enc'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2025_01_12_sin_enc-2025-01-12-20-54-46


In [5]:
from transformers import Trainer, TrainingArguments

batch_size = 64
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 2, "epsilon": 0.02,   "lambda": 0.0,  'learning_rate': 5e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0075, "lambda": 0.15, 'learning_rate': 4e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.004,  "lambda": 0.3,  'learning_rate': 3e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.002,  "lambda": 0.4,  'learning_rate': 2e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0,    "lambda": 0.5,  'learning_rate': 1e-5, 'resolution': 256}
]

In [6]:
from src.visualization.generate_mesh import generate_meshes
from src.data.load_data import get_data_dir

obj_dir = get_data_dir() / 'intermediate'
format_string_base = "{name}-" + current_date + "-curriculum-"

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    model.lambdaa = stage['lambda']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    format_string = format_string_base + str(i) + ".obj"
    generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=stage['resolution'], context_size=256)
train_dataset.close()
val_dataset.close()

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0316, 'grad_norm': 1.0437781810760498, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0098, 'grad_norm': 0.6846656203269958, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0051, 'grad_norm': 0.3105813264846802, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0039, 'grad_norm': 0.7690013647079468, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0045, 'grad_norm': 0.09465508908033371, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0047, 'grad_norm': 0.24180790781974792, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.0031, 'grad_norm': 0.27468323707580566, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0047, 'grad_norm': 0.4331199526786804, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0033, 'grad_norm': 0.5561445355415344, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0048, 'grad_norm': 0.17116418480873108, 'lea

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0012093420373275876, 'eval_runtime': 108.6712, 'eval_samples_per_second': 1380.31, 'eval_steps_per_second': 21.57, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.22686095535755157, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.23054669797420502, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.15317486226558685, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.16489280760288239, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0013, 'grad_norm': 0.14699119329452515, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.0013, 'grad_norm': 0.08565917611122131, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.06652416288852692, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.0009, 'grad_norm': 0.30495983362197876, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss'

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0008256315486505628, 'eval_runtime': 107.1853, 'eval_samples_per_second': 1399.446, 'eval_steps_per_second': 21.869, 'epoch': 2.0}
{'train_runtime': 1780.343, 'train_samples_per_second': 842.534, 'train_steps_per_second': 13.165, 'train_loss': 0.001400636687165673, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.0093, 'grad_norm': 2.201678991317749, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0079, 'grad_norm': 1.4017647504806519, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0072, 'grad_norm': 2.0599100589752197, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0071, 'grad_norm': 1.9323419332504272, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.008, 'grad_norm': 1.6693137884140015, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}
{'loss': 0.0085, 'grad_norm': 1.3324469327926636, 'learning_rate': 4.9872002730608416e-05, 'epoch': 0.01}
{'loss': 0.006, 'grad_norm': 0.7269073128700256, 'learning_rate': 4.985066985237649e-05, 'epoch': 0.01}
{'loss': 0.0039, 'grad_norm': 1.2201075553894043, 'learning_rate': 4.982933697414455e-05, 'epoch': 0.01}
{'loss': 0.0038, 'grad_norm': 0.5010647177696228, 'learning_rate': 4.980800409591262e-05, 'epoch': 0.01}
{'loss': 0.0036, 'grad_norm': 1.5592877864837646, 'learning_r

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0010931903962045908, 'eval_runtime': 110.7321, 'eval_samples_per_second': 1354.62, 'eval_steps_per_second': 21.168, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.2090212106704712, 'learning_rate': 2.499786671217681e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.2294575423002243, 'learning_rate': 2.4976533833944876e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.22859032452106476, 'learning_rate': 2.4955200955712947e-05, 'epoch': 1.0}
{'loss': 0.0013, 'grad_norm': 0.4152574837207794, 'learning_rate': 2.4933868077481015e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.25703075528144836, 'learning_rate': 2.4912535199249083e-05, 'epoch': 1.0}
{'loss': 0.001, 'grad_norm': 0.1479860544204712, 'learning_rate': 2.4891202321017154e-05, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.19755344092845917, 'learning_rate': 2.4869869442785222e-05, 'epoch': 1.01}
{'loss': 0.001, 'grad_norm': 0.18037448823451996, 'learning_rate': 2.484853656455329e-05, 'epoch': 1.01}
{'loss': 0.0

  0%|          | 0/2344 [00:00<?, ?it/s]

{'eval_loss': 0.0006436426774598658, 'eval_runtime': 112.6418, 'eval_samples_per_second': 1331.655, 'eval_steps_per_second': 20.809, 'epoch': 2.0}
{'train_runtime': 1933.6481, 'train_samples_per_second': 775.736, 'train_steps_per_second': 12.121, 'train_loss': 0.0012205000819815353, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/15625 [00:00<?, ?it/s]

  0%|          | 0/23438 [00:00<?, ?it/s]

{'loss': 0.004, 'grad_norm': 1.1085772514343262, 'learning_rate': 4.997866712176807e-05, 'epoch': 0.0}
{'loss': 0.0058, 'grad_norm': 2.5193164348602295, 'learning_rate': 4.995733424353614e-05, 'epoch': 0.0}
{'loss': 0.0045, 'grad_norm': 0.9270623326301575, 'learning_rate': 4.993600136530421e-05, 'epoch': 0.0}
{'loss': 0.0064, 'grad_norm': 1.4410475492477417, 'learning_rate': 4.991466848707228e-05, 'epoch': 0.0}
{'loss': 0.0058, 'grad_norm': 2.3266303539276123, 'learning_rate': 4.9893335608840345e-05, 'epoch': 0.0}


KeyboardInterrupt: 

In [11]:
generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=256, context_size=256)

Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/262144 [00:00<?, ?it/s]

In [7]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)