In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
from src.models.modules import *
from src.models.loss import L1_clamp_loss
from dataclasses import dataclass
import torch

@dataclass
class SDFTransformerConfig:
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    num_inds: int = 32
    dim_hidden: int = 128
    num_heads: int = 4
    ln: bool = False

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.enc = nn.Sequential(
            ISAB(config.dim_input, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln)
        )
        self.ctx_proj = nn.Linear(3, config.dim_hidden)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.dec = nn.Sequential(
            PMA(config.dim_hidden, config.num_heads, config.num_outputs, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.Linear(config.dim_hidden, config.dim_output)
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor=None):
        y = self.enc(context)           # [batch_size, context_size, dim_hidden]
        x = self.ctx_proj(x)            # [batch_size, 1 , dim_hidden]
        x = x.repeat(1, y.shape[1], 1)  # [batch_size, context_size, dim_hidden]

        y = self.cross(x, y)
        sdf = self.dec(y)

        loss = None
        if labels is not None:
            loss = L1_clamp_loss(sdf, labels, self.config.delta)
        return {'loss': loss, 'logits': sdf}

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = SDFTransformerConfig(dim_input=4, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [8]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_10_30_set_transformer'
current_date = datetime.now().strftime("%Y-%m-%d")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_10_30_set_transformer-2024-11-07


In [4]:
from transformers import Trainer, TrainingArguments

batch_size = 10
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [5]:
trainer.train()

  0%|          | 0/216000 [00:00<?, ?it/s]

{'loss': 0.035, 'grad_norm': 3.8627631664276123, 'learning_rate': 4.999768518518519e-05, 'epoch': 0.0}
{'loss': 0.0336, 'grad_norm': 4.662616729736328, 'learning_rate': 4.999537037037037e-05, 'epoch': 0.0}
{'loss': 0.026, 'grad_norm': 0.8776947259902954, 'learning_rate': 4.999305555555556e-05, 'epoch': 0.0}
{'loss': 0.0141, 'grad_norm': 2.6966798305511475, 'learning_rate': 4.9990740740740746e-05, 'epoch': 0.0}
{'loss': 0.0124, 'grad_norm': 0.9936151504516602, 'learning_rate': 4.998842592592593e-05, 'epoch': 0.0}
{'loss': 0.0061, 'grad_norm': 1.1760140657424927, 'learning_rate': 4.9986111111111115e-05, 'epoch': 0.0}
{'loss': 0.0081, 'grad_norm': 4.287510871887207, 'learning_rate': 4.9983796296296296e-05, 'epoch': 0.0}
{'loss': 0.01, 'grad_norm': 2.0086071491241455, 'learning_rate': 4.998148148148148e-05, 'epoch': 0.0}
{'loss': 0.0083, 'grad_norm': 1.8179482221603394, 'learning_rate': 4.997916666666667e-05, 'epoch': 0.0}
{'loss': 0.0086, 'grad_norm': 1.7934503555297852, 'learning_rate': 

  0%|          | 0/10800 [00:00<?, ?it/s]

{'eval_loss': 0.005603063385933638, 'eval_runtime': 115.5833, 'eval_samples_per_second': 934.391, 'eval_steps_per_second': 93.439, 'epoch': 1.0}
{'loss': 0.0065, 'grad_norm': 0.04135667905211449, 'learning_rate': 3.9997685185185184e-05, 'epoch': 1.0}
{'loss': 0.0022, 'grad_norm': 0.5703533291816711, 'learning_rate': 3.999537037037037e-05, 'epoch': 1.0}
{'loss': 0.0045, 'grad_norm': 0.42849647998809814, 'learning_rate': 3.999305555555556e-05, 'epoch': 1.0}
{'loss': 0.0035, 'grad_norm': 0.28820276260375977, 'learning_rate': 3.999074074074075e-05, 'epoch': 1.0}
{'loss': 0.0053, 'grad_norm': 0.2853175699710846, 'learning_rate': 3.998842592592593e-05, 'epoch': 1.0}
{'loss': 0.0061, 'grad_norm': 0.6524989008903503, 'learning_rate': 3.9986111111111116e-05, 'epoch': 1.0}
{'loss': 0.0054, 'grad_norm': 1.13398015499115, 'learning_rate': 3.9983796296296296e-05, 'epoch': 1.0}
{'loss': 0.0067, 'grad_norm': 0.28377723693847656, 'learning_rate': 3.9981481481481484e-05, 'epoch': 1.0}
{'loss': 0.0047, 

  0%|          | 0/10800 [00:00<?, ?it/s]

{'eval_loss': 0.005064943805336952, 'eval_runtime': 107.5726, 'eval_samples_per_second': 1003.973, 'eval_steps_per_second': 100.397, 'epoch': 2.0}
{'loss': 0.0035, 'grad_norm': 0.8131235241889954, 'learning_rate': 2.999768518518519e-05, 'epoch': 2.0}
{'loss': 0.0079, 'grad_norm': 0.2929728329181671, 'learning_rate': 2.9995370370370373e-05, 'epoch': 2.0}
{'loss': 0.0052, 'grad_norm': 0.4122086465358734, 'learning_rate': 2.9993055555555554e-05, 'epoch': 2.0}
{'loss': 0.0028, 'grad_norm': 0.5468218922615051, 'learning_rate': 2.999074074074074e-05, 'epoch': 2.0}
{'loss': 0.005, 'grad_norm': 0.8080326914787292, 'learning_rate': 2.998842592592593e-05, 'epoch': 2.0}
{'loss': 0.0057, 'grad_norm': 0.5767069458961487, 'learning_rate': 2.9986111111111116e-05, 'epoch': 2.0}
{'loss': 0.0049, 'grad_norm': 0.6858397722244263, 'learning_rate': 2.9983796296296297e-05, 'epoch': 2.0}
{'loss': 0.0057, 'grad_norm': 0.05946849286556244, 'learning_rate': 2.998148148148148e-05, 'epoch': 2.0}
{'loss': 0.0023, 

  0%|          | 0/10800 [00:00<?, ?it/s]

{'eval_loss': 0.003693122649565339, 'eval_runtime': 111.4726, 'eval_samples_per_second': 968.848, 'eval_steps_per_second': 96.885, 'epoch': 3.0}
{'loss': 0.0054, 'grad_norm': 0.18069368600845337, 'learning_rate': 1.9997685185185186e-05, 'epoch': 3.0}
{'loss': 0.0055, 'grad_norm': 0.6658015251159668, 'learning_rate': 1.9995370370370374e-05, 'epoch': 3.0}
{'loss': 0.0028, 'grad_norm': 0.08630233258008957, 'learning_rate': 1.9993055555555558e-05, 'epoch': 3.0}
{'loss': 0.0014, 'grad_norm': 0.2955077588558197, 'learning_rate': 1.9990740740740742e-05, 'epoch': 3.0}
{'loss': 0.0043, 'grad_norm': 0.18993699550628662, 'learning_rate': 1.9988425925925926e-05, 'epoch': 3.0}
{'loss': 0.0036, 'grad_norm': 0.8285302519798279, 'learning_rate': 1.998611111111111e-05, 'epoch': 3.0}
{'loss': 0.0061, 'grad_norm': 0.08143265545368195, 'learning_rate': 1.9983796296296298e-05, 'epoch': 3.0}
{'loss': 0.0027, 'grad_norm': 0.4031091332435608, 'learning_rate': 1.9981481481481482e-05, 'epoch': 3.0}
{'loss': 0.0

  0%|          | 0/10800 [00:00<?, ?it/s]

{'eval_loss': 0.0033602765761315823, 'eval_runtime': 108.9366, 'eval_samples_per_second': 991.403, 'eval_steps_per_second': 99.14, 'epoch': 4.0}
{'loss': 0.0015, 'grad_norm': 0.799359917640686, 'learning_rate': 9.997685185185187e-06, 'epoch': 4.0}
{'loss': 0.0045, 'grad_norm': 0.2805522382259369, 'learning_rate': 9.995370370370371e-06, 'epoch': 4.0}
{'loss': 0.0049, 'grad_norm': 0.29578906297683716, 'learning_rate': 9.993055555555555e-06, 'epoch': 4.0}
{'loss': 0.0016, 'grad_norm': 0.1463368833065033, 'learning_rate': 9.990740740740741e-06, 'epoch': 4.0}
{'loss': 0.0013, 'grad_norm': 0.1777307689189911, 'learning_rate': 9.988425925925925e-06, 'epoch': 4.0}
{'loss': 0.0052, 'grad_norm': 0.19785714149475098, 'learning_rate': 9.986111111111111e-06, 'epoch': 4.0}
{'loss': 0.0021, 'grad_norm': 0.5508069396018982, 'learning_rate': 9.983796296296297e-06, 'epoch': 4.0}
{'loss': 0.0026, 'grad_norm': 0.28854215145111084, 'learning_rate': 9.981481481481482e-06, 'epoch': 4.0}
{'loss': 0.005, 'grad

  0%|          | 0/10800 [00:00<?, ?it/s]

{'eval_loss': 0.0030985488556325436, 'eval_runtime': 107.9719, 'eval_samples_per_second': 1000.26, 'eval_steps_per_second': 100.026, 'epoch': 5.0}
{'train_runtime': 5741.8086, 'train_samples_per_second': 376.188, 'train_steps_per_second': 37.619, 'train_loss': 0.004502062981607634, 'epoch': 5.0}


TrainOutput(global_step=216000, training_loss=0.004502062981607634, metrics={'train_runtime': 5741.8086, 'train_samples_per_second': 376.188, 'train_steps_per_second': 37.619, 'total_flos': 0.0, 'train_loss': 0.004502062981607634, 'epoch': 5.0})

In [11]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)

In [44]:
from src.visualization.generate_mesh import generate_mesh
from src.data.load_data import get_data_dir

obj_path = get_data_dir() / 'intermediate' / 'bunny' / 'stanford-bunny.obj'
infered_obj_path = result_dir / f"{current_date}.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/100000 [00:00<?, ?it/s]

In [40]:
import importlib
import src.visualization.generate_mesh
importlib.reload(src.visualization.generate_mesh)

<module 'src.visualization.generate_mesh' from 'c:\\_prog\\vm_shared\\attention-sdf\\src\\visualization\\generate_mesh.py'>