In [13]:
# import sys
# from pathlib import Path

# # Find project root by looking for .git or requirements.txt
# current = Path.cwd()
# while not any((current / marker).exists() for marker in ['.git', 'requirements.txt']):
#     if current.parent == current:
#         raise FileNotFoundError("Project root not found")
#     current = current.parent

# sys.path.append(str(current))
# print(f"Added project root: {current}")


In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from utils.io import load_chkpt, create_trainer_from_chkpt, get_dataloader_from_chkpt
from metrics.utils import MetricAggregator
from utils.visualize import Visualizer
from datasets import get_dataset


In [2]:
# Path to your checkpoint file
chkpt_path = 'checkpoints/tests/b-vae-16-epoch-20-gaussian.pt'

# Load the checkpoint
chkpt = load_chkpt(chkpt_path)

# Create trainer from checkpoint
trainer = create_trainer_from_chkpt(chkpt, create_exact=True)

# model = load_model_chkpt(chkpt)

print(f"Model loaded from {chkpt_path}")


Checkpoint loaded from checkpoints/tests/b-vae-16-epoch-20-gaussian.pt on original.
Determinism settings applied from checkpoint: {'seed': 0, 'use_cuda_det': True, 'enforce_det': False, 'cublas_workspace_config': None}


AttributeError: 'BaseTrainer' object has no attribute 'chkpt_viz'

In [None]:
# Load dataset from checkpoint
dataset_info = chkpt['dataset']
dataset_name = dataset_info['name']
dataset_kwargs = dataset_info['kwargs']
dataset_class = get_dataset(dataset_name)
dataset = dataset_class(**dataset_kwargs)

# Create dataloader from checkpoint
dataloader = get_dataloader_from_chkpt(chkpt)

print(f"Dataset {dataset_name} loaded with {len(dataset)} samples.")


In [None]:
metrics_to_compute = [
    {'name': 'dci_d', 'args':{'num_train':5000, 'num_test':1000}}, # Example args
    {'name': 'mig', 'args':{'num_bins':100, 'mi_method':'numpy', 'entropy_method':'numpy'}} # Example args
]

metric_aggregator = MetricAggregator(metrics=metrics_to_compute)

print("\n===== Computing Metrics =====")
metrics_results = metric_aggregator.compute(model=model, 
                                            data_loader=dataloader, 
                                            device=trainer.device)
print("Metrics Results:", metrics_results)


In [None]:
visualizer = Visualizer(vae_model=model, dataset=dataset)

print("\n===== Visualizing Reconstructions =====")
visualizer.plot_random_reconstructions(10, mode='mean')
plt.show()

print("\n===== Visualizing Latent Traversals =====")
visualizer.plot_all_latent_traversals(num_samples=20)
plt.show()
