In [7]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
from src.models.modules import *
from src.models.loss import L1_clamp_loss
from dataclasses import dataclass
import torch

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    num_inds: int = 32
    dim_hidden: int = 128
    num_heads: int = 4
    ln: bool = False

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.enc = nn.Sequential(
            ISAB(config.dim_context, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            ISAB(config.dim_hidden, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln)
        )
        self.input_proj = nn.Sequential(nn.Linear(config.dim_input, config.dim_hidden), nn.ReLU())
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.dec = nn.Sequential(
            PMA(config.dim_hidden, config.num_heads, config.num_outputs, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
        )
        self.regr = nn.Sequential(
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.ReLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        y = self.enc(context)           # [batch_size, context_size, dim_hidden]
        x = self.input_proj(x)          # [batch_size, num_outputs, dim_hidden]
        x = x.repeat(1, y.shape[1], 1)  # [batch_size, context_size, dim_hidden]
        y = self.cross(x, y)            # [batch_size, context_size, dim_hidden]
        y = self.dec(y)                 # [batch_size, num_outputs, dim_hidden]
        y = self.regr(y)                # [batch_size, num_outputs, dim_output]

        loss = None
        if labels is not None:
            loss = L1_clamp_loss(y, labels, self.config.delta)
        return {'loss': loss, 'logits': y}

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [9]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [10]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_11_13_use_proj_regr'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_11_13_use_proj_regr-2024-11-16-09-08-51


In [11]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=8,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [12]:
trainer.train()
train_dataset.close()
val_dataset.close()

  0%|          | 0/240000 [00:00<?, ?it/s]

{'loss': 0.0205, 'grad_norm': 1.1560782194137573, 'learning_rate': 4.999791666666667e-05, 'epoch': 0.0}
{'loss': 0.0088, 'grad_norm': 0.9420439004898071, 'learning_rate': 4.999583333333333e-05, 'epoch': 0.0}
{'loss': 0.008, 'grad_norm': 0.6023547649383545, 'learning_rate': 4.999375e-05, 'epoch': 0.0}
{'loss': 0.0055, 'grad_norm': 0.2019413411617279, 'learning_rate': 4.999166666666667e-05, 'epoch': 0.0}
{'loss': 0.0082, 'grad_norm': 0.4030986726284027, 'learning_rate': 4.9989583333333337e-05, 'epoch': 0.0}
{'loss': 0.0076, 'grad_norm': 0.26906415820121765, 'learning_rate': 4.99875e-05, 'epoch': 0.0}
{'loss': 0.0068, 'grad_norm': 0.7335054874420166, 'learning_rate': 4.998541666666667e-05, 'epoch': 0.0}
{'loss': 0.0064, 'grad_norm': 0.02046429179608822, 'learning_rate': 4.998333333333334e-05, 'epoch': 0.0}
{'loss': 0.0073, 'grad_norm': 0.6006406545639038, 'learning_rate': 4.998125e-05, 'epoch': 0.0}
{'loss': 0.0051, 'grad_norm': 0.7326744198799133, 'learning_rate': 4.997916666666667e-05, 

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0024574368726462126, 'eval_runtime': 129.084, 'eval_samples_per_second': 1549.378, 'eval_steps_per_second': 38.734, 'epoch': 1.0}
{'loss': 0.0042, 'grad_norm': 0.44488587975502014, 'learning_rate': 4.3747916666666665e-05, 'epoch': 1.0}
{'loss': 0.0027, 'grad_norm': 0.2966344356536865, 'learning_rate': 4.374583333333334e-05, 'epoch': 1.0}
{'loss': 0.0015, 'grad_norm': 0.11855438351631165, 'learning_rate': 4.374375e-05, 'epoch': 1.0}
{'loss': 0.003, 'grad_norm': 0.41567957401275635, 'learning_rate': 4.374166666666667e-05, 'epoch': 1.0}
{'loss': 0.0023, 'grad_norm': 0.4744091033935547, 'learning_rate': 4.3739583333333334e-05, 'epoch': 1.0}
{'loss': 0.0028, 'grad_norm': 0.031479187309741974, 'learning_rate': 4.3737500000000006e-05, 'epoch': 1.0}
{'loss': 0.003, 'grad_norm': 0.8706613183021545, 'learning_rate': 4.3735416666666665e-05, 'epoch': 1.0}
{'loss': 0.0028, 'grad_norm': 0.208566814661026, 'learning_rate': 4.373333333333334e-05, 'epoch': 1.0}
{'loss': 0.0035, 'grad_no

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.002387440763413906, 'eval_runtime': 133.7555, 'eval_samples_per_second': 1495.265, 'eval_steps_per_second': 37.382, 'epoch': 2.0}
{'loss': 0.002, 'grad_norm': 0.29052743315696716, 'learning_rate': 3.749791666666667e-05, 'epoch': 2.0}
{'loss': 0.0024, 'grad_norm': 0.5811926126480103, 'learning_rate': 3.7495833333333334e-05, 'epoch': 2.0}
{'loss': 0.002, 'grad_norm': 0.11604063957929611, 'learning_rate': 3.749375e-05, 'epoch': 2.0}
{'loss': 0.0035, 'grad_norm': 0.6026103496551514, 'learning_rate': 3.749166666666667e-05, 'epoch': 2.0}
{'loss': 0.0038, 'grad_norm': 0.029780138283967972, 'learning_rate': 3.748958333333333e-05, 'epoch': 2.0}
{'loss': 0.0047, 'grad_norm': 0.08713049441576004, 'learning_rate': 3.74875e-05, 'epoch': 2.0}
{'loss': 0.0023, 'grad_norm': 0.2030973881483078, 'learning_rate': 3.748541666666667e-05, 'epoch': 2.0}
{'loss': 0.0014, 'grad_norm': 0.3193824589252472, 'learning_rate': 3.7483333333333334e-05, 'epoch': 2.0}
{'loss': 0.0029, 'grad_norm': 0.3740

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0022999930661171675, 'eval_runtime': 127.2647, 'eval_samples_per_second': 1571.527, 'eval_steps_per_second': 39.288, 'epoch': 3.0}
{'loss': 0.0024, 'grad_norm': 0.14391179382801056, 'learning_rate': 3.1247916666666666e-05, 'epoch': 3.0}
{'loss': 0.0026, 'grad_norm': 0.028905276209115982, 'learning_rate': 3.124583333333334e-05, 'epoch': 3.0}
{'loss': 0.0022, 'grad_norm': 0.17226538062095642, 'learning_rate': 3.124375e-05, 'epoch': 3.0}
{'loss': 0.0019, 'grad_norm': 0.17237836122512817, 'learning_rate': 3.124166666666667e-05, 'epoch': 3.0}
{'loss': 0.003, 'grad_norm': 0.11531303077936172, 'learning_rate': 3.1239583333333335e-05, 'epoch': 3.0}
{'loss': 0.0037, 'grad_norm': 0.3159581124782562, 'learning_rate': 3.12375e-05, 'epoch': 3.0}
{'loss': 0.0026, 'grad_norm': 0.08640685677528381, 'learning_rate': 3.1235416666666666e-05, 'epoch': 3.0}
{'loss': 0.0022, 'grad_norm': 0.058215364813804626, 'learning_rate': 3.123333333333334e-05, 'epoch': 3.0}
{'loss': 0.0026, 'grad_norm':

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0021700120996683836, 'eval_runtime': 131.7685, 'eval_samples_per_second': 1517.814, 'eval_steps_per_second': 37.945, 'epoch': 4.0}
{'loss': 0.0014, 'grad_norm': 0.14311973750591278, 'learning_rate': 2.4997916666666667e-05, 'epoch': 4.0}
{'loss': 0.0017, 'grad_norm': 0.7194696068763733, 'learning_rate': 2.4995833333333336e-05, 'epoch': 4.0}
{'loss': 0.0028, 'grad_norm': 0.6768826246261597, 'learning_rate': 2.499375e-05, 'epoch': 4.0}
{'loss': 0.0019, 'grad_norm': 0.006506660487502813, 'learning_rate': 2.499166666666667e-05, 'epoch': 4.0}
{'loss': 0.0017, 'grad_norm': 0.22844423353672028, 'learning_rate': 2.4989583333333335e-05, 'epoch': 4.0}
{'loss': 0.0017, 'grad_norm': 0.059157826006412506, 'learning_rate': 2.49875e-05, 'epoch': 4.0}
{'loss': 0.0022, 'grad_norm': 0.19999556243419647, 'learning_rate': 2.4985416666666666e-05, 'epoch': 4.0}
{'loss': 0.0025, 'grad_norm': 0.14273694157600403, 'learning_rate': 2.4983333333333335e-05, 'epoch': 4.0}
{'loss': 0.0023, 'grad_norm

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.001862390898168087, 'eval_runtime': 130.1655, 'eval_samples_per_second': 1536.506, 'eval_steps_per_second': 38.413, 'epoch': 5.0}
{'loss': 0.0014, 'grad_norm': 0.25742974877357483, 'learning_rate': 1.8747916666666667e-05, 'epoch': 5.0}
{'loss': 0.0015, 'grad_norm': 0.05738663300871849, 'learning_rate': 1.8745833333333336e-05, 'epoch': 5.0}
{'loss': 0.0015, 'grad_norm': 0.42844581604003906, 'learning_rate': 1.874375e-05, 'epoch': 5.0}
{'loss': 0.0023, 'grad_norm': 0.036751482635736465, 'learning_rate': 1.8741666666666667e-05, 'epoch': 5.0}
{'loss': 0.0025, 'grad_norm': 0.2010982483625412, 'learning_rate': 1.8739583333333336e-05, 'epoch': 5.0}
{'loss': 0.0016, 'grad_norm': 0.3495792746543884, 'learning_rate': 1.87375e-05, 'epoch': 5.0}
{'loss': 0.0017, 'grad_norm': 0.05753815174102783, 'learning_rate': 1.8735416666666667e-05, 'epoch': 5.0}
{'loss': 0.0013, 'grad_norm': 0.16243354976177216, 'learning_rate': 1.8733333333333332e-05, 'epoch': 5.0}
{'loss': 0.0022, 'grad_norm'

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0017644831677898765, 'eval_runtime': 131.1669, 'eval_samples_per_second': 1524.774, 'eval_steps_per_second': 38.119, 'epoch': 6.0}
{'loss': 0.0016, 'grad_norm': 0.14249075949192047, 'learning_rate': 1.2497916666666668e-05, 'epoch': 6.0}
{'loss': 0.0018, 'grad_norm': 0.20000126957893372, 'learning_rate': 1.2495833333333335e-05, 'epoch': 6.0}
{'loss': 0.0016, 'grad_norm': 0.17168352007865906, 'learning_rate': 1.249375e-05, 'epoch': 6.0}
{'loss': 0.0021, 'grad_norm': 0.17248734831809998, 'learning_rate': 1.2491666666666668e-05, 'epoch': 6.0}
{'loss': 0.0015, 'grad_norm': 0.11469561606645584, 'learning_rate': 1.2489583333333333e-05, 'epoch': 6.0}
{'loss': 0.0027, 'grad_norm': 0.02967328391969204, 'learning_rate': 1.24875e-05, 'epoch': 6.0}
{'loss': 0.0024, 'grad_norm': 0.1719534695148468, 'learning_rate': 1.2485416666666667e-05, 'epoch': 6.0}
{'loss': 0.002, 'grad_norm': 0.20005375146865845, 'learning_rate': 1.2483333333333335e-05, 'epoch': 6.0}
{'loss': 0.0014, 'grad_norm'

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.001690576784312725, 'eval_runtime': 128.5867, 'eval_samples_per_second': 1555.371, 'eval_steps_per_second': 38.884, 'epoch': 7.0}
{'loss': 0.0012, 'grad_norm': 0.31485480070114136, 'learning_rate': 6.2479166666666675e-06, 'epoch': 7.0}
{'loss': 0.002, 'grad_norm': 0.11539436131715775, 'learning_rate': 6.245833333333334e-06, 'epoch': 7.0}
{'loss': 0.0014, 'grad_norm': 0.197199746966362, 'learning_rate': 6.24375e-06, 'epoch': 7.0}
{'loss': 0.0015, 'grad_norm': 0.05785413458943367, 'learning_rate': 6.241666666666667e-06, 'epoch': 7.0}
{'loss': 0.0016, 'grad_norm': 0.03138536959886551, 'learning_rate': 6.239583333333334e-06, 'epoch': 7.0}
{'loss': 0.0018, 'grad_norm': 0.062295254319906235, 'learning_rate': 6.2375e-06, 'epoch': 7.0}
{'loss': 0.0017, 'grad_norm': 0.14336006343364716, 'learning_rate': 6.235416666666667e-06, 'epoch': 7.0}
{'loss': 0.0017, 'grad_norm': 0.11551631987094879, 'learning_rate': 6.2333333333333335e-06, 'epoch': 7.0}
{'loss': 0.0018, 'grad_norm': 0.086

  0%|          | 0/5000 [00:00<?, ?it/s]

{'eval_loss': 0.0016539497300982475, 'eval_runtime': 126.1587, 'eval_samples_per_second': 1585.305, 'eval_steps_per_second': 39.633, 'epoch': 8.0}
{'train_runtime': 11233.0667, 'train_samples_per_second': 854.62, 'train_steps_per_second': 21.365, 'train_loss': 0.0022443765548523517, 'epoch': 8.0}


In [13]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)

In [14]:
from src.visualization.generate_mesh import generate_mesh
from src.data.load_data import get_data_dir

obj_path = get_data_dir() / 'intermediate' / 'bunny' / 'stanford-bunny.obj'
infered_obj_path = result_dir / f"bunny-{current_date}-generalization.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [15]:
obj_path = get_data_dir() / 'intermediate' / 'dragon' / 'xyzrgb_dragon.obj'
infered_obj_path = result_dir / f"dragon-{current_date}-generalization.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [17]:
obj_path = get_data_dir() / 'intermediate' / 'lucy' / 'lucy.obj'
infered_obj_path = result_dir / f"lucy-{current_date}-generalization.obj"
generate_mesh(model, obj_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [18]:
hdf5_path = get_data_dir() / 'processed' / 'bunny' / 'stanford-bunny_train.hdf5'
infered_obj_path = result_dir / f"{current_date}-train.obj"
generate_mesh(model, hdf5_path, infered_obj_path, device, batch_size, resolution=100)

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]