In [1]:
import sys
from pathlib import Path

# Find project root by looking for .git or requirements.txt
current = Path.cwd()
while not any((current / marker).exists() for marker in ['.git', 'requirements.txt']):
    if current.parent == current:
        raise FileNotFoundError("Project root not found")
    current = current.parent

sys.path.append(str(current))
print(f"Added project root: {current}")

Added project root: /notebooks


In [2]:
from utils.io import load_chkpt
import torch
import numpy as np
import torch.utils
import matplotlib.pyplot as plt
import torch.optim as optim

import utils.visualize
from trainers.basetrainer import BaseTrainer, create_trainer_from_chkpt
import losses
import vae_models
from datasets import get_dataset
from utils.io import find_optimal_num_workers
from metrics.utils import MetricAggregator
from utils.io import print_chkpt_info
import pprint

In [3]:
chkpt_test = load_chkpt('checkpoints/tests/test-epoch-1.pt', device='cuda')

Checkpoint loaded from checkpoints/tests/test-epoch-1.pt on cuda


In [4]:
pprint.pprint(chkpt_test.keys())

dict_keys(['train_id', 'train_step_unit', 'train_step_num', 'train_seed', 'train_determinism_type', 'use_torch_compile', 'model', 'loss', 'optimizer', 'lr_scheduler', 'dataset', 'dataloader', 'metrics'])


In [5]:
print_chkpt_info(chkpt_test)

Checkpoint Information:
  Train ID: fa9a7fdb-ff47-4458-b630-9d96595ae66e
  Train Step Unit: epoch
  Train Step Number: 1
  Train Seed: 0
  Train Determinism Type: full
  Use Torch Compile: False
#### Model ####
  Model Name: vae_locatello
  Model kwargs: {'img_size': torch.Size([3, 64, 64]), 'latent_dim': 10, 'encoder_decay': 0.0, 'decoder_decay': 0.0, 'decoder_output_dist': 'bernoulli'}
#### Loss ####
  Loss Name: betavae
  Loss kwargs: {'beta': 16, 'log_kl_components': True, 'rec_dist': 'gaussian'}
#### Dataset ####
  Dataset Name: shapes3d
  Dataset kwargs: {'selected_factors': 'all', 'not_selected_factors_index_value': None, 'root': 'data/shapes3d/', 'subset': 1}
#### Dataloader ####
  Dataloader kwargs: {'batch_size': 64, 'shuffle': True, 'num_workers': 4, 'pin_memory': True, 'seed': 0, 'persistent_workers': True, 'in_order': True, 'snapshot_every_n_steps': 1}
#### Optimizer ####
  Optimizer Name: Adam
  LR Scheduler Name: ConstantLR


In [8]:
chkpt_test['dataloader']['state_dict']

{'_snapshot': {'_snapshot_step': 7500,
  '_last_yielded_worker_id': 3,
  '_main_snapshot': {'_num_workers': 4,
   '_sampler_iter_state': {'samples_yielded': 480000,
    'sampler_iter_state': {'yielded': 480000,
     'generator': tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0', dtype=torch.uint8)}},
   '_index_sampler_state': None,
   '_sampler_iter_yielded': 7500,
   '_IterableDataset_len_called': None,
   '_shared_seed': None,
   '_base_seed': 1559385941720436641},
  '_worker_snapshots': {'worker_0': {'worker_id': 0,
    'dataset_state': None,
    'fetcher_state': None},
   'worker_1': {'worker_id': 1, 'dataset_state': None, 'fetcher_state': None},
   'worker_2': {'worker_id': 2, 'dataset_state': None, 'fetcher_state': None},
   'worker_3': {'worker_id': 3,
    'dataset_state': None,
    'fetcher_state': None}}},
 '_steps_since_snapshot': 0,
 '_iterator_finished': True}

# Training tests

In [6]:
trainer = create_trainer_from_chkpt(chkpt_test)

In [7]:
trainer.train(max_steps=1)

                                                 

TypeError: RNG state must be a torch.ByteTensor