In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
%load_ext autoreload
%autoreload 2

In [None]:
from src.models.modules import *
from src.models.loss import L1_epsilon_lambda
from dataclasses import dataclass
import torch

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@dataclass
class SDFTransformerConfig:
    dim_context: int
    dim_input: int
    num_outputs: int
    dim_output: int
    num_seeds: int = 128
    delta: float = 0.1
    dim_hidden: int = 128
    num_inds: int = 10
    num_heads: int = 8

class SDFTransformer(nn.Module):
    def __init__(self, config: SDFTransformerConfig):
        super(SDFTransformer, self).__init__()
        self.config = config
        self.epsilon = None
        self.lambdaa = None
        self.proj1 = nn.Linear(config.dim_input, config.dim_hidden)
        self.proj2 = nn.Linear(config.dim_context, config.dim_hidden)
        self.cross = MAB(config.dim_hidden, config.dim_hidden, config.dim_hidden, config.num_heads)
        self.pma = PMA(config.dim_hidden, config.num_heads, config.num_seeds)

        self.sab1 = SAB(config.dim_hidden, config.dim_hidden, config.num_heads)
        self.silu1 = nn.SiLU()
        self.sab2 = SAB(config.dim_hidden, config.dim_hidden, config.num_heads)
        self.silu2 = nn.SiLU()
        self.sab3 = SAB(config.dim_hidden, config.dim_hidden, config.num_heads)
        self.silu3 = nn.SiLU()
        self.sab4 = SAB(config.dim_hidden, config.dim_hidden, config.num_heads)
        self.silu4 = nn.SiLU()
        self.final = nn.Sequential(
            PMA(config.dim_hidden, config.num_heads, config.num_outputs),
            nn.SiLU(),
            nn.Linear(config.dim_hidden, config.dim_output),
            nn.Tanh()
        )

    def forward(self, context: torch.Tensor, x: torch.Tensor, labels: torch.Tensor = None):
        x = x.expand(-1, context.shape[1], -1)  # [batch_size, context_size, dim_input]
        x = self.proj1(x)                       # [batch_size, context_size, dim_hidden]
        context = self.proj2(context)           # [batch_size, context_size, dim_hidden]
        y = self.cross(x, context)              # [batch_size, context_size, dim_hidden]
        y = self.pma(y)

        residual = y
        y = self.sab1(y) + residual
        y = self.silu1(y)              # [batch_size, context_size, dim_hidden]

        residual = y
        y = self.sab2(y) + residual
        y = self.silu2(y)              # [batch_size, context_size, dim_hidden]

        residual = y
        y = self.sab3(y) + residual
        y = self.silu3(y)              # [batch_size, context_size, dim_hidden]

        residual = y
        y = self.sab4(y) + residual
        y = self.silu4(y)              # [batch_size, context_size, dim_hidden]

        y = self.final(y)              # [batch_size, num_outputs, dim_output]

        loss = None
        if labels is not None:
            loss = L1_epsilon_lambda(y, labels, self.epsilon, self.lambdaa, self.config.delta)
        return {'loss': loss, 'logits': y}

config = SDFTransformerConfig(dim_context=4, dim_input=3, num_outputs=1, dim_output=1)
model = SDFTransformer(config).to(device)
print(device)

cuda


In [3]:
from src.models.dataset import LazySampleDataset
from pathlib import Path

project_dir = Path(os.path.abspath('')).resolve().parent
procesed_dir = project_dir / 'data' / 'processed'

train_files = list(procesed_dir.rglob('*_train.hdf5'))
val_files = list(procesed_dir.rglob('*_val.hdf5'))

train_dataset = LazySampleDataset(train_files)
val_dataset = LazySampleDataset(val_files)

In [4]:
from src.data.load_data import get_results_dir
from datetime import datetime

notebook_name = '2024_12_30_enc_pooling'
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
folder_name = f"{notebook_name}-{current_date}"
result_dir = get_results_dir() / folder_name
result_dir.mkdir(parents=True, exist_ok=True)
print(result_dir)

