In [12]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [13]:
from src.models.modules import *
from src.models.loss import L1_clamp_loss
from dataclasses import dataclass
import torch

# process_models(source_dir, target_dir, train_ratio=0.8, context_size=64, batch_size=512, num_samples=131_072)

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    num_inds: int = 32
    dim_hidden: int = 128
    num_heads: int = 4
    ln: bool = False

class SDFEncoder(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFEncoder, self).__init__()
        self.config = config
        self.enc = nn.Sequential(
            ISAB(config.dim_context, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln)
        )
        self.input_proj = nn.Linear(config.dim_input, config.dim_hidden)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.dec = nn.Sequential(
            PMA(config.dim_hidden, config.num_heads, config.num_outputs, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor):
        y = self.enc(context)           # [batch_size, context_size, dim_hidden]
        x = self.input_proj(x)          # [batch_size, num_outputs, dim_hidden]
        x = x.repeat(1, y.shape[1], 1)  # [batch_size, context_size, dim_hidden]

        y = self.cross(x, y)
        return self.dec(y)


class SDFDecoder(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFDecoder, self).__init__()
        self.config = config
        self.net = nn.Sequential(
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_output)
        )

    def forward(self, x: torch.Tensor):
        return self.net(x)


class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.latent_proj = nn.Linear(config.dim_hidden, config.dim_hidden - 3)
        self.enc = SDFEncoder(config)
        self.dec = SDFDecoder(config)

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        y = self.enc(context, x)    # [batch_size, num_outputs, dim_hidden]
        y = self.latent_proj(y)     # [batch_size, num_outputs, dim_hidden - dim_input]
        y = torch.cat((y, x), -1)
        sdf = self.dec(y)

        loss = None
        if labels is not None:
            loss = L1_clamp_loss(sdf, labels, self.config.delta)
        return {'loss': loss, 'logits': sdf}

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [14]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [15]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_11_08_set_transformer'
current_date = datetime.now().strftime("%Y-%m-%d")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_11_08_set_transformer-2024-11-08


In [23]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [24]:
trainer.train()

  0%|          | 0/68545 [00:00<?, ?it/s]

