# Imports

In [1]:
from __future__ import print_function
import torch
import json

from data_loader import load_cifar10, get_class_names
from training_utils import train_model, continue_training

# CUDA

In [None]:
print(f"Is CUDA available? {torch.cuda.is_available()}")
!nvcc --version

In [3]:
#!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

# Initialization

In [4]:
# Set the device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_loader, test_loader, X_train, X_test, Y_train, Y_test = load_cifar10(batch_size=64, seed=42)
class_names = get_class_names()

# Train CNN

In [None]:
from model import SmallerComparableCNN
model = SmallerComparableCNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"Model device: {next(model.parameters()).device}")

# Train model
metrics = train_model(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    model_name='cnn',
    num_epochs=2000,
    device=device
)

# Continue training

In [None]:
# Create model instance
model = SmallerComparableCNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Find the last checkpoint
checkpoint_dir = 'cnn_checkpoints'  
with open(f'{checkpoint_dir}/training_metrics.json', 'r') as f:
    metrics = json.load(f)
print(f"Last completed epoch: {metrics['current_epoch']}")

# Continue training
metrics = continue_training(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    model_name='cnn', 
    checkpoint_dir='cnn_checkpoints',
    target_epochs=2000,
    device=device
)