In [11]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
from src.models.modules import *
from src.models.loss import L1_clamp_loss
from dataclasses import dataclass
import torch

# process_models(source_dir, target_dir, train_ratio=0.75, context_size=100, batch_size=250100, num_samples=4)

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    num_inds: int = 32
    dim_hidden: int = 128
    num_heads: int = 4
    ln: bool = False

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.enc = nn.Sequential(
            ISAB(config.dim_context, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln)
        )
        self.dec = nn.Sequential(
            PMA(config.dim_hidden, config.num_heads, config.num_outputs, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
        )
        self.input_proj = nn.Sequential(nn.Linear(config.dim_input, config.dim_hidden), nn.ReLU())
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.regr = nn.Sequential(
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        y = self.enc(context)                       # [batch_size, context_size, dim_hidden]
        y = self.dec(y)                             # [batch_size, num_outputs, dim_hidden]
        x = self.input_proj(x)                      # [batch_size, num_inputs, dim_hidden]
        x = x.repeat(1, self.config.num_outputs, 1) # [batch_size, num_outputs, dim_hidden]
        y = self.cross(x, y)                        # [batch_size, num_outputs, dim_hidden]
        y = self.regr(y)                            # [batch_size, num_outputs, dim_output]

        loss = None
        if labels is not None:
            loss = L1_clamp_loss(y, labels, self.config.delta)
        return {'loss': loss, 'logits': y}

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [13]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [14]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_11_14_cross_last'
current_date = datetime.now().strftime("%Y-%m-%d")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_11_14_cross_last-2024-11-14


In [15]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=8,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [16]:
trainer.train()

  0%|          | 0/120000 [00:00<?, ?it/s]

{'loss': 0.0229, 'grad_norm': 1.1760578155517578, 'learning_rate': 4.999583333333333e-05, 'epoch': 0.0}
{'loss': 0.0093, 'grad_norm': 0.9476481676101685, 'learning_rate': 4.999166666666667e-05, 'epoch': 0.0}
{'loss': 0.0063, 'grad_norm': 0.8865270614624023, 'learning_rate': 4.99875e-05, 'epoch': 0.0}
{'loss': 0.0077, 'grad_norm': 0.07775304466485977, 'learning_rate': 4.998333333333334e-05, 'epoch': 0.0}
{'loss': 0.0071, 'grad_norm': 0.8240468502044678, 'learning_rate': 4.997916666666667e-05, 'epoch': 0.0}
{'loss': 0.0065, 'grad_norm': 0.6842063665390015, 'learning_rate': 4.9975e-05, 'epoch': 0.0}
{'loss': 0.0068, 'grad_norm': 0.20630376040935516, 'learning_rate': 4.997083333333333e-05, 'epoch': 0.0}
{'loss': 0.0066, 'grad_norm': 0.4806221127510071, 'learning_rate': 4.996666666666667e-05, 'epoch': 0.01}
{'loss': 0.0083, 'grad_norm': 0.8168231844902039, 'learning_rate': 4.99625e-05, 'epoch': 0.01}
{'loss': 0.007, 'grad_norm': 0.8227691054344177, 'learning_rate': 4.995833333333333e-05, 'e

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.002675057388842106, 'eval_runtime': 133.2496, 'eval_samples_per_second': 1500.943, 'eval_steps_per_second': 37.524, 'epoch': 1.0}
{'loss': 0.0034, 'grad_norm': 0.44292980432510376, 'learning_rate': 4.374583333333334e-05, 'epoch': 1.0}
{'loss': 0.0026, 'grad_norm': 0.5053950548171997, 'learning_rate': 4.374166666666667e-05, 'epoch': 1.0}
{'loss': 0.0027, 'grad_norm': 0.07385949790477753, 'learning_rate': 4.3737500000000006e-05, 'epoch': 1.0}
{'loss': 0.0035, 'grad_norm': 0.1770777702331543, 'learning_rate': 4.373333333333334e-05, 'epoch': 1.0}
{'loss': 0.0029, 'grad_norm': 0.12899088859558105, 'learning_rate': 4.372916666666667e-05, 'epoch': 1.0}
{'loss': 0.0024, 'grad_norm': 0.4433104693889618, 'learning_rate': 4.3725000000000006e-05, 'epoch': 1.0}
{'loss': 0.0019, 'grad_norm': 0.2573915123939514, 'learning_rate': 4.372083333333333e-05, 'epoch': 1.0}
{'loss': 0.0025, 'grad_norm': 0.224347323179245, 'learning_rate': 4.371666666666667e-05, 'epoch': 1.01}
{'loss': 0.004, '

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0024623614735901356, 'eval_runtime': 132.8797, 'eval_samples_per_second': 1505.121, 'eval_steps_per_second': 37.628, 'epoch': 2.0}
{'loss': 0.002, 'grad_norm': 0.03974504396319389, 'learning_rate': 3.7495833333333334e-05, 'epoch': 2.0}
{'loss': 0.0019, 'grad_norm': 0.26374930143356323, 'learning_rate': 3.749166666666667e-05, 'epoch': 2.0}
{'loss': 0.0031, 'grad_norm': 0.160506471991539, 'learning_rate': 3.74875e-05, 'epoch': 2.0}
{'loss': 0.0029, 'grad_norm': 0.03963220864534378, 'learning_rate': 3.7483333333333334e-05, 'epoch': 2.0}
{'loss': 0.0039, 'grad_norm': 0.16082164645195007, 'learning_rate': 3.747916666666667e-05, 'epoch': 2.0}
{'loss': 0.0023, 'grad_norm': 0.2532998323440552, 'learning_rate': 3.7475e-05, 'epoch': 2.0}
{'loss': 0.002, 'grad_norm': 0.12883669137954712, 'learning_rate': 3.7470833333333334e-05, 'epoch': 2.0}
{'loss': 0.0032, 'grad_norm': 0.06386885046958923, 'learning_rate': 3.7466666666666665e-05, 'epoch': 2.01}
{'loss': 0.0023, 'grad_norm': 0.61

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0019938042387366295, 'eval_runtime': 137.5411, 'eval_samples_per_second': 1454.111, 'eval_steps_per_second': 36.353, 'epoch': 3.0}
{'loss': 0.0018, 'grad_norm': 0.0327165462076664, 'learning_rate': 3.124583333333334e-05, 'epoch': 3.0}
{'loss': 0.0024, 'grad_norm': 0.12383757531642914, 'learning_rate': 3.124166666666667e-05, 'epoch': 3.0}
{'loss': 0.0027, 'grad_norm': 0.2164842039346695, 'learning_rate': 3.12375e-05, 'epoch': 3.0}
{'loss': 0.0032, 'grad_norm': 0.31993216276168823, 'learning_rate': 3.123333333333334e-05, 'epoch': 3.0}
{'loss': 0.0019, 'grad_norm': 0.1263100951910019, 'learning_rate': 3.122916666666667e-05, 'epoch': 3.0}
{'loss': 0.0029, 'grad_norm': 0.1570439636707306, 'learning_rate': 3.122500000000001e-05, 'epoch': 3.0}
{'loss': 0.0013, 'grad_norm': 0.0845729410648346, 'learning_rate': 3.122083333333333e-05, 'epoch': 3.0}
{'loss': 0.0022, 'grad_norm': 0.07238538563251495, 'learning_rate': 3.121666666666667e-05, 'epoch': 3.01}
{'loss': 0.0023, 'grad_norm

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0021788356825709343, 'eval_runtime': 132.7203, 'eval_samples_per_second': 1506.929, 'eval_steps_per_second': 37.673, 'epoch': 4.0}
{'loss': 0.002, 'grad_norm': 0.09811139106750488, 'learning_rate': 2.4995833333333336e-05, 'epoch': 4.0}
{'loss': 0.0025, 'grad_norm': 0.3262749910354614, 'learning_rate': 2.499166666666667e-05, 'epoch': 4.0}
{'loss': 0.0014, 'grad_norm': 0.4010801315307617, 'learning_rate': 2.49875e-05, 'epoch': 4.0}
{'loss': 0.0026, 'grad_norm': 0.09367120265960693, 'learning_rate': 2.4983333333333335e-05, 'epoch': 4.0}
{'loss': 0.0018, 'grad_norm': 0.5536375641822815, 'learning_rate': 2.4979166666666666e-05, 'epoch': 4.0}
{'loss': 0.0017, 'grad_norm': 0.2688104808330536, 'learning_rate': 2.4975e-05, 'epoch': 4.0}
{'loss': 0.0017, 'grad_norm': 0.3142812252044678, 'learning_rate': 2.4970833333333335e-05, 'epoch': 4.0}
{'loss': 0.0017, 'grad_norm': 0.17753471434116364, 'learning_rate': 2.496666666666667e-05, 'epoch': 4.01}
{'loss': 0.0012, 'grad_norm': 0.428

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0017535408260300756, 'eval_runtime': 133.0779, 'eval_samples_per_second': 1502.879, 'eval_steps_per_second': 37.572, 'epoch': 5.0}
{'loss': 0.002, 'grad_norm': 0.18918859958648682, 'learning_rate': 1.8745833333333336e-05, 'epoch': 5.0}
{'loss': 0.0012, 'grad_norm': 0.5160457491874695, 'learning_rate': 1.8741666666666667e-05, 'epoch': 5.0}
{'loss': 0.0017, 'grad_norm': 0.034882787615060806, 'learning_rate': 1.87375e-05, 'epoch': 5.0}
{'loss': 0.0018, 'grad_norm': 0.18600419163703918, 'learning_rate': 1.8733333333333332e-05, 'epoch': 5.0}
{'loss': 0.0017, 'grad_norm': 0.3147324025630951, 'learning_rate': 1.8729166666666667e-05, 'epoch': 5.0}
{'loss': 0.0015, 'grad_norm': 0.10472051799297333, 'learning_rate': 1.8725e-05, 'epoch': 5.0}
{'loss': 0.0027, 'grad_norm': 0.18434591591358185, 'learning_rate': 1.8720833333333335e-05, 'epoch': 5.0}
{'loss': 0.0014, 'grad_norm': 0.21627026796340942, 'learning_rate': 1.871666666666667e-05, 'epoch': 5.01}
{'loss': 0.0022, 'grad_norm': 

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0017805879469960928, 'eval_runtime': 135.2403, 'eval_samples_per_second': 1478.849, 'eval_steps_per_second': 36.971, 'epoch': 6.0}
{'loss': 0.0017, 'grad_norm': 0.28740590810775757, 'learning_rate': 1.2495833333333335e-05, 'epoch': 6.0}
{'loss': 0.0016, 'grad_norm': 0.0775669738650322, 'learning_rate': 1.2491666666666668e-05, 'epoch': 6.0}
{'loss': 0.0015, 'grad_norm': 0.06404750794172287, 'learning_rate': 1.24875e-05, 'epoch': 6.0}
{'loss': 0.0013, 'grad_norm': 0.19354648888111115, 'learning_rate': 1.2483333333333335e-05, 'epoch': 6.0}
{'loss': 0.002, 'grad_norm': 0.21486331522464752, 'learning_rate': 1.2479166666666667e-05, 'epoch': 6.0}
{'loss': 0.0016, 'grad_norm': 0.35671329498291016, 'learning_rate': 1.2475e-05, 'epoch': 6.0}
{'loss': 0.0025, 'grad_norm': 0.27761417627334595, 'learning_rate': 1.2470833333333334e-05, 'epoch': 6.0}
{'loss': 0.0012, 'grad_norm': 0.15700678527355194, 'learning_rate': 1.2466666666666667e-05, 'epoch': 6.01}
{'loss': 0.0015, 'grad_norm':

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0016772671369835734, 'eval_runtime': 134.1928, 'eval_samples_per_second': 1490.393, 'eval_steps_per_second': 37.26, 'epoch': 7.0}
{'loss': 0.0015, 'grad_norm': 0.07303564995527267, 'learning_rate': 6.245833333333334e-06, 'epoch': 7.0}
{'loss': 0.0011, 'grad_norm': 0.4271428883075714, 'learning_rate': 6.241666666666667e-06, 'epoch': 7.0}
{'loss': 0.0018, 'grad_norm': 0.07173795998096466, 'learning_rate': 6.2375e-06, 'epoch': 7.0}
{'loss': 0.0019, 'grad_norm': 0.07286680489778519, 'learning_rate': 6.2333333333333335e-06, 'epoch': 7.0}
{'loss': 0.0017, 'grad_norm': 0.12461357563734055, 'learning_rate': 6.229166666666667e-06, 'epoch': 7.0}
{'loss': 0.0015, 'grad_norm': 0.04122012108564377, 'learning_rate': 6.2250000000000005e-06, 'epoch': 7.0}
{'loss': 0.0019, 'grad_norm': 0.2466820329427719, 'learning_rate': 6.220833333333333e-06, 'epoch': 7.0}
{'loss': 0.0019, 'grad_norm': 0.444598913192749, 'learning_rate': 6.2166666666666676e-06, 'epoch': 7.01}
{'loss': 0.0017, 'grad_no

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0016453731805086136, 'eval_runtime': 133.2453, 'eval_samples_per_second': 1500.991, 'eval_steps_per_second': 37.525, 'epoch': 8.0}
{'train_runtime': 6382.1998, 'train_samples_per_second': 752.092, 'train_steps_per_second': 18.802, 'train_loss': 0.0021715190480463206, 'epoch': 8.0}


TrainOutput(global_step=120000, training_loss=0.0021715190480463206, metrics={'train_runtime': 6382.1998, 'train_samples_per_second': 752.092, 'train_steps_per_second': 18.802, 'total_flos': 0.0, 'train_loss': 0.0021715190480463206, 'epoch': 8.0})

In [17]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)

In [18]:
from src.visualization.generate_mesh import generate_mesh
from src.data.load_data import get_data_dir

obj_path = get_data_dir() / 'intermediate' / 'bunny' / 'stanford-bunny.obj'
infered_obj_path = result_dir / f"bunny-{current_date}-generalization.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [None]:
obj_path = get_data_dir() / 'intermediate' / 'dragon' / 'xyzrgb_dragon.obj'
infered_obj_path = result_dir / f"dragon-{current_date}-generalization.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

In [None]:
obj_path = get_data_dir() / 'intermediate' / 'lucy' / 'lucy.obj'
infered_obj_path = result_dir / f"lucy-{current_date}-generalization.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

In [19]:
train_dataset.close()
val_dataset.close()

hdf5_path = get_data_dir() / 'processed' / 'bunny' / 'stanford-bunny_train.hdf5'
infered_obj_path = result_dir / f"{current_date}-train.obj"
generate_mesh(model, hdf5_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [20]:
import importlib
import src.visualization.generate_mesh
importlib.reload(src.visualization.generate_mesh)

<module 'src.visualization.generate_mesh' from 'c:\\_prog\\vm_shared\\attention-sdf\\src\\visualization\\generate_mesh.py'>