In [1]:
import sys
from pathlib import Path

# Find project root by looking for .git or requirements.txt
current = Path.cwd()
while not any((current / marker).exists() for marker in ['.git', 'requirements.txt']):
    if current.parent == current:
        raise FileNotFoundError("Project root not found")
    current = current.parent

sys.path.append(str(current))
print(f"Added project root: {current}")

Added project root: /notebooks


In [2]:
from utils.io import load_chkpt
import torch
import numpy as np
import torch.utils
import matplotlib.pyplot as plt
import torch.optim as optim

import utils.visualize
from trainers import UnsupervisedTrainer
import losses
import vae_models
from datasets import get_dataset
from utils.io import find_optimal_num_workers
from metrics.utils import MetricAggregator

In [3]:
chkpt_test = load_chkpt('checkpoints/tests/test-epoch-1.pt')

Checkpoint loaded from checkpoints/tests/test-epoch-1.pt on cpu


In [4]:
print(chkpt_test)

{'train_id': UUID('481b4941-2740-41c6-89fd-64304cfda922'), 'train_step_unit': 'epoch', 'train_step_num': 1, 'train_seed': 0, 'train_determinism_type': 'full', 'model': {'name': 'vae_locatello', 'kwargs': {'img_size': torch.Size([3, 64, 64]), 'latent_dim': 10, 'encoder_decay': 0.0, 'decoder_decay': 0.0, 'decoder_output_dist': 'bernoulli'}, 'state_dict': OrderedDict([('encoder.conv1.weight', tensor([[[[-0.1480, -0.0353,  0.3251,  0.3058],
          [-0.2813,  0.0724,  0.0100, -0.2668],
          [ 0.1278, -0.3329,  0.1221,  0.1720],
          [ 0.2957,  0.3238, -0.0738, -0.1262]],

         [[ 0.0516, -0.0599,  0.1857,  0.1622],
          [-0.0689,  0.2102, -0.1859,  0.1181],
          [ 0.3399, -0.2358,  0.0280, -0.2517],
          [-0.2110,  0.2585, -0.2816, -0.0625]],

         [[-0.2162, -0.1501, -0.1544, -0.2677],
          [ 0.2151,  0.2385, -0.0063, -0.2413],
          [-0.1040,  0.1970, -0.2248,  0.3017],
          [-0.1625, -0.2742, -0.0315, -0.0098]]],


        [[[ 0.3099, -0.

In [5]:
print(chkpt_test['model'])

{'name': 'vae_locatello', 'kwargs': {'img_size': torch.Size([3, 64, 64]), 'latent_dim': 10, 'encoder_decay': 0.0, 'decoder_decay': 0.0, 'decoder_output_dist': 'bernoulli'}, 'state_dict': OrderedDict([('encoder.conv1.weight', tensor([[[[-0.1480, -0.0353,  0.3251,  0.3058],
          [-0.2813,  0.0724,  0.0100, -0.2668],
          [ 0.1278, -0.3329,  0.1221,  0.1720],
          [ 0.2957,  0.3238, -0.0738, -0.1262]],

         [[ 0.0516, -0.0599,  0.1857,  0.1622],
          [-0.0689,  0.2102, -0.1859,  0.1181],
          [ 0.3399, -0.2358,  0.0280, -0.2517],
          [-0.2110,  0.2585, -0.2816, -0.0625]],

         [[-0.2162, -0.1501, -0.1544, -0.2677],
          [ 0.2151,  0.2385, -0.0063, -0.2413],
          [-0.1040,  0.1970, -0.2248,  0.3017],
          [-0.1625, -0.2742, -0.0315, -0.0098]]],


        [[[ 0.3099, -0.3008, -0.1853,  0.2841],
          [ 0.2697,  0.1907, -0.2942, -0.1456],
          [ 0.2054,  0.1350, -0.3039,  0.1154],
          [ 0.3156,  0.0944, -0.0060,  0.0267]]

In [6]:
print(chkpt_test['loss'])

{'name': 'beta_vae', 'kwargs': {'beta': 16, 'log_kl_components': True, 'mode': 'post_forward', 'rec_dist': 'gaussian'}, 'state_dict': None}


In [6]:
print(chkpt_test['model']['state_dict'])

OrderedDict([('encoder.conv1.weight', tensor([[[[-0.1480, -0.0353,  0.3251,  0.3058],
          [-0.2813,  0.0724,  0.0100, -0.2668],
          [ 0.1278, -0.3329,  0.1221,  0.1720],
          [ 0.2957,  0.3238, -0.0738, -0.1262]],

         [[ 0.0516, -0.0599,  0.1857,  0.1622],
          [-0.0689,  0.2102, -0.1859,  0.1181],
          [ 0.3399, -0.2358,  0.0280, -0.2517],
          [-0.2110,  0.2585, -0.2816, -0.0625]],

         [[-0.2162, -0.1501, -0.1544, -0.2677],
          [ 0.2151,  0.2385, -0.0063, -0.2413],
          [-0.1040,  0.1970, -0.2248,  0.3017],
          [-0.1625, -0.2742, -0.0315, -0.0098]]],


        [[[ 0.3099, -0.3008, -0.1853,  0.2841],
          [ 0.2697,  0.1907, -0.2942, -0.1456],
          [ 0.2054,  0.1350, -0.3039,  0.1154],
          [ 0.3156,  0.0944, -0.0060,  0.0267]],

         [[ 0.0446, -0.1919, -0.2988,  0.1521],
          [ 0.3096, -0.0261,  0.3548, -0.2707],
          [ 0.2501, -0.0230,  0.0663, -0.3196],
          [-0.1394,  0.0457, -0.2498,  0

# Training tests

In [None]:
# Load 3D Shapes
Shapes3D = get_dataset("shapes3d")
shapes3d_dataset = Shapes3D(selected_factors='all', not_selected_factors_index_value=None)

# num_workers_3dshapes = find_optimal_num_workers(shapes3d_dataset, batch_size=batch_size, num_batches_to_test='all')
num_workers_3dshapes = 4

if deterministic_run:
    shapes3d_dataloader = get_deterministic_dataloader(dataset=shapes3d_dataset, 
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=num_workers_3dshapes,
                                                   seed=seed,
                                                   pin_memory=True)
else:
    shapes3d_dataloader = torch.utils.data.DataLoader(shapes3d_dataset, 
                                                      batch_size=batch_size, 
                                                      num_workers=num_workers_3dshapes, 
                                                      shuffle=True, 
                                                      pin_memory=True)


print(f"Loaded 3D Shapes dataset with {len(shapes3d_dataset)} samples.")