{'loss': 0.0046, 'grad_norm': 0.4959447979927063, 'learning_rate': 4.999270552191991e-05, 'epoch': 0.0}
{'loss': 0.0065, 'grad_norm': 0.2919958531856537, 'learning_rate': 4.998541104383981e-05, 'epoch': 0.0}
{'loss': 0.0044, 'grad_norm': 0.46707943081855774, 'learning_rate': 4.9978116565759725e-05, 'epoch': 0.0}
{'loss': 0.0069, 'grad_norm': 0.08566362410783768, 'learning_rate': 4.997082208767963e-05, 'epoch': 0.0}
{'loss': 0.0058, 'grad_norm': 0.5266549587249756, 'learning_rate': 4.996352760959953e-05, 'epoch': 0.0}
{'loss': 0.0067, 'grad_norm': 0.030128000304102898, 'learning_rate': 4.995623313151944e-05, 'epoch': 0.0}
{'loss': 0.0066, 'grad_norm': 0.5259796380996704, 'learning_rate': 4.994893865343935e-05, 'epoch': 0.01}
{'loss': 0.0037, 'grad_norm': 0.22008076310157776, 'learning_rate': 4.9941644175359256e-05, 'epoch': 0.01}
{'loss': 0.0051, 'grad_norm': 0.1177249401807785, 'learning_rate': 4.9934349697279163e-05, 'epoch': 0.01}
{'loss': 0.0061, 'grad_norm': 0.29232075810432434, 'l

  0%|          | 0/3495 [00:00<?, ?it/s]

{'eval_loss': 0.004848572891205549, 'eval_runtime': 121.1171, 'eval_samples_per_second': 1154.057, 'eval_steps_per_second': 28.856, 'epoch': 1.0}
{'loss': 0.0071, 'grad_norm': 0.4796831011772156, 'learning_rate': 3.999927055219199e-05, 'epoch': 1.0}
{'loss': 0.0052, 'grad_norm': 0.29228702187538147, 'learning_rate': 3.99919760741119e-05, 'epoch': 1.0}
{'loss': 0.006, 'grad_norm': 0.08753100037574768, 'learning_rate': 3.99846815960318e-05, 'epoch': 1.0}
{'loss': 0.004, 'grad_norm': 0.3172089457511902, 'learning_rate': 3.9977387117951715e-05, 'epoch': 1.0}
{'loss': 0.0064, 'grad_norm': 0.6654368042945862, 'learning_rate': 3.997009263987162e-05, 'epoch': 1.0}
{'loss': 0.0053, 'grad_norm': 0.23126035928726196, 'learning_rate': 3.9962798161791524e-05, 'epoch': 1.0}
{'loss': 0.0068, 'grad_norm': 0.1443922370672226, 'learning_rate': 3.995550368371143e-05, 'epoch': 1.0}
{'loss': 0.0055, 'grad_norm': 0.4627682566642761, 'learning_rate': 3.994820920563134e-05, 'epoch': 1.01}
{'loss': 0.0049, 'gr

  0%|          | 0/3495 [00:00<?, ?it/s]

{'eval_loss': 0.004431820008903742, 'eval_runtime': 113.7745, 'eval_samples_per_second': 1228.536, 'eval_steps_per_second': 30.719, 'epoch': 2.0}
{'loss': 0.0037, 'grad_norm': 0.05793626606464386, 'learning_rate': 2.9998541104383983e-05, 'epoch': 2.0}
{'loss': 0.0042, 'grad_norm': 0.4118248522281647, 'learning_rate': 2.9991246626303887e-05, 'epoch': 2.0}
{'loss': 0.0044, 'grad_norm': 0.3749372959136963, 'learning_rate': 2.9983952148223798e-05, 'epoch': 2.0}
{'loss': 0.0035, 'grad_norm': 0.08614091575145721, 'learning_rate': 2.9976657670143706e-05, 'epoch': 2.0}
{'loss': 0.0044, 'grad_norm': 0.192571759223938, 'learning_rate': 2.996936319206361e-05, 'epoch': 2.0}
{'loss': 0.004, 'grad_norm': 0.23121118545532227, 'learning_rate': 2.9962068713983514e-05, 'epoch': 2.0}
{'loss': 0.0062, 'grad_norm': 0.08718898892402649, 'learning_rate': 2.995477423590342e-05, 'epoch': 2.0}
{'loss': 0.0045, 'grad_norm': 0.37485066056251526, 'learning_rate': 2.9947479757823325e-05, 'epoch': 2.01}
{'loss': 0.0

  0%|          | 0/3495 [00:00<?, ?it/s]

{'eval_loss': 0.0034337386023253202, 'eval_runtime': 113.7115, 'eval_samples_per_second': 1229.216, 'eval_steps_per_second': 30.736, 'epoch': 3.0}
{'loss': 0.0034, 'grad_norm': 0.03121889755129814, 'learning_rate': 1.9997811656575973e-05, 'epoch': 3.0}
{'loss': 0.0041, 'grad_norm': 0.10910124331712723, 'learning_rate': 1.999051717849588e-05, 'epoch': 3.0}
{'loss': 0.0029, 'grad_norm': 0.2589733302593231, 'learning_rate': 1.9983222700415785e-05, 'epoch': 3.0}
{'loss': 0.0037, 'grad_norm': 0.08652739971876144, 'learning_rate': 1.9975928222335692e-05, 'epoch': 3.0}
{'loss': 0.0036, 'grad_norm': 0.08705448359251022, 'learning_rate': 1.99686337442556e-05, 'epoch': 3.0}
{'loss': 0.0045, 'grad_norm': 0.22767654061317444, 'learning_rate': 1.9961339266175504e-05, 'epoch': 3.0}
{'loss': 0.0033, 'grad_norm': 0.17310458421707153, 'learning_rate': 1.9954044788095415e-05, 'epoch': 3.0}
{'loss': 0.0029, 'grad_norm': 0.03052341379225254, 'learning_rate': 1.994675031001532e-05, 'epoch': 3.01}
{'loss': 

  0%|          | 0/3495 [00:00<?, ?it/s]

{'eval_loss': 0.003013582667335868, 'eval_runtime': 121.4366, 'eval_samples_per_second': 1151.02, 'eval_steps_per_second': 28.78, 'epoch': 4.0}
{'loss': 0.0042, 'grad_norm': 0.2022276371717453, 'learning_rate': 9.997082208767963e-06, 'epoch': 4.0}
{'loss': 0.0032, 'grad_norm': 0.4341629147529602, 'learning_rate': 9.989787730687871e-06, 'epoch': 4.0}
{'loss': 0.003, 'grad_norm': 0.5241408348083496, 'learning_rate': 9.982493252607777e-06, 'epoch': 4.0}
{'loss': 0.0026, 'grad_norm': 0.10354064404964447, 'learning_rate': 9.975198774527683e-06, 'epoch': 4.0}
{'loss': 0.0026, 'grad_norm': 0.12383738160133362, 'learning_rate': 9.96790429644759e-06, 'epoch': 4.0}
{'loss': 0.0034, 'grad_norm': 0.01428354810923338, 'learning_rate': 9.960609818367496e-06, 'epoch': 4.0}
{'loss': 0.0016, 'grad_norm': 0.0585518479347229, 'learning_rate': 9.953315340287404e-06, 'epoch': 4.0}
{'loss': 0.0033, 'grad_norm': 0.17776042222976685, 'learning_rate': 9.94602086220731e-06, 'epoch': 4.01}
{'loss': 0.0041, 'grad

  0%|          | 0/3495 [00:00<?, ?it/s]

{'eval_loss': 0.002845516661182046, 'eval_runtime': 121.3438, 'eval_samples_per_second': 1151.901, 'eval_steps_per_second': 28.802, 'epoch': 5.0}
{'train_runtime': 4336.6165, 'train_samples_per_second': 632.235, 'train_steps_per_second': 15.806, 'train_loss': 0.003977885318570694, 'epoch': 5.0}


TrainOutput(global_step=68545, training_loss=0.003977885318570694, metrics={'train_runtime': 4336.6165, 'train_samples_per_second': 632.235, 'train_steps_per_second': 15.806, 'total_flos': 0.0, 'train_loss': 0.003977885318570694, 'epoch': 5.0})

In [25]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)

In [26]:
from src.visualization.generate_mesh import generate_mesh
from src.data.load_data import get_data_dir

obj_path = get_data_dir() / 'intermediate' / 'bunny' / 'stanford-bunny.obj'
infered_obj_path = result_dir / f"{current_date}.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [40]:
import importlib
import src.visualization.generate_mesh
importlib.reload(src.visualization.generate_mesh)

<module 'src.visualization.generate_mesh' from 'c:\\_prog\\vm_shared\\attention-sdf\\src\\visualization\\generate_mesh.py'>