In [29]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
from src.models.modules import *
from src.models.loss import L1_epsilon_lambda
from dataclasses import dataclass
import torch

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    delta: float = 0.1
    num_inds: int = 32
    dim_hidden: int = 128
    num_heads: int = 4
    ln: bool = False

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.lambdaa = None
        self.enc = nn.Sequential(
            ISAB(config.dim_context, config.dim_hidden, config.num_heads, config.num_inds, ln=config.ln),
            nn.SiLU(),
            PMA(config.dim_hidden, config.num_heads, config.num_outputs, ln=config.ln),
            nn.SiLU(),
        )
        self.input_proj = nn.Linear(config.dim_input, config.dim_hidden)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln)
        self.dec = nn.Sequential(
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
            SAB(config.dim_hidden, config.dim_hidden, config.num_heads, ln=config.ln),
            nn.SiLU(),
        )
        self.regr = nn.Sequential(
            nn.Linear(config.dim_hidden, config.dim_hidden),
            nn.SiLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        y = self.enc(context)           # [batch_size, context_size, dim_hidden]
        x = self.input_proj(x)          # [batch_size, num_outputs, dim_hidden]
        x = x.repeat(1, y.shape[1], 1)  # [batch_size, context_size, dim_hidden]
        y = self.cross(x, y)            # [batch_size, context_size, dim_hidden]
        y = self.dec(y)                 # [batch_size, num_outputs, dim_hidden]
        y = self.regr(y)                # [batch_size, num_outputs, dim_output]

        loss = None
        if labels is not None:
            loss = L1_epsilon_lambda(y, labels, self.epsilon, self.lambdaa, self.config.delta)
        return {'loss': loss, 'logits': y}

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [31]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [32]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_12_04_increase_acc'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_12_04_increase_acc-2024-12-04-19-51-38


In [33]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 1, "epsilon": 0.025,  "lambda": 0.0,  'learning_rate': 5e-5},
    {"epochs": 1, "epsilon": 0.02,   "lambda": 0.05, 'learning_rate': 5e-5},
    {"epochs": 1, "epsilon": 0.01,   "lambda": 0.1,  'learning_rate': 4e-5},
    {"epochs": 1, "epsilon": 0.0075, "lambda": 0.15, 'learning_rate': 4e-5},
    {"epochs": 1, "epsilon": 0.005,  "lambda": 0.25, 'learning_rate': 3e-5},
    {"epochs": 1, "epsilon": 0.004,  "lambda": 0.3,  'learning_rate': 3e-5},
    {"epochs": 1, "epsilon": 0.003,  "lambda": 0.35, 'learning_rate': 2e-5},
    {"epochs": 1, "epsilon": 0.002,  "lambda": 0.4,  'learning_rate': 2e-5},
    {"epochs": 1, "epsilon": 0.001,  "lambda": 0.45, 'learning_rate': 1e-5},
    {"epochs": 1, "epsilon": 0.0,    "lambda": 0.5,  'learning_rate': 1e-5}
]

In [34]:
from src.visualization.generate_mesh import generate_mesh
from src.data.load_data import get_data_dir


# obj_path_bunny = get_data_dir() / 'intermediate' / 'bunny' / 'stanford-bunny.obj'
obj_path_lucy = get_data_dir() / 'intermediate' / 'lucy' / 'lucy.obj'

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    model.lambdaa = stage['lambda']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    # infered_bunny_path = result_dir / f"bunny-{current_date}-currciulum-{i}.obj"
    infered_lucy_path = result_dir / f"lucy-{current_date}-currciulum-{i}.obj"
    # generate_mesh(model, obj_path_bunny, infered_bunny_path, device, batch_size, resolution=100)
    generate_mesh(model, obj_path_lucy, infered_lucy_path, device, batch_size, resolution=100)
train_dataset.close()
val_dataset.close()

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.0159, 'grad_norm': 0.09022717922925949, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0044, 'grad_norm': 0.014274707064032555, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0037, 'grad_norm': 0.030520357191562653, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0028, 'grad_norm': 0.09145455807447433, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 0.0923508033156395, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0045, 'grad_norm': 0.03267191722989082, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0027, 'grad_norm': 0.031095322221517563, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0028, 'grad_norm': 0.009401927702128887, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0026, 'grad_norm': 0.1239437460899353, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.0028, 'grad_norm': 0.18335862457752228, 

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.00011210032243980095, 'eval_runtime': 54.6254, 'eval_samples_per_second': 1830.651, 'eval_steps_per_second': 45.766, 'epoch': 1.0}
{'train_runtime': 940.5594, 'train_samples_per_second': 956.877, 'train_steps_per_second': 23.922, 'train_loss': 0.00028339929932526, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.0002, 'grad_norm': 0.09181372821331024, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0002, 'grad_norm': 0.0, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0008, 'grad_norm': 0.11468783020973206, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0002, 'grad_norm': 0.03967037796974182, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.0007, 'grad_norm': 0.0, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0004, 'grad_norm': 0.06570356339216232, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0002, 'grad_norm': 0.12421321868896484, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.08762187510728836, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.04964500293135643, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.0001, 'grad_norm': 0.0, 'learning_rate': 4.977777777777778e-05, 'epoch': 0

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.00010585736890789121, 'eval_runtime': 61.8218, 'eval_samples_per_second': 1617.551, 'eval_steps_per_second': 40.439, 'epoch': 1.0}
{'train_runtime': 1014.214, 'train_samples_per_second': 887.387, 'train_steps_per_second': 22.185, 'train_loss': 0.00017697582694541628, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.0009, 'grad_norm': 0.24218149483203888, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0008, 'grad_norm': 0.09788159281015396, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0009, 'grad_norm': 0.1412358582019806, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0004, 'grad_norm': 0.037668175995349884, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.0009, 'grad_norm': 0.03722699359059334, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0004, 'grad_norm': 0.0, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.14341506361961365, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.09608035534620285, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.10340768098831177, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.0004, 'grad_norm': 0.1338556855916977, 'learning_rate': 4.

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.00016172266623470932, 'eval_runtime': 61.9636, 'eval_samples_per_second': 1613.85, 'eval_steps_per_second': 40.346, 'epoch': 1.0}
{'train_runtime': 962.5205, 'train_samples_per_second': 935.045, 'train_steps_per_second': 23.376, 'train_loss': 0.00023154844160512515, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.0002, 'grad_norm': 0.04581836611032486, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0002, 'grad_norm': 0.0, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.1033308282494545, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.03347720205783844, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.0009, 'grad_norm': 0.07099959999322891, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0008, 'grad_norm': 0.14334772527217865, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.07041116058826447, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0003, 'grad_norm': 0.0361415259540081, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0004, 'grad_norm': 0.0871063768863678, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.10505471378564835, 'learning_rate': 4.97

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.00018044149328488857, 'eval_runtime': 61.1578, 'eval_samples_per_second': 1635.115, 'eval_steps_per_second': 40.878, 'epoch': 1.0}
{'train_runtime': 959.9556, 'train_samples_per_second': 937.543, 'train_steps_per_second': 23.439, 'train_loss': 0.00024183624571045482, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.0003, 'grad_norm': 0.09778520464897156, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0003, 'grad_norm': 0.12896372377872467, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0003, 'grad_norm': 0.021979933604598045, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.03220590204000473, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.001, 'grad_norm': 0.016033001244068146, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.028127865865826607, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.21219925582408905, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.1336585432291031, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.2558695077896118, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.0003, 'grad_norm': 0.04306492954492569, 'l

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.00021685568208340555, 'eval_runtime': 60.6553, 'eval_samples_per_second': 1648.66, 'eval_steps_per_second': 41.216, 'epoch': 1.0}
{'train_runtime': 944.9656, 'train_samples_per_second': 952.416, 'train_steps_per_second': 23.81, 'train_loss': 0.0002827535301566564, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.0003, 'grad_norm': 0.04616709053516388, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0002, 'grad_norm': 0.037300027906894684, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0004, 'grad_norm': 0.07655368745326996, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0007, 'grad_norm': 0.07910392433404922, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.001, 'grad_norm': 0.12033037841320038, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.3344085216522217, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0008, 'grad_norm': 0.12582708895206451, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0007, 'grad_norm': 0.13040506839752197, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0007, 'grad_norm': 0.0805378332734108, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.0004, 'grad_norm': 0.06183067709207535, 'lea

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.00024876274983398616, 'eval_runtime': 61.2795, 'eval_samples_per_second': 1631.867, 'eval_steps_per_second': 40.797, 'epoch': 1.0}
{'train_runtime': 941.1624, 'train_samples_per_second': 956.264, 'train_steps_per_second': 23.907, 'train_loss': 0.00031950801065928924, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.0004, 'grad_norm': 0.061309657990932465, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0002, 'grad_norm': 0.038905296474695206, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.1894746720790863, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.3205696940422058, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.0011, 'grad_norm': 0.15195107460021973, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0007, 'grad_norm': 0.27876394987106323, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0007, 'grad_norm': 0.20867560803890228, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0009, 'grad_norm': 0.21661797165870667, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0009, 'grad_norm': 0.18501155078411102, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.05979115888476372, 'l

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.0003343481512274593, 'eval_runtime': 61.0943, 'eval_samples_per_second': 1636.813, 'eval_steps_per_second': 40.92, 'epoch': 1.0}
{'train_runtime': 939.9467, 'train_samples_per_second': 957.501, 'train_steps_per_second': 23.938, 'train_loss': 0.0004113520976440567, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.0006, 'grad_norm': 0.04102453961968422, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.04328245669603348, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0007, 'grad_norm': 0.1424359530210495, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0008, 'grad_norm': 0.12173876166343689, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.0011, 'grad_norm': 0.17215897142887115, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.5649046301841736, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0005, 'grad_norm': 0.19372986257076263, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0007, 'grad_norm': 0.20066571235656738, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.08338965475559235, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.0006, 'grad_norm': 0.11726898699998856, 'lea

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.0005213418044149876, 'eval_runtime': 61.9206, 'eval_samples_per_second': 1614.972, 'eval_steps_per_second': 40.374, 'epoch': 1.0}
{'train_runtime': 939.8552, 'train_samples_per_second': 957.594, 'train_steps_per_second': 23.94, 'train_loss': 0.0006064030581878292, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.001, 'grad_norm': 0.05113563314080238, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0009, 'grad_norm': 0.1342361718416214, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0011, 'grad_norm': 0.21435756981372833, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0014, 'grad_norm': 0.2374485433101654, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.0015, 'grad_norm': 0.23829999566078186, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.0012, 'grad_norm': 0.28549450635910034, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0015, 'grad_norm': 0.11873136460781097, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0012, 'grad_norm': 0.3130949139595032, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0016, 'grad_norm': 0.06345627456903458, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.0011, 'grad_norm': 0.2768242359161377, 'learni

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.0009110842947848141, 'eval_runtime': 61.2223, 'eval_samples_per_second': 1633.392, 'eval_steps_per_second': 40.835, 'epoch': 1.0}
{'train_runtime': 939.5146, 'train_samples_per_second': 957.941, 'train_steps_per_second': 23.949, 'train_loss': 0.0010066576908652981, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/22500 [00:00<?, ?it/s]

{'loss': 0.0019, 'grad_norm': 0.09865419566631317, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 0.0019, 'grad_norm': 0.23214419186115265, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 0.0021, 'grad_norm': 0.5407238006591797, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0024, 'grad_norm': 0.27176403999328613, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.0}
{'loss': 0.0026, 'grad_norm': 0.09770296514034271, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.0}
{'loss': 0.002, 'grad_norm': 0.6619182229042053, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.0}
{'loss': 0.0019, 'grad_norm': 0.43282681703567505, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.0}
{'loss': 0.0021, 'grad_norm': 0.1731690913438797, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.0}
{'loss': 0.0019, 'grad_norm': 0.04089946672320366, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.0}
{'loss': 0.002, 'grad_norm': 0.28429698944091797, 'learni

  0%|          | 0/2500 [00:00<?, ?it/s]

{'eval_loss': 0.0018537439173087478, 'eval_runtime': 59.7354, 'eval_samples_per_second': 1674.05, 'eval_steps_per_second': 41.851, 'epoch': 1.0}
{'train_runtime': 981.7417, 'train_samples_per_second': 916.738, 'train_steps_per_second': 22.918, 'train_loss': 0.0019598631338940725, 'epoch': 1.0}


Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

In [35]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)

In [36]:
infered_lucy_path = result_dir / f"lucy-{current_date}-currciulum-{i}-hd.obj"
generate_mesh(model, obj_path_lucy, infered_lucy_path, device, batch_size, resolution=200)

Processing batches:   0%|          | 0/200000 [00:00<?, ?it/s]