In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.modules import *
from src.models.loss import L1_clamp_loss
from dataclasses import dataclass
import torch

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    num_inds: int = 32
    dim_hidden: int = 128
    num_heads: int = 4
    ln: bool = False

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.input_proj = nn.Sequential(nn.Linear(config.dim_input, config.dim_hidden), nn.ReLU())
        self.context_proj = nn.Sequential(nn.Linear(config.dim_context, config.dim_hidden), nn.ReLU())
        self.cross1 = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.enc = nn.Sequential(
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln)
        )
        self.cross2 = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.dec = nn.Sequential(
            PMA(config.dim_hidden, config.num_heads, config.num_outputs, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
        )
        self.regr = nn.Sequential(
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        x = self.input_proj(x)          # [batch_size, num_outputs, dim_hidden]
        y = self.context_proj(context)  # [batch_size, num_outputs, dim_hidden]
        x = x.repeat(1, y.shape[1], 1)  # [batch_size, context_size, dim_hidden]
        y = self.cross1(x, y)           # [batch_size, context_size, dim_hidden]
        y = self.enc(y)                 # [batch_size, context_size, dim_hidden]
        y = self.cross2(x, y)           # [batch_size, context_size, dim_hidden]
        y = self.dec(y)                 # [batch_size, num_outputs, dim_hidden]
        y = self.regr(y)                # [batch_size, num_outputs, dim_output]

        loss = None
        if labels is not None:
            loss = L1_clamp_loss(y, labels, self.config.delta)
        return {'loss': loss, 'logits': y}

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [None]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_11_18_cross_first_and_middle'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_11_18_less_regression-2024-11-18-19-39-48


In [5]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=8,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [6]:
trainer.train()
train_dataset.close()
val_dataset.close()

  0%|          | 0/90000 [00:00<?, ?it/s]

{'loss': 0.0295, 'grad_norm': 3.1588540077209473, 'learning_rate': 4.9994444444444446e-05, 'epoch': 0.0}
{'loss': 0.016, 'grad_norm': 2.327871084213257, 'learning_rate': 4.998888888888889e-05, 'epoch': 0.0}
{'loss': 0.0106, 'grad_norm': 2.251917600631714, 'learning_rate': 4.998333333333334e-05, 'epoch': 0.0}
{'loss': 0.01, 'grad_norm': 2.046504497528076, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.009, 'grad_norm': 1.304001808166504, 'learning_rate': 4.997222222222223e-05, 'epoch': 0.0}
{'loss': 0.0077, 'grad_norm': 2.039996385574341, 'learning_rate': 4.996666666666667e-05, 'epoch': 0.01}
{'loss': 0.0086, 'grad_norm': 0.6675971746444702, 'learning_rate': 4.9961111111111114e-05, 'epoch': 0.01}
{'loss': 0.0076, 'grad_norm': 2.5327000617980957, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.01}
{'loss': 0.0074, 'grad_norm': 2.987104892730713, 'learning_rate': 4.995e-05, 'epoch': 0.01}
{'loss': 0.0098, 'grad_norm': 1.5433958768844604, 'learning_rate': 4.99444444444

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0031096613965928555, 'eval_runtime': 99.7281, 'eval_samples_per_second': 1504.09, 'eval_steps_per_second': 37.602, 'epoch': 1.0}
{'loss': 0.0046, 'grad_norm': 0.11416604369878769, 'learning_rate': 4.374444444444445e-05, 'epoch': 1.0}
{'loss': 0.0031, 'grad_norm': 1.3413219451904297, 'learning_rate': 4.3738888888888893e-05, 'epoch': 1.0}
{'loss': 0.004, 'grad_norm': 0.5014130473136902, 'learning_rate': 4.373333333333334e-05, 'epoch': 1.0}
{'loss': 0.0028, 'grad_norm': 0.9064767956733704, 'learning_rate': 4.372777777777778e-05, 'epoch': 1.0}
{'loss': 0.0033, 'grad_norm': 0.24005761742591858, 'learning_rate': 4.3722222222222224e-05, 'epoch': 1.0}
{'loss': 0.0027, 'grad_norm': 0.1974848061800003, 'learning_rate': 4.371666666666667e-05, 'epoch': 1.01}
{'loss': 0.0043, 'grad_norm': 0.67525714635849, 'learning_rate': 4.371111111111111e-05, 'epoch': 1.01}
{'loss': 0.0029, 'grad_norm': 0.4251466393470764, 'learning_rate': 4.3705555555555555e-05, 'epoch': 1.01}
{'loss': 0.0024, '

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0031866563949733973, 'eval_runtime': 102.4579, 'eval_samples_per_second': 1464.016, 'eval_steps_per_second': 36.6, 'epoch': 2.0}
{'loss': 0.0038, 'grad_norm': 1.109066367149353, 'learning_rate': 3.749444444444445e-05, 'epoch': 2.0}
{'loss': 0.0034, 'grad_norm': 1.4462045431137085, 'learning_rate': 3.748888888888889e-05, 'epoch': 2.0}
{'loss': 0.0025, 'grad_norm': 0.2516457438468933, 'learning_rate': 3.7483333333333334e-05, 'epoch': 2.0}
{'loss': 0.0028, 'grad_norm': 0.6007432341575623, 'learning_rate': 3.747777777777778e-05, 'epoch': 2.0}
{'loss': 0.0017, 'grad_norm': 0.6436458826065063, 'learning_rate': 3.747222222222223e-05, 'epoch': 2.0}
{'loss': 0.0033, 'grad_norm': 0.4597722589969635, 'learning_rate': 3.7466666666666665e-05, 'epoch': 2.01}
{'loss': 0.0023, 'grad_norm': 1.028822898864746, 'learning_rate': 3.7461111111111115e-05, 'epoch': 2.01}
{'loss': 0.0036, 'grad_norm': 0.5689339637756348, 'learning_rate': 3.745555555555555e-05, 'epoch': 2.01}
{'loss': 0.0032, 'g

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0024676795583218336, 'eval_runtime': 106.503, 'eval_samples_per_second': 1408.411, 'eval_steps_per_second': 35.21, 'epoch': 3.0}
{'loss': 0.0028, 'grad_norm': 0.2588590383529663, 'learning_rate': 3.124444444444445e-05, 'epoch': 3.0}
{'loss': 0.0017, 'grad_norm': 0.1300494223833084, 'learning_rate': 3.123888888888889e-05, 'epoch': 3.0}
{'loss': 0.0017, 'grad_norm': 0.09502961486577988, 'learning_rate': 3.123333333333334e-05, 'epoch': 3.0}
{'loss': 0.0019, 'grad_norm': 0.3792899549007416, 'learning_rate': 3.1227777777777775e-05, 'epoch': 3.0}
{'loss': 0.0023, 'grad_norm': 0.6633902788162231, 'learning_rate': 3.1222222222222225e-05, 'epoch': 3.0}
{'loss': 0.0026, 'grad_norm': 0.4482229948043823, 'learning_rate': 3.121666666666667e-05, 'epoch': 3.01}
{'loss': 0.0021, 'grad_norm': 0.3182070255279541, 'learning_rate': 3.121111111111111e-05, 'epoch': 3.01}
{'loss': 0.002, 'grad_norm': 0.11141405999660492, 'learning_rate': 3.1205555555555556e-05, 'epoch': 3.01}
{'loss': 0.0017,

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0022131421137601137, 'eval_runtime': 102.9554, 'eval_samples_per_second': 1456.942, 'eval_steps_per_second': 36.424, 'epoch': 4.0}
{'loss': 0.0022, 'grad_norm': 0.08853283524513245, 'learning_rate': 2.4994444444444445e-05, 'epoch': 4.0}
{'loss': 0.0016, 'grad_norm': 0.24838799238204956, 'learning_rate': 2.498888888888889e-05, 'epoch': 4.0}
{'loss': 0.0019, 'grad_norm': 0.08376079797744751, 'learning_rate': 2.4983333333333335e-05, 'epoch': 4.0}
{'loss': 0.0026, 'grad_norm': 0.1622951179742813, 'learning_rate': 2.497777777777778e-05, 'epoch': 4.0}
{'loss': 0.0023, 'grad_norm': 0.16591361165046692, 'learning_rate': 2.4972222222222226e-05, 'epoch': 4.0}
{'loss': 0.0016, 'grad_norm': 0.1229059025645256, 'learning_rate': 2.496666666666667e-05, 'epoch': 4.01}
{'loss': 0.002, 'grad_norm': 0.420931339263916, 'learning_rate': 2.4961111111111113e-05, 'epoch': 4.01}
{'loss': 0.0023, 'grad_norm': 0.36335819959640503, 'learning_rate': 2.4955555555555556e-05, 'epoch': 4.01}
{'loss': 0

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.002139003248885274, 'eval_runtime': 106.6943, 'eval_samples_per_second': 1405.885, 'eval_steps_per_second': 35.147, 'epoch': 5.0}
{'loss': 0.0023, 'grad_norm': 0.23870186507701874, 'learning_rate': 1.8744444444444445e-05, 'epoch': 5.0}
{'loss': 0.0021, 'grad_norm': 0.47209247946739197, 'learning_rate': 1.873888888888889e-05, 'epoch': 5.0}
{'loss': 0.002, 'grad_norm': 0.44062328338623047, 'learning_rate': 1.8733333333333332e-05, 'epoch': 5.0}
{'loss': 0.0016, 'grad_norm': 0.23614591360092163, 'learning_rate': 1.8727777777777776e-05, 'epoch': 5.0}
{'loss': 0.0022, 'grad_norm': 0.35470524430274963, 'learning_rate': 1.8722222222222223e-05, 'epoch': 5.0}
{'loss': 0.0029, 'grad_norm': 0.7083866596221924, 'learning_rate': 1.871666666666667e-05, 'epoch': 5.01}
{'loss': 0.0026, 'grad_norm': 0.19944386184215546, 'learning_rate': 1.8711111111111113e-05, 'epoch': 5.01}
{'loss': 0.0021, 'grad_norm': 0.6269383430480957, 'learning_rate': 1.8705555555555557e-05, 'epoch': 5.01}
{'loss':

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0019739249255508184, 'eval_runtime': 114.496, 'eval_samples_per_second': 1310.089, 'eval_steps_per_second': 32.752, 'epoch': 6.0}
{'loss': 0.0023, 'grad_norm': 0.19304610788822174, 'learning_rate': 1.2494444444444444e-05, 'epoch': 6.0}
{'loss': 0.0012, 'grad_norm': 0.41775524616241455, 'learning_rate': 1.248888888888889e-05, 'epoch': 6.0}
{'loss': 0.0029, 'grad_norm': 0.07973397523164749, 'learning_rate': 1.2483333333333335e-05, 'epoch': 6.0}
{'loss': 0.0022, 'grad_norm': 0.01660463958978653, 'learning_rate': 1.2477777777777778e-05, 'epoch': 6.0}
{'loss': 0.0026, 'grad_norm': 0.04342953860759735, 'learning_rate': 1.2472222222222223e-05, 'epoch': 6.0}
{'loss': 0.0018, 'grad_norm': 0.23170316219329834, 'learning_rate': 1.2466666666666667e-05, 'epoch': 6.01}
{'loss': 0.0018, 'grad_norm': 0.3787147104740143, 'learning_rate': 1.2461111111111112e-05, 'epoch': 6.01}
{'loss': 0.0017, 'grad_norm': 0.45814672112464905, 'learning_rate': 1.2455555555555556e-05, 'epoch': 6.01}
{'los

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.001908026752062142, 'eval_runtime': 105.2784, 'eval_samples_per_second': 1424.793, 'eval_steps_per_second': 35.62, 'epoch': 7.0}
{'loss': 0.0013, 'grad_norm': 0.23185977339744568, 'learning_rate': 6.244444444444445e-06, 'epoch': 7.0}
{'loss': 0.0013, 'grad_norm': 0.263492614030838, 'learning_rate': 6.238888888888889e-06, 'epoch': 7.0}
{'loss': 0.0014, 'grad_norm': 0.11268585175275803, 'learning_rate': 6.2333333333333335e-06, 'epoch': 7.0}
{'loss': 0.0018, 'grad_norm': 0.08103465288877487, 'learning_rate': 6.227777777777778e-06, 'epoch': 7.0}
{'loss': 0.0021, 'grad_norm': 0.07957027852535248, 'learning_rate': 6.222222222222222e-06, 'epoch': 7.0}
{'loss': 0.0027, 'grad_norm': 0.04563117027282715, 'learning_rate': 6.2166666666666676e-06, 'epoch': 7.01}
{'loss': 0.0017, 'grad_norm': 0.2623458802700043, 'learning_rate': 6.211111111111111e-06, 'epoch': 7.01}
{'loss': 0.002, 'grad_norm': 0.08248522877693176, 'learning_rate': 6.205555555555556e-06, 'epoch': 7.01}
{'loss': 0.001

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0018459202256053686, 'eval_runtime': 110.6357, 'eval_samples_per_second': 1355.801, 'eval_steps_per_second': 33.895, 'epoch': 8.0}
{'train_runtime': 5052.2895, 'train_samples_per_second': 712.548, 'train_steps_per_second': 17.814, 'train_loss': 0.0024398397880399394, 'epoch': 8.0}


In [7]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)

In [8]:
from src.visualization.generate_mesh import generate_mesh
from src.data.load_data import get_data_dir

obj_path = get_data_dir() / 'intermediate' / 'bunny' / 'stanford-bunny.obj'
infered_obj_path = result_dir / f"bunny-{current_date}-generalization.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [9]:
hdf5_path = get_data_dir() / 'processed' / 'bunny' / 'stanford-bunny_train.hdf5'
infered_obj_path = result_dir / f"{current_date}-train.obj"
generate_mesh(model, hdf5_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]