In [1]:
import sys
from pathlib import Path

# Find project root by looking for .git or requirements.txt
current = Path.cwd()
while not any((current / marker).exists() for marker in ['.git', 'requirements.txt']):
    if current.parent == current:
        raise FileNotFoundError("Project root not found")
    current = current.parent

sys.path.append(str(current))
print(f"Added project root: {current}")

Added project root: /notebooks


In [2]:
from utils.io import load_chkpt
import torch
from trainers.basetrainer import BaseTrainer
from utils.io import create_trainer_from_chkpt
from utils.io import print_chkpt_info
import pprint

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
chkpt_test = load_chkpt('checkpoints/tests/test-epoch-1.pt')

Checkpoint loaded from checkpoints/tests/test-epoch-1.pt on original.


In [5]:
pprint.pprint(chkpt_test.keys())

dict_keys(['train_id', 'train_iter_num', 'train_epoch_num', 'train_determinism_kwargs', 'train_device', 'use_torch_compile', 'torch_compile_kwargs', 'model', 'loss', 'optimizer', 'lr_scheduler', 'dataset', 'dataloader', 'chkpt', 'logging', 'logs'])


In [6]:
print_chkpt_info(chkpt_test)

Checkpoint Information:
  Train ID: 0cfc8791-c05a-4f28-93c8-b687beb19b7f
  Train Iteration Number: 7500
  Train Epoch Number: 1.0
  Train determinism kwargs: {'seed': 0, 'use_cuda_det': True, 'enforce_det': False, 'cublas_workspace_config': None}
  Use Torch Compile: False
#### Model ####
  Model Name: vae_locatello
  Model kwargs: {'img_size': torch.Size([3, 64, 64]), 'latent_dim': 10, 'encoder_decay': 0.0, 'decoder_decay': 0.0, 'decoder_output_dist': 'bernoulli'}
#### Loss ####
  Loss Name: betavae
  Loss kwargs: {'beta': 16, 'log_kl_components': True, 'rec_dist': 'gaussian'}
#### Dataset ####
  Dataset Name: shapes3d
  Dataset kwargs: {'selected_factors': 'all', 'not_selected_factors_index_value': None, 'root': 'data/shapes3d/', 'subset': 1}
#### Dataloader ####
  Dataloader kwargs: {'batch_size': 64, 'shuffle': True, 'num_workers': 4, 'pin_memory': True, 'seed': 0, 'persistent_workers': True, 'in_order': True, 'snapshot_every_n_steps': 1}
#### Optimizer ####
  Optimizer Name: Adam


In [7]:
chkpt_test['logs']['train']

{'loss_results': [{'loss': 1700.9673791503906,
   'rec_loss': 1698.1490899658204,
   'kl_loss': 0.17614290852099657,
   'kl_loss_0': 0.01968884836882353,
   'kl_loss_1': 0.014443990183062852,
   'kl_loss_2': 0.023195226844400167,
   'kl_loss_3': 0.015036357745993883,
   'kl_loss_4': 0.016282279854640366,
   'kl_loss_5': 0.01725930684711784,
   'kl_loss_6': 0.01393447668524459,
   'kl_loss_7': 0.018066423193085938,
   'kl_loss_8': 0.01947542678564787,
   'kl_loss_9': 0.018760571100283414,
   'iter': 200,
   'epoch': 0.02666666666666667},
  {'loss': 1529.9993939208985,
   'rec_loss': 1528.207576904297,
   'kl_loss': 0.11198866175487637,
   'kl_loss_0': 0.014522448838688433,
   'kl_loss_1': 0.006994565953500569,
   'kl_loss_2': 0.018559690262191,
   'kl_loss_3': 0.006991295355837792,
   'kl_loss_4': 0.009086287969257683,
   'kl_loss_5': 0.0087951132748276,
   'kl_loss_6': 0.006859964779578149,
   'kl_loss_7': 0.008521333516109736,
   'kl_loss_8': 0.02548780293436721,
   'kl_loss_9': 0.006

In [8]:
chkpt_test['logging']

{'is_progress_bar': True,
 'progress_bar_log_iter_interval': 50,
 'log_loss_interval_type': 'iter',
 'use_train_logging': True,
 'log_loss_iter_interval': 200,
 'return_log_loss': True}

# Training tests

In [9]:
trainer = create_trainer_from_chkpt(chkpt_test, create_exact=True)

In [10]:
trainer.train_losses_log

[{'loss': 1700.9673791503906,
  'rec_loss': 1698.1490899658204,
  'kl_loss': 0.17614290852099657,
  'kl_loss_0': 0.01968884836882353,
  'kl_loss_1': 0.014443990183062852,
  'kl_loss_2': 0.023195226844400167,
  'kl_loss_3': 0.015036357745993883,
  'kl_loss_4': 0.016282279854640366,
  'kl_loss_5': 0.01725930684711784,
  'kl_loss_6': 0.01393447668524459,
  'kl_loss_7': 0.018066423193085938,
  'kl_loss_8': 0.01947542678564787,
  'kl_loss_9': 0.018760571100283414,
  'iter': 200,
  'epoch': 0.02666666666666667},
 {'loss': 1529.9993939208985,
  'rec_loss': 1528.207576904297,
  'kl_loss': 0.11198866175487637,
  'kl_loss_0': 0.014522448838688433,
  'kl_loss_1': 0.006994565953500569,
  'kl_loss_2': 0.018559690262191,
  'kl_loss_3': 0.006991295355837792,
  'kl_loss_4': 0.009086287969257683,
  'kl_loss_5': 0.0087951132748276,
  'kl_loss_6': 0.006859964779578149,
  'kl_loss_7': 0.008521333516109736,
  'kl_loss_8': 0.02548780293436721,
  'kl_loss_9': 0.006170158155728131,
  'iter': 400,
  'epoch': 0

In [11]:
trainer.train(step_unit='epoch', max_steps=1)

TypeError: object of type 'NoneType' has no len()

In [19]:
trainer.log_loss_interval_type

'iter'