In [1]:
import torch
from pathlib import Path

MODEL_PATH = Path("models/best_model.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not MODEL_PATH.exists():
    raise FileNotFoundError(f"{MODEL_PATH} not found")

checkpoint = torch.load(MODEL_PATH, map_location=device)

# Heuristics to detect what's stored
best_model = None
best_model_state_dict = None

if isinstance(checkpoint, dict):
    if "model_state_dict" in checkpoint:
        best_model_state_dict = checkpoint["model_state_dict"]
    elif "state_dict" in checkpoint:
        best_model_state_dict = checkpoint["state_dict"]
    else:
        # If values are tensors, treat as state_dict
        try:
            first_val = next(iter(checkpoint.values()))
            if isinstance(first_val, torch.Tensor):
                best_model_state_dict = checkpoint
        except StopIteration:
            best_model_state_dict = None

if best_model_state_dict is None and not isinstance(checkpoint, dict):
    # checkpoint is likely a full model object
    best_model = checkpoint

if best_model_state_dict is not None:
    print(f"Loaded state_dict with {len(best_model_state_dict)} parameters. Define your model architecture and run: model.load_state_dict(best_model_state_dict)")
else:
    print(f"Loaded model object of type: {type(best_model)}")

# Expose variables for subsequent notebook cells:
# - best_model (if the full model object was saved)
# - best_model_state_dict (if only state_dict was saved)

Loaded state_dict with 39 parameters. Define your model architecture and run: model.load_state_dict(best_model_state_dict)


In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import multiprocessing
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from utils import CIFAR10_dataset, count_parameters, ICCNN

mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]

test_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean, std)])


test_dataset = CIFAR10_dataset(partition="test", transform=test_transform)

batch_size = 100
num_workers = multiprocessing.cpu_count()-1
print("Num workers", num_workers)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=num_workers)