In [1]:
import sys
from pathlib import Path

# Find project root by looking for .git or requirements.txt
current = Path.cwd()
while not any((current / marker).exists() for marker in ['.git', 'requirements.txt']):
    if current.parent == current:
        raise FileNotFoundError("Project root not found")
    current = current.parent

sys.path.append(str(current))
print(f"Added project root: {current}")

Added project root: /notebooks


In [2]:
from utils.io import load_chkpt
import torch
import numpy as np
import torch.utils
import matplotlib.pyplot as plt
import torch.optim as optim

import utils.visualize
from trainers.basetrainer import BaseTrainer, create_trainer_from_chkpt
import losses
import vae_models
from datasets import get_dataset
from utils.io import find_optimal_num_workers
from metrics.utils import MetricAggregator
from utils.io import print_chkpt_info
import pprint

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
chkpt_test = load_chkpt('checkpoints/tests/test-epoch-1.pt', device='cpu')

Checkpoint loaded from checkpoints/tests/test-epoch-1.pt on cpu.


In [None]:
pprint.pprint(chkpt_test.keys())

dict_keys(['train_id', 'train_step_unit', 'train_step_num', 'train_seed', 'train_determinism_type', 'train_device', 'use_torch_compile', 'model', 'loss', 'optimizer', 'lr_scheduler', 'dataset', 'dataloader', 'metrics'])


In [6]:
print_chkpt_info(chkpt_test)

Checkpoint Information:
  Train ID: 657b5a99-5f00-413c-907f-b4609ede4e95
  Train Step Unit: epoch
  Train Step Number: 1
  Train Seed: 0
  Train Determinism Type: full
  Use Torch Compile: False
#### Model ####
  Model Name: vae_locatello
  Model kwargs: {'img_size': torch.Size([3, 64, 64]), 'latent_dim': 10, 'encoder_decay': 0.0, 'decoder_decay': 0.0, 'decoder_output_dist': 'bernoulli'}
#### Loss ####
  Loss Name: betavae
  Loss kwargs: {'beta': 16, 'log_kl_components': True, 'rec_dist': 'gaussian'}
#### Dataset ####
  Dataset Name: shapes3d
  Dataset kwargs: {'selected_factors': 'all', 'not_selected_factors_index_value': None, 'root': 'data/shapes3d/', 'subset': 1}
#### Dataloader ####
  Dataloader kwargs: {'batch_size': 64, 'shuffle': True, 'num_workers': 4, 'pin_memory': True, 'seed': 0, 'persistent_workers': True, 'in_order': True, 'snapshot_every_n_steps': 1}
#### Optimizer ####
  Optimizer Name: Adam
  LR Scheduler Name: ConstantLR


In [7]:
chkpt_test['dataloader']['state_dict']

{'_snapshot': {'_snapshot_step': 7500,
  '_last_yielded_worker_id': 3,
  '_main_snapshot': {'_num_workers': 4,
   '_sampler_iter_state': {'samples_yielded': 480000,
    'sampler_iter_state': {'yielded': 480000,
     'generator': tensor([0, 0, 0,  ..., 0, 0, 0], dtype=torch.uint8)}},
   '_index_sampler_state': None,
   '_sampler_iter_yielded': 7500,
   '_IterableDataset_len_called': None,
   '_shared_seed': None,
   '_base_seed': 1559385941720436641},
  '_worker_snapshots': {'worker_0': {'worker_id': 0,
    'dataset_state': None,
    'fetcher_state': None},
   'worker_1': {'worker_id': 1, 'dataset_state': None, 'fetcher_state': None},
   'worker_2': {'worker_id': 2, 'dataset_state': None, 'fetcher_state': None},
   'worker_3': {'worker_id': 3,
    'dataset_state': None,
    'fetcher_state': None}}},
 '_steps_since_snapshot': 0,
 '_iterator_finished': True}

# Training tests

In [8]:
trainer = create_trainer_from_chkpt(chkpt_test, additional_trainer_kwargs={'return_log_loss': True})

In [9]:
trainer.train(max_steps=1)

                                                                                                                                                                                                                                                                        

{'logs': [{'loss': 333.4381298828125,
   'rec_loss': 169.05564193725587,
   'kl_loss': 10.27390549659729,
   'kl_loss_0': 0.030646619405597448,
   'kl_loss_1': 1.7619849276542663,
   'kl_loss_2': 1.4231387066841126,
   'kl_loss_3': 1.5545745992660522,
   'kl_loss_4': 0.22029912158846854,
   'kl_loss_5': 1.4461258447170258,
   'kl_loss_6': 0.20769223108887672,
   'kl_loss_7': 1.5592448616027832,
   'kl_loss_8': 1.7783469414711,
   'kl_loss_9': 0.2918515999615192,
   'iteration': 100},
  {'loss': 333.32938720703123,
   'rec_loss': 169.48809616088866,
   'kl_loss': 10.240080699920654,
   'kl_loss_0': 0.02988694705069065,
   'kl_loss_1': 1.7572722828388214,
   'kl_loss_2': 1.4250492870807647,
   'kl_loss_3': 1.5400812697410584,
   'kl_loss_4': 0.2319072449207306,
   'kl_loss_5': 1.4295591914653778,
   'kl_loss_6': 0.20516439527273178,
   'kl_loss_7': 1.5601705741882324,
   'kl_loss_8': 1.7603284156322478,
   'kl_loss_9': 0.3006611043214798,
   'iteration': 200},
  {'loss': 332.560613403320