In [1]:
import sys
from pathlib import Path

# Find project root by looking for .git or requirements.txt
current = Path.cwd()
while not any((current / marker).exists() for marker in ['.git', 'requirements.txt']):
    if current.parent == current:
        raise FileNotFoundError("Project root not found")
    current = current.parent

sys.path.append(str(current))
print(f"Added project root: {current}")

Added project root: /notebooks


In [2]:
from utils.io import load_chkpt
import torch
from trainers.basetrainer import BaseTrainer, create_trainer_from_chkpt
from utils.io import print_chkpt_info
import pprint

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
chkpt_test = load_chkpt('checkpoints/tests/test-epoch-1.pt')

Checkpoint loaded from checkpoints/tests/test-epoch-1.pt on original.


In [5]:
pprint.pprint(chkpt_test.keys())

dict_keys(['train_id', 'train_step_unit', 'train_step_num', 'train_determinism_kwargs', 'train_device', 'use_torch_compile', 'model', 'loss', 'optimizer', 'lr_scheduler', 'dataset', 'dataloader', 'metrics'])


In [6]:
print_chkpt_info(chkpt_test)

Checkpoint Information:
  Train ID: 256fbce3-a727-4155-a554-1a67b118e42a
  Train Step Unit: epoch
  Train Step Number: 1
  Train determinism kwargs: {'seed': 0, 'use_cuda_deterministic': None, 'cublas_workspace_config': None}
  Use Torch Compile: False
#### Model ####
  Model Name: vae_locatello
  Model kwargs: {'img_size': torch.Size([3, 64, 64]), 'latent_dim': 10, 'encoder_decay': 0.0, 'decoder_decay': 0.0, 'decoder_output_dist': 'bernoulli'}
#### Loss ####
  Loss Name: betavae
  Loss kwargs: {'beta': 16, 'log_kl_components': True, 'rec_dist': 'gaussian'}
#### Dataset ####
  Dataset Name: shapes3d
  Dataset kwargs: {'selected_factors': 'all', 'not_selected_factors_index_value': None, 'root': 'data/shapes3d/', 'subset': 1}
#### Dataloader ####
  Dataloader kwargs: {'batch_size': 64, 'shuffle': True, 'num_workers': 4, 'pin_memory': True, 'seed': 0, 'persistent_workers': True, 'in_order': True, 'snapshot_every_n_steps': 1}
#### Optimizer ####
  Optimizer Name: Adam
  LR Scheduler Name: 

In [7]:
chkpt_test['dataloader']['state_dict']

{'_snapshot': {'_snapshot_step': 7500,
  '_last_yielded_worker_id': 3,
  '_main_snapshot': {'_num_workers': 4,
   '_sampler_iter_state': {'samples_yielded': 480000,
    'sampler_iter_state': {'yielded': 480000,
     'generator': tensor([0, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8)}},
   '_index_sampler_state': None,
   '_sampler_iter_yielded': 7500,
   '_IterableDataset_len_called': None,
   '_shared_seed': None,
   '_base_seed': 1559385941720436641},
  '_worker_snapshots': {'worker_0': {'worker_id': 0,
    'dataset_state': None,
    'fetcher_state': None},
   'worker_1': {'worker_id': 1, 'dataset_state': None, 'fetcher_state': None},
   'worker_2': {'worker_id': 2, 'dataset_state': None, 'fetcher_state': None},
   'worker_3': {'worker_id': 3,
    'dataset_state': None,
    'fetcher_state': None}}},
 '_steps_since_snapshot': 0,
 '_iterator_finished': True}

# Training tests

In [8]:
trainer = create_trainer_from_chkpt(chkpt_test, additional_trainer_kwargs={'return_log_loss': True})

In [9]:
trainer.train(max_steps=1)

                                                                                                                                                                                                                                                                        

{'logs': [{'loss': 333.8346920776367,
   'rec_loss': 170.54660873413087,
   'kl_loss': 10.205505208969116,
   'kl_loss_0': 0.03722667889669538,
   'kl_loss_1': 1.7660546350479125,
   'kl_loss_2': 1.3966964161396027,
   'kl_loss_3': 1.6041313362121583,
   'kl_loss_4': 0.144887069016695,
   'kl_loss_5': 1.432068418264389,
   'kl_loss_6': 0.1330850360542536,
   'kl_loss_7': 1.5386587953567505,
   'kl_loss_8': 1.778789998292923,
   'kl_loss_9': 0.37390682101249695,
   'iteration': 100},
  {'loss': 333.25275939941406,
   'rec_loss': 169.3267138671875,
   'kl_loss': 10.245377855300903,
   'kl_loss_0': 0.034741947017610075,
   'kl_loss_1': 1.7705136632919312,
   'kl_loss_2': 1.4220470046997071,
   'kl_loss_3': 1.5784772610664368,
   'kl_loss_4': 0.1489968989044428,
   'kl_loss_5': 1.4310874330997467,
   'kl_loss_6': 0.12981923237442972,
   'kl_loss_7': 1.5826399290561677,
   'kl_loss_8': 1.7678543627262115,
   'kl_loss_9': 0.37920013546943665,
   'iteration': 200},
  {'loss': 332.773806762695