C:\_prog\vm_shared\attention-sdf\results\2024_12_30_enc_pooling-2025-01-04-11-56-50


In [5]:
from transformers import Trainer, TrainingArguments

batch_size = 40
training_args = TrainingArguments(
    output_dir=result_dir / "results",
    eval_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    logging_dir=result_dir / "logs",
    logging_steps=10,
    weight_decay=0.01,
    save_total_limit=3,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

curriculum_schedule = [
    {"epochs": 2, "epsilon": 0.02,   "lambda": 0.0,  'learning_rate': 5e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0075, "lambda": 0.15, 'learning_rate': 4e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.004,  "lambda": 0.3,  'learning_rate': 3e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.002,  "lambda": 0.4,  'learning_rate': 2e-5, 'resolution': 100},
    {"epochs": 2, "epsilon": 0.0,    "lambda": 0.5,  'learning_rate': 1e-5, 'resolution': 250}
]

In [6]:
from src.visualization.generate_mesh import generate_meshes
from src.data.load_data import get_data_dir

obj_dir = get_data_dir() / 'intermediate'
format_string_base = "{name}-" + current_date + "-curriculum-"

for i, stage in enumerate(curriculum_schedule):
    model.epsilon = stage['epsilon']
    model.lambdaa = stage['lambda']
    trainer.args.num_train_epochs = stage['epochs']
    trainer.args.learning_rate = stage['learning_rate']
    trainer.train()
    format_string = format_string_base + str(i) + ".obj"
    generate_meshes(model, obj_dir, result_dir, format_string, device,
        batch_size, resolution=stage['resolution'], context_size=200)
train_dataset.close()
val_dataset.close()

  0%|          | 0/37500 [00:00<?, ?it/s]

{'loss': 0.0048, 'grad_norm': 0.08449310064315796, 'learning_rate': 4.9986666666666674e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 0.041189663112163544, 'learning_rate': 4.997333333333333e-05, 'epoch': 0.0}
{'loss': 0.0048, 'grad_norm': 0.04594554752111435, 'learning_rate': 4.996e-05, 'epoch': 0.0}
{'loss': 0.0035, 'grad_norm': 0.03996404632925987, 'learning_rate': 4.994666666666667e-05, 'epoch': 0.0}
{'loss': 0.0033, 'grad_norm': 0.04094928875565529, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0043, 'grad_norm': 0.04317784681916237, 'learning_rate': 4.992e-05, 'epoch': 0.0}
{'loss': 0.0035, 'grad_norm': 0.07793587446212769, 'learning_rate': 4.990666666666667e-05, 'epoch': 0.0}
{'loss': 0.0036, 'grad_norm': 0.03746110200881958, 'learning_rate': 4.989333333333334e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 0.07724099606275558, 'learning_rate': 4.9880000000000004e-05, 'epoch': 0.0}
{'loss': 0.0034, 'grad_norm': 0.0795799121260643, 'learning_rate': 4.9866666

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.000753439380787313, 'eval_runtime': 110.6516, 'eval_samples_per_second': 1355.607, 'eval_steps_per_second': 33.89, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.0, 'learning_rate': 2.4986666666666666e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.052078016102313995, 'learning_rate': 2.4973333333333334e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.09420109540224075, 'learning_rate': 2.496e-05, 'epoch': 1.0}
{'loss': 0.0014, 'grad_norm': 0.31952592730522156, 'learning_rate': 2.494666666666667e-05, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.04346800968050957, 'learning_rate': 2.4933333333333334e-05, 'epoch': 1.0}
{'loss': 0.0013, 'grad_norm': 0.05943591520190239, 'learning_rate': 2.4920000000000002e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.051782816648483276, 'learning_rate': 2.4906666666666666e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.0, 'learning_rate': 2.4893333333333334e-05, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.162512868642807, 'l

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0002543237351346761, 'eval_runtime': 119.9497, 'eval_samples_per_second': 1250.525, 'eval_steps_per_second': 31.263, 'epoch': 2.0}
{'train_runtime': 2151.3932, 'train_samples_per_second': 697.223, 'train_steps_per_second': 17.431, 'train_loss': 0.0009822263236985904, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/37500 [00:00<?, ?it/s]

{'loss': 0.0026, 'grad_norm': 0.6090831756591797, 'learning_rate': 4.9986666666666674e-05, 'epoch': 0.0}
{'loss': 0.0031, 'grad_norm': 0.3823338747024536, 'learning_rate': 4.997333333333333e-05, 'epoch': 0.0}
{'loss': 0.0033, 'grad_norm': 0.25000736117362976, 'learning_rate': 4.996e-05, 'epoch': 0.0}
{'loss': 0.0023, 'grad_norm': 0.05601964145898819, 'learning_rate': 4.994666666666667e-05, 'epoch': 0.0}
{'loss': 0.0021, 'grad_norm': 0.16791026294231415, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.002, 'grad_norm': 0.21148572862148285, 'learning_rate': 4.992e-05, 'epoch': 0.0}
{'loss': 0.002, 'grad_norm': 0.2003089338541031, 'learning_rate': 4.990666666666667e-05, 'epoch': 0.0}
{'loss': 0.0015, 'grad_norm': 0.253536194562912, 'learning_rate': 4.989333333333334e-05, 'epoch': 0.0}
{'loss': 0.0011, 'grad_norm': 0.09144023805856705, 'learning_rate': 4.9880000000000004e-05, 'epoch': 0.0}
{'loss': 0.0013, 'grad_norm': 0.04289136081933975, 'learning_rate': 4.98666666666666

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0006534381536766887, 'eval_runtime': 108.3543, 'eval_samples_per_second': 1384.347, 'eval_steps_per_second': 34.609, 'epoch': 1.0}
{'loss': 0.0002, 'grad_norm': 0.0, 'learning_rate': 2.4986666666666666e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.0, 'learning_rate': 2.4973333333333334e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.10837516188621521, 'learning_rate': 2.496e-05, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.04873130843043327, 'learning_rate': 2.494666666666667e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.0621168427169323, 'learning_rate': 2.4933333333333334e-05, 'epoch': 1.0}
{'loss': 0.0012, 'grad_norm': 0.06710529327392578, 'learning_rate': 2.4920000000000002e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.06275910139083862, 'learning_rate': 2.4906666666666666e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_norm': 0.16088691353797913, 'learning_rate': 2.4893333333333334e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.16446247696876526, '

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.00035120954271405935, 'eval_runtime': 123.0434, 'eval_samples_per_second': 1219.082, 'eval_steps_per_second': 30.477, 'epoch': 2.0}
{'train_runtime': 2239.9111, 'train_samples_per_second': 669.669, 'train_steps_per_second': 16.742, 'train_loss': 0.000711833354877308, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/37500 [00:00<?, ?it/s]

{'loss': 0.0014, 'grad_norm': 0.3196641504764557, 'learning_rate': 4.9986666666666674e-05, 'epoch': 0.0}
{'loss': 0.002, 'grad_norm': 0.2454535812139511, 'learning_rate': 4.997333333333333e-05, 'epoch': 0.0}
{'loss': 0.0024, 'grad_norm': 0.20906956493854523, 'learning_rate': 4.996e-05, 'epoch': 0.0}
{'loss': 0.0023, 'grad_norm': 0.3227406442165375, 'learning_rate': 4.994666666666667e-05, 'epoch': 0.0}
{'loss': 0.0013, 'grad_norm': 0.28416019678115845, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0013, 'grad_norm': 0.1682884693145752, 'learning_rate': 4.992e-05, 'epoch': 0.0}
{'loss': 0.0013, 'grad_norm': 0.1662338674068451, 'learning_rate': 4.990666666666667e-05, 'epoch': 0.0}
{'loss': 0.0016, 'grad_norm': 0.25210702419281006, 'learning_rate': 4.989333333333334e-05, 'epoch': 0.0}
{'loss': 0.0012, 'grad_norm': 0.14098139107227325, 'learning_rate': 4.9880000000000004e-05, 'epoch': 0.0}
{'loss': 0.0019, 'grad_norm': 0.506744921207428, 'learning_rate': 4.986666666666667

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0006114892894402146, 'eval_runtime': 111.0521, 'eval_samples_per_second': 1350.717, 'eval_steps_per_second': 33.768, 'epoch': 1.0}
{'loss': 0.0004, 'grad_norm': 0.031418558210134506, 'learning_rate': 2.4986666666666666e-05, 'epoch': 1.0}
{'loss': 0.0008, 'grad_norm': 0.06309918314218521, 'learning_rate': 2.4973333333333334e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.4285775423049927, 'learning_rate': 2.496e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.09602037817239761, 'learning_rate': 2.494666666666667e-05, 'epoch': 1.0}
{'loss': 0.0003, 'grad_norm': 0.11582169681787491, 'learning_rate': 2.4933333333333334e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.0551363080739975, 'learning_rate': 2.4920000000000002e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.08951546251773834, 'learning_rate': 2.4906666666666666e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.0, 'learning_rate': 2.4893333333333334e-05, 'epoch': 1.0}
{'loss': 0.0005, 'grad_norm': 0.2726

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.000360486883437261, 'eval_runtime': 107.8858, 'eval_samples_per_second': 1390.359, 'eval_steps_per_second': 34.759, 'epoch': 2.0}
{'train_runtime': 2173.4506, 'train_samples_per_second': 690.147, 'train_steps_per_second': 17.254, 'train_loss': 0.0007437365419433142, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/37500 [00:00<?, ?it/s]

{'loss': 0.0005, 'grad_norm': 0.3268005847930908, 'learning_rate': 4.9986666666666674e-05, 'epoch': 0.0}
{'loss': 0.0013, 'grad_norm': 0.3546949326992035, 'learning_rate': 4.997333333333333e-05, 'epoch': 0.0}
{'loss': 0.0022, 'grad_norm': 0.8790349960327148, 'learning_rate': 4.996e-05, 'epoch': 0.0}
{'loss': 0.002, 'grad_norm': 0.6683709621429443, 'learning_rate': 4.994666666666667e-05, 'epoch': 0.0}
{'loss': 0.0011, 'grad_norm': 0.1291065216064453, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0015, 'grad_norm': 0.24150680005550385, 'learning_rate': 4.992e-05, 'epoch': 0.0}
{'loss': 0.0016, 'grad_norm': 0.2967110574245453, 'learning_rate': 4.990666666666667e-05, 'epoch': 0.0}
{'loss': 0.0022, 'grad_norm': 0.5078425407409668, 'learning_rate': 4.989333333333334e-05, 'epoch': 0.0}
{'loss': 0.0015, 'grad_norm': 0.11441376060247421, 'learning_rate': 4.9880000000000004e-05, 'epoch': 0.0}
{'loss': 0.0017, 'grad_norm': 0.12426350265741348, 'learning_rate': 4.986666666666667

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0008834605105221272, 'eval_runtime': 108.5713, 'eval_samples_per_second': 1381.581, 'eval_steps_per_second': 34.54, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.2671215534210205, 'learning_rate': 2.4986666666666666e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.20100809633731842, 'learning_rate': 2.4973333333333334e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.5473935604095459, 'learning_rate': 2.496e-05, 'epoch': 1.0}
{'loss': 0.0007, 'grad_norm': 0.11832874268293381, 'learning_rate': 2.494666666666667e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.29597118496894836, 'learning_rate': 2.4933333333333334e-05, 'epoch': 1.0}
{'loss': 0.0011, 'grad_norm': 0.3091585338115692, 'learning_rate': 2.4920000000000002e-05, 'epoch': 1.0}
{'loss': 0.001, 'grad_norm': 0.2578931748867035, 'learning_rate': 2.4906666666666666e-05, 'epoch': 1.0}
{'loss': 0.0006, 'grad_norm': 0.06677522510290146, 'learning_rate': 2.4893333333333334e-05, 'epoch': 1.0}
{'loss': 0.0009, 'grad_no

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0006091163377277553, 'eval_runtime': 112.292, 'eval_samples_per_second': 1335.803, 'eval_steps_per_second': 33.395, 'epoch': 2.0}
{'train_runtime': 2168.4807, 'train_samples_per_second': 691.729, 'train_steps_per_second': 17.293, 'train_loss': 0.0009804468535383542, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/25000 [00:00<?, ?it/s]

  0%|          | 0/37500 [00:00<?, ?it/s]

{'loss': 0.0022, 'grad_norm': 0.16228075325489044, 'learning_rate': 4.9986666666666674e-05, 'epoch': 0.0}
{'loss': 0.0034, 'grad_norm': 0.6336237788200378, 'learning_rate': 4.997333333333333e-05, 'epoch': 0.0}
{'loss': 0.0033, 'grad_norm': 0.6088259816169739, 'learning_rate': 4.996e-05, 'epoch': 0.0}
{'loss': 0.0027, 'grad_norm': 0.19537419080734253, 'learning_rate': 4.994666666666667e-05, 'epoch': 0.0}
{'loss': 0.0024, 'grad_norm': 0.23564021289348602, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 0.0026, 'grad_norm': 0.8731986880302429, 'learning_rate': 4.992e-05, 'epoch': 0.0}
{'loss': 0.0024, 'grad_norm': 0.8727937340736389, 'learning_rate': 4.990666666666667e-05, 'epoch': 0.0}
{'loss': 0.0033, 'grad_norm': 0.7255638241767883, 'learning_rate': 4.989333333333334e-05, 'epoch': 0.0}
{'loss': 0.0036, 'grad_norm': 0.9163646697998047, 'learning_rate': 4.9880000000000004e-05, 'epoch': 0.0}
{'loss': 0.0037, 'grad_norm': 0.7639572620391846, 'learning_rate': 4.98666666666666

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.002313608303666115, 'eval_runtime': 108.3927, 'eval_samples_per_second': 1383.856, 'eval_steps_per_second': 34.596, 'epoch': 1.0}
{'loss': 0.0023, 'grad_norm': 1.0666489601135254, 'learning_rate': 2.4986666666666666e-05, 'epoch': 1.0}
{'loss': 0.0024, 'grad_norm': 0.07823537290096283, 'learning_rate': 2.4973333333333334e-05, 'epoch': 1.0}
{'loss': 0.0021, 'grad_norm': 0.11029788851737976, 'learning_rate': 2.496e-05, 'epoch': 1.0}
{'loss': 0.0019, 'grad_norm': 0.46448779106140137, 'learning_rate': 2.494666666666667e-05, 'epoch': 1.0}
{'loss': 0.0019, 'grad_norm': 0.6968131065368652, 'learning_rate': 2.4933333333333334e-05, 'epoch': 1.0}
{'loss': 0.0027, 'grad_norm': 0.21236149966716766, 'learning_rate': 2.4920000000000002e-05, 'epoch': 1.0}
{'loss': 0.0022, 'grad_norm': 0.7125865817070007, 'learning_rate': 2.4906666666666666e-05, 'epoch': 1.0}
{'loss': 0.0019, 'grad_norm': 0.102008156478405, 'learning_rate': 2.4893333333333334e-05, 'epoch': 1.0}
{'loss': 0.0023, 'grad_no

  0%|          | 0/3750 [00:00<?, ?it/s]

{'eval_loss': 0.0019353643292561173, 'eval_runtime': 100.621, 'eval_samples_per_second': 1490.743, 'eval_steps_per_second': 37.269, 'epoch': 2.0}
{'train_runtime': 2127.1845, 'train_samples_per_second': 705.157, 'train_steps_per_second': 17.629, 'train_loss': 0.0023490064614017803, 'epoch': 2.0}


Processing models:   0%|          | 0/3 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/390625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/390625 [00:00<?, ?it/s]

Processing batches:   0%|          | 0/390625 [00:00<?, ?it/s]

In [7]:
import json
from dataclasses import asdict

current_date = datetime.now().strftime("%Y-%m-%d")
model_name = f"{current_date}-model"
config_name = f"{current_date}-config.json"
trainer.save_model(result_dir / model_name)

with open(result_dir / config_name, 'w') as f:
    json.dump(asdict(config), f)