In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

In [2]:
from src.models.modules import *
from src.models.loss import L1_clamp_loss
from dataclasses import dataclass
import torch

# process_models(source_dir, target_dir, train_ratio=0.8, context_size=100, batch_size=100100, num_samples=500501)

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    num_inds: int = 32
    dim_hidden: int = 128
    num_heads: int = 4
    ln: bool = False

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.input_proj = nn.Sequential(nn.Linear(config.dim_input, config.dim_hidden), nn.ReLU())
        self.context_proj = nn.Sequential(nn.Linear(config.dim_context, config.dim_hidden), nn.ReLU())
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.enc = nn.Sequential(
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln)
        )
        self.dec = nn.Sequential(
            PMA(config.dim_hidden, config.num_heads, config.num_outputs, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
        )
        self.regr = nn.Sequential(
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        x = self.input_proj(x)          # [batch_size, num_outputs, dim_hidden]
        y = self.context_proj(context)  # [batch_size, num_outputs, dim_hidden]
        x = x.repeat(1, y.shape[1], 1)  # [batch_size, context_size, dim_hidden]
        y = self.cross(x, y)            # [batch_size, context_size, dim_hidden]
        y = self.enc(y)                 # [batch_size, context_size, dim_hidden]
        y = self.dec(y)                 # [batch_size, num_outputs, dim_hidden]
        y = self.regr(y)                # [batch_size, num_outputs, dim_output]

        loss = None
        if labels is not None:
            loss = L1_clamp_loss(y, labels, self.config.delta)
        return {'loss': loss, 'logits': y}

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [4]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_11_13_cross_first'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_11_13_cross_first-2024-11-13-18-59-57


In [5]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=8,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [6]:
trainer.train()

  0%|          | 0/150000 [00:00<?, ?it/s]

{'loss': 0.0183, 'grad_norm': 0.6731148362159729, 'learning_rate': 4.999666666666667e-05, 'epoch': 0.0}
{'loss': 0.0082, 'grad_norm': 0.7922601103782654, 'learning_rate': 4.9993333333333335e-05, 'epoch': 0.0}
{'loss': 0.0069, 'grad_norm': 0.12218088656663895, 'learning_rate': 4.999e-05, 'epoch': 0.0}
{'loss': 0.0071, 'grad_norm': 0.24234633147716522, 'learning_rate': 4.9986666666666674e-05, 'epoch': 0.0}
{'loss': 0.0057, 'grad_norm': 0.060814812779426575, 'learning_rate': 4.998333333333334e-05, 'epoch': 0.0}
{'loss': 0.0088, 'grad_norm': 0.2415844202041626, 'learning_rate': 4.9980000000000006e-05, 'epoch': 0.0}
{'loss': 0.0074, 'grad_norm': 0.30283230543136597, 'learning_rate': 4.997666666666667e-05, 'epoch': 0.0}
{'loss': 0.0073, 'grad_norm': 0.1814224272966385, 'learning_rate': 4.997333333333333e-05, 'epoch': 0.0}
{'loss': 0.0075, 'grad_norm': 0.060970257967710495, 'learning_rate': 4.997e-05, 'epoch': 0.0}
{'loss': 0.0064, 'grad_norm': 0.2421507090330124, 'learning_rate': 4.996666666

  0%|          | 0/6250 [00:00<?, ?it/s]

{'eval_loss': 0.004160473123192787, 'eval_runtime': 169.7276, 'eval_samples_per_second': 1472.948, 'eval_steps_per_second': 36.824, 'epoch': 1.0}
{'loss': 0.0045, 'grad_norm': 0.34863024950027466, 'learning_rate': 4.374666666666667e-05, 'epoch': 1.0}
{'loss': 0.0036, 'grad_norm': 0.31931203603744507, 'learning_rate': 4.374333333333334e-05, 'epoch': 1.0}
{'loss': 0.0038, 'grad_norm': 0.10520172119140625, 'learning_rate': 4.3740000000000005e-05, 'epoch': 1.0}
{'loss': 0.003, 'grad_norm': 0.029643818736076355, 'learning_rate': 4.373666666666667e-05, 'epoch': 1.0}
{'loss': 0.0024, 'grad_norm': 0.0865369439125061, 'learning_rate': 4.373333333333334e-05, 'epoch': 1.0}
{'loss': 0.0038, 'grad_norm': 0.013561271131038666, 'learning_rate': 4.373e-05, 'epoch': 1.0}
{'loss': 0.003, 'grad_norm': 0.02915349043905735, 'learning_rate': 4.372666666666667e-05, 'epoch': 1.0}
{'loss': 0.0031, 'grad_norm': 0.08746315538883209, 'learning_rate': 4.3723333333333335e-05, 'epoch': 1.0}
{'loss': 0.0043, 'grad_no

  0%|          | 0/6250 [00:00<?, ?it/s]

{'eval_loss': 0.0032997415401041508, 'eval_runtime': 169.1196, 'eval_samples_per_second': 1478.244, 'eval_steps_per_second': 36.956, 'epoch': 2.0}
{'loss': 0.0025, 'grad_norm': 0.08559174090623856, 'learning_rate': 3.749666666666667e-05, 'epoch': 2.0}
{'loss': 0.0033, 'grad_norm': 0.05706758052110672, 'learning_rate': 3.7493333333333336e-05, 'epoch': 2.0}
{'loss': 0.0035, 'grad_norm': 0.02853216417133808, 'learning_rate': 3.749e-05, 'epoch': 2.0}
{'loss': 0.0046, 'grad_norm': 0.45647749304771423, 'learning_rate': 3.748666666666667e-05, 'epoch': 2.0}
{'loss': 0.0042, 'grad_norm': 3.682918071746826, 'learning_rate': 3.7483333333333334e-05, 'epoch': 2.0}
{'loss': 0.0041, 'grad_norm': 0.2853260338306427, 'learning_rate': 3.748000000000001e-05, 'epoch': 2.0}
{'loss': 0.0031, 'grad_norm': 0.05708394944667816, 'learning_rate': 3.747666666666667e-05, 'epoch': 2.0}
{'loss': 0.0028, 'grad_norm': 0.313849538564682, 'learning_rate': 3.747333333333333e-05, 'epoch': 2.0}
{'loss': 0.0042, 'grad_norm'

  0%|          | 0/6250 [00:00<?, ?it/s]

{'eval_loss': 0.003302820725366473, 'eval_runtime': 164.6459, 'eval_samples_per_second': 1518.411, 'eval_steps_per_second': 37.96, 'epoch': 3.0}
{'loss': 0.0029, 'grad_norm': 0.11310239136219025, 'learning_rate': 3.124666666666667e-05, 'epoch': 3.0}
{'loss': 0.003, 'grad_norm': 1.0893789529800415, 'learning_rate': 3.124333333333333e-05, 'epoch': 3.0}
{'loss': 0.0032, 'grad_norm': 0.2262103408575058, 'learning_rate': 3.1240000000000006e-05, 'epoch': 3.0}
{'loss': 0.0028, 'grad_norm': 0.339316725730896, 'learning_rate': 3.123666666666667e-05, 'epoch': 3.0}
{'loss': 0.0023, 'grad_norm': 0.3393224775791168, 'learning_rate': 3.123333333333334e-05, 'epoch': 3.0}
{'loss': 0.0028, 'grad_norm': 0.508935272693634, 'learning_rate': 3.1230000000000004e-05, 'epoch': 3.0}
{'loss': 0.0034, 'grad_norm': 0.22620069980621338, 'learning_rate': 3.122666666666667e-05, 'epoch': 3.0}
{'loss': 0.0026, 'grad_norm': 0.3675951361656189, 'learning_rate': 3.122333333333333e-05, 'epoch': 3.0}
{'loss': 0.0033, 'grad

  0%|          | 0/6250 [00:00<?, ?it/s]

{'eval_loss': 0.002995244227349758, 'eval_runtime': 168.1741, 'eval_samples_per_second': 1486.555, 'eval_steps_per_second': 37.164, 'epoch': 4.0}
{'loss': 0.0023, 'grad_norm': 0.4226841628551483, 'learning_rate': 2.4996666666666667e-05, 'epoch': 4.0}
{'loss': 0.0031, 'grad_norm': 0.25363728404045105, 'learning_rate': 2.4993333333333337e-05, 'epoch': 4.0}
{'loss': 0.0021, 'grad_norm': 0.169075146317482, 'learning_rate': 2.4990000000000003e-05, 'epoch': 4.0}
{'loss': 0.002, 'grad_norm': 0.14091017842292786, 'learning_rate': 2.4986666666666666e-05, 'epoch': 4.0}
{'loss': 0.0027, 'grad_norm': 0.5354135632514954, 'learning_rate': 2.4983333333333335e-05, 'epoch': 4.0}
{'loss': 0.0043, 'grad_norm': 0.028180204331874847, 'learning_rate': 2.498e-05, 'epoch': 4.0}
{'loss': 0.0032, 'grad_norm': 0.2818049490451813, 'learning_rate': 2.4976666666666668e-05, 'epoch': 4.0}
{'loss': 0.0037, 'grad_norm': 0.42267364263534546, 'learning_rate': 2.4973333333333334e-05, 'epoch': 4.0}
{'loss': 0.0023, 'grad_n

  0%|          | 0/6250 [00:00<?, ?it/s]

{'eval_loss': 0.002923607360571623, 'eval_runtime': 172.5992, 'eval_samples_per_second': 1448.442, 'eval_steps_per_second': 36.211, 'epoch': 5.0}
{'loss': 0.003, 'grad_norm': 4.863276892308477e-09, 'learning_rate': 1.8746666666666668e-05, 'epoch': 5.0}
{'loss': 0.0022, 'grad_norm': 0.25614747405052185, 'learning_rate': 1.8743333333333334e-05, 'epoch': 5.0}
{'loss': 0.0034, 'grad_norm': 0.3935626149177551, 'learning_rate': 1.8740000000000004e-05, 'epoch': 5.0}
{'loss': 0.0025, 'grad_norm': 0.14055797457695007, 'learning_rate': 1.8736666666666666e-05, 'epoch': 5.0}
{'loss': 0.003, 'grad_norm': 0.3935420513153076, 'learning_rate': 1.8733333333333332e-05, 'epoch': 5.0}
{'loss': 0.0034, 'grad_norm': 0.14055457711219788, 'learning_rate': 1.8730000000000002e-05, 'epoch': 5.0}
{'loss': 0.0022, 'grad_norm': 0.22488826513290405, 'learning_rate': 1.8726666666666668e-05, 'epoch': 5.0}
{'loss': 0.0027, 'grad_norm': 0.2248765081167221, 'learning_rate': 1.8723333333333334e-05, 'epoch': 5.0}
{'loss': 

  0%|          | 0/6250 [00:00<?, ?it/s]

{'eval_loss': 0.002873950870707631, 'eval_runtime': 171.5774, 'eval_samples_per_second': 1457.068, 'eval_steps_per_second': 36.427, 'epoch': 6.0}
{'loss': 0.0032, 'grad_norm': 0.05608152225613594, 'learning_rate': 1.2496666666666668e-05, 'epoch': 6.0}
{'loss': 0.0028, 'grad_norm': 2.9871383144808306e-09, 'learning_rate': 1.2493333333333333e-05, 'epoch': 6.0}
{'loss': 0.0028, 'grad_norm': 0.0841216966509819, 'learning_rate': 1.249e-05, 'epoch': 6.0}
{'loss': 0.0024, 'grad_norm': 0.02804020419716835, 'learning_rate': 1.2486666666666667e-05, 'epoch': 6.0}
{'loss': 0.0028, 'grad_norm': 0.14020352065563202, 'learning_rate': 1.2483333333333335e-05, 'epoch': 6.0}
{'loss': 0.0032, 'grad_norm': 0.23690302670001984, 'learning_rate': 1.248e-05, 'epoch': 6.0}
{'loss': 0.0038, 'grad_norm': 0.36453261971473694, 'learning_rate': 1.2476666666666667e-05, 'epoch': 6.0}
{'loss': 0.0024, 'grad_norm': 0.16824021935462952, 'learning_rate': 1.2473333333333335e-05, 'epoch': 6.0}
{'loss': 0.0034, 'grad_norm': 

  0%|          | 0/6250 [00:00<?, ?it/s]

{'eval_loss': 0.0027004729490727186, 'eval_runtime': 172.6693, 'eval_samples_per_second': 1447.854, 'eval_steps_per_second': 36.196, 'epoch': 7.0}
{'loss': 0.0024, 'grad_norm': 0.028018372133374214, 'learning_rate': 6.2466666666666664e-06, 'epoch': 7.0}
{'loss': 0.0053, 'grad_norm': 0.2241506278514862, 'learning_rate': 6.243333333333333e-06, 'epoch': 7.0}
{'loss': 0.0017, 'grad_norm': 0.11207615584135056, 'learning_rate': 6.24e-06, 'epoch': 7.0}
{'loss': 0.0021, 'grad_norm': 0.1681102216243744, 'learning_rate': 6.236666666666667e-06, 'epoch': 7.0}
{'loss': 0.0019, 'grad_norm': 0.11207246780395508, 'learning_rate': 6.2333333333333335e-06, 'epoch': 7.0}
{'loss': 0.0018, 'grad_norm': 0.16810980439186096, 'learning_rate': 6.2300000000000005e-06, 'epoch': 7.0}
{'loss': 0.0024, 'grad_norm': 0.05603798106312752, 'learning_rate': 6.226666666666667e-06, 'epoch': 7.0}
{'loss': 0.003, 'grad_norm': 0.05603867396712303, 'learning_rate': 6.223333333333334e-06, 'epoch': 7.0}
{'loss': 0.0029, 'grad_no

  0%|          | 0/6250 [00:00<?, ?it/s]

{'eval_loss': 0.0026104876305907965, 'eval_runtime': 166.3201, 'eval_samples_per_second': 1503.126, 'eval_steps_per_second': 37.578, 'epoch': 8.0}
{'train_runtime': 8086.8179, 'train_samples_per_second': 741.948, 'train_steps_per_second': 18.549, 'train_loss': 0.003053902931082994, 'epoch': 8.0}


TrainOutput(global_step=150000, training_loss=0.003053902931082994, metrics={'train_runtime': 8086.8179, 'train_samples_per_second': 741.948, 'train_steps_per_second': 18.549, 'total_flos': 0.0, 'train_loss': 0.003053902931082994, 'epoch': 8.0})

In [7]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)

In [8]:
from src.visualization.generate_mesh import generate_mesh
from src.data.load_data import get_data_dir

obj_path = get_data_dir() / 'intermediate' / 'bunny' / 'stanford-bunny.obj'
infered_obj_path = result_dir / f"{current_date}-generalization.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [9]:
train_dataset.close()
val_dataset.close()

hdf5_path = get_data_dir() / 'processed' / 'bunny' / 'stanford-bunny_train.hdf5'
infered_obj_path = result_dir / f"{current_date}-train.obj"
generate_mesh(model, hdf5_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [10]:
import importlib
import src.visualization.generate_mesh
importlib.reload(src.visualization.generate_mesh)

<module 'src.visualization.generate_mesh' from 'c:\\_prog\\vm_shared\\attention-sdf\\src\\visualization\\generate_mesh.py'>