In [1]:
# # If running in a notebook, install missing dependencies with pip before import
# import sys
# import subprocess

# def install_and_import(package, import_name=None):
#     import importlib
#     try:
#         if import_name is None:
#             import_name = package
#         importlib.import_module(import_name)
#     except ImportError:
#         print(f"Installing {package} ...")
#         subprocess.check_call([sys.executable, "-m", "pip", "install", package])
#         # Optionally try to import again
#         importlib.invalidate_caches()
#         importlib.import_module(import_name)

# # List of (pip_package, import_name) pairs
# packages = [
#     ("torch", "torch"),
#     ("torchvision", "torchvision"),
#     ("datasets", "datasets"),
#     ("numpy", "numpy"),
# ]

# for pip_name, import_name in packages:
#     install_and_import(pip_name, import_name)


In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from ResNet import ResNet
import numpy as np
import gc
import os
import glob

device = "mps" if torch.backends.mps.is_built() \
    else "cuda" if torch.cuda.is_available() else "cpu"

print(device)

torch.manual_seed(3)

mps


<torch._C.Generator at 0x1160cfc70>

In [3]:
# Alternative: Manually load a specific checkpoint
# Uncomment and modify the path to load a specific checkpoint

# checkpoint_path = 'checkpoints/resnet_epoch_50.pth'  # Specify the checkpoint you want to load
# start_epoch, train_losses, train_accuracies, val_accuracies = load_checkpoint(
#     model, optimizer, scheduler, checkpoint_path
# )
# print(f"Loaded checkpoint from epoch {start_epoch}")


## __Initial training of ResNet-56 on CIFAR10__

#### NN

In [4]:
model=ResNet(num_classes=10,n=9).to(device)

#### Load dataset

In [5]:
# Define standard data transforms for CIFAR10
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010]
    ),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010]
    ),
])


In [6]:
# Load datasets
train_dataset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform)
test_dataset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=test_transform)

#### Define basic params

In [None]:
# Training hyperparameters
num_epochs = 270
batch_size = 1000
initial_lr = 0.1

# Learning rate schedule: warmup for 15 epochs, then step down
def lr_lambda(epoch):
    if epoch < 15:
        return (epoch + 1) / 15
    elif epoch < 90:
        return 1.0
    elif epoch < 180:
        return 0.1
    elif epoch < 240:
        return 0.01
    else:
        return 0.001

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)


In [8]:
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [9]:
# Checkpoint loading functionality
def find_latest_checkpoint(checkpoint_dir='checkpoints'):
    """Find the latest checkpoint file in the directory"""
    if not os.path.exists(checkpoint_dir):
        return None
    
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'resnet_epoch_*.pth'))
    if not checkpoint_files:
        return None
    
    # Sort by epoch number (extract epoch number from filename)
    def extract_epoch_number(filename):
        # Extract epoch number from filename like 'resnet_epoch_50.pth'
        basename = os.path.basename(filename)
        epoch_str = basename.split('_epoch_')[1].split('.pth')[0]
        return int(epoch_str)
    
    latest_checkpoint = max(checkpoint_files, key=extract_epoch_number)
    return latest_checkpoint

def load_checkpoint(model, optimizer, scheduler, checkpoint_path):
    """Load checkpoint and return epoch number and metrics"""
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint not found: {checkpoint_path}")
        return 0, [], [], []
    
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Load optimizer state
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # Load scheduler state
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get epoch number
    start_epoch = checkpoint['epoch']
    
    print(f"Resuming from epoch {start_epoch}")
    print(f"Previous metrics - Loss: {checkpoint['train_loss']:.4f}, "
          f"Train Acc: {checkpoint['train_acc']:.4f}, Val Acc: {checkpoint['val_acc']:.4f}")
    
    return start_epoch, checkpoint.get('train_losses', []), \
           checkpoint.get('train_accuracies', []), checkpoint.get('val_accuracies', [])


In [10]:
# Enhanced training loop with checkpoint resuming
def train_with_resume(model, optimizer, scheduler, criterion, train_loader, test_loader, 
                     num_epochs, device, resume_from_checkpoint=True):
    """
    Training loop that can resume from the latest checkpoint
    """
    # Initialize metrics lists
    train_losses = []
    train_accuracies = []
    val_accuracies = []
    
    # Create directory for saving weights
    os.makedirs('checkpoints', exist_ok=True)
    
    # Try to resume from checkpoint if requested
    start_epoch = 0
    if resume_from_checkpoint:
        latest_checkpoint = find_latest_checkpoint()
        if latest_checkpoint:
            start_epoch, train_losses, train_accuracies, val_accuracies = load_checkpoint(
                model, optimizer, scheduler, latest_checkpoint
            )
            print(f"Resuming training from epoch {start_epoch + 1}")
        else:
            print("No checkpoint found, starting from epoch 1")
    else:
        print("Starting fresh training from epoch 1")
    
    # Training loop
    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        avg_loss = total_loss / len(train_loader)
        train_acc = correct / total if total > 0 else 0.

        del images, labels, output, loss

        model.eval()
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                output = model(images)
                _, predicted = torch.max(output, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
            val_acc = correct_val / total_val if total_val > 0 else 0.
        
        del images, labels, output

        train_losses.append(avg_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)

        scheduler.step()

        # Save weights and print every 10th epoch
        if (epoch + 1) % 10 == 0:
            checkpoint_path = f'checkpoints/resnet_epoch_{epoch+1}.pth'
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': avg_loss,
                'train_acc': train_acc,
                'val_acc': val_acc,
                'train_losses': train_losses,
                'train_accuracies': train_accuracies,
                'val_accuracies': val_accuracies,
            }, checkpoint_path)
            
            print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f} Train Acc: {train_acc:.4f} Val Acc: {val_acc:.4f}")
            print(f"Checkpoint saved to {checkpoint_path}")
        
        # Clear memory between epochs
        if device == "cuda":
            torch.cuda.empty_cache()
        elif device == "mps":
            torch.mps.empty_cache()
        # For CPU, we only use gc.collect() which is called below
        gc.collect()
    
    return train_losses, train_accuracies, val_accuracies


In [None]:
# Usage example: Start training with automatic checkpoint resuming
# This will automatically find and load the latest checkpoint if available
train_losses, train_accuracies, val_accuracies = train_with_resume(
    model=model,
    optimizer=optimizer, 
    scheduler=scheduler,
    criterion=criterion,
    train_loader=train_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    device=device,
    resume_from_checkpoint=False  # Set to False to start fresh
)


Starting fresh training from epoch 1


In [None]:
# # Enhanced training loop with weight saving and memory clearing
# train_losses = []
# train_accuracies = []
# val_accuracies = []

# # Create directory for saving weights
# os.makedirs('checkpoints', exist_ok=True)

# for epoch in range(num_epochs):
#     model.train()
#     total_loss = 0.0
#     correct = 0
#     total = 0

#     for images, labels in train_loader:
#         images, labels = images.to(device), labels.to(device)

#         optimizer.zero_grad()
#         output = model(images)
#         loss = criterion(output, labels)
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()
#         _, predicted = torch.max(output, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
    
#     avg_loss = total_loss / len(train_loader)
#     train_acc = correct / total if total > 0 else 0.

#     model.eval()
#     correct_val = 0
#     total_val = 0

#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)
#             output = model(images)
#             _, predicted = torch.max(output, 1)
#             total_val += labels.size(0)
#             correct_val += (predicted == labels).sum().item()
#         val_acc = correct_val / total_val if total_val > 0 else 0.

#     train_losses.append(avg_loss)
#     train_accuracies.append(train_acc)
#     val_accuracies.append(val_acc)

#     scheduler.step()

#     # Save weights and print every 10th epoch
#     if (epoch + 1) % 10 == 0:
#         checkpoint_path = f'checkpoints/resnet_epoch_{epoch+1}.pth'
#         torch.save({
#             'epoch': epoch + 1,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'scheduler_state_dict': scheduler.state_dict(),
#             'train_loss': avg_loss,
#             'train_acc': train_acc,
#             'val_acc': val_acc,
#         }, checkpoint_path)
        
#         print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f} Train Acc: {train_acc:.4f} Val Acc: {val_acc:.4f}")
#         print(f"Checkpoint saved to {checkpoint_path}")
    
#     # Clear memory between epochs
#     if device == "cuda":
#         torch.cuda.empty_cache()
#     elif device == "mps":
#         torch.mps.empty_cache()
#     gc.collect()


In [None]:
# train_losses = []
# train_accuracies = []
# val_accuracies = []

# for epoch in range(num_epochs):
#     model.train()
#     total_loss = 0.0
#     correct = 0
#     total = 0

#     for images, labels in train_loader:
#         images, labels = images.to(device), labels.to(device)

#         optimizer.zero_grad()
#         output = model(images)
#         loss = criterion(output, labels)
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()
#         _, predicted = torch.max(output, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
    
#     avg_loss = total_loss / len(train_loader)
#     train_acc = correct / total if total > 0 else 0.

#     model.eval()
#     correct_val = 0
#     total_val = 0

#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)
#             output = model(images)
#             _, predicted = torch.max(output, 1)
#             total_val += labels.size(0)
#             correct_val += (predicted == labels).sum().item()
#         val_acc = correct_val / total_val if total_val > 0 else 0.

#     train_losses.append(avg_loss)
#     train_accuracies.append(train_acc)
#     val_accuracies.append(val_acc)

#     scheduler.step()

#     # Print every 10th epoch
#     if (epoch + 1) % 10 == 0:
#         print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f} Train Acc: {train_acc:.4f} Val Acc: {val_acc:.4f}")


In [None]:
# from datasets import Dataset
# Dataset.cleanup_cache_files

In [None]:
# from datasets import load_dataset
# ds = load_dataset("hirundo-io/Noisy-CIFAR-100")

In [None]:
ds