In [1]:
import importlib
import data_handler  # Ensure the module is imported
importlib.reload(data_handler)  # Reload the module
from data_handler import DataHandler  # Re-import the class

import models  # Ensure the module is imported
importlib.reload(models)  # Reload the module
from models import Wide_ResNet28


import MixMo  # Ensure the module is imported
importlib.reload(MixMo)  # Reload the module
from MixMo import linear_mixmo,cut_mixmo  # Re-import the class


In [2]:
# Initialize DataHandler
data_handler = DataHandler(data_root='./data')

# Load CIFAR-10 with standard augmentations
train_loader, test_loader = data_handler.get_cifar10(batch_size=32)


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchmetrics.functional import accuracy
from scipy.special import softmax

# Create a vanilla ResNet model
model = Wide_ResNet28(
    widen_factor=10,      # Controls network width
    dropout_rate=0.3,     # Dropout rate for regularization
    num_classes=10,       # Number of output classes for CIFAR-10
)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define loss function and optimizer with specified hyperparameters
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-4)  # Using 2e-4 as specified for baseline

# Learning rate scheduler with decay rate 0.1 at milestones
milestones = [100, 150]  # Common milestones for ResNet on CIFAR
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

# Training function
def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs , _ = model(inputs,inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 20 == 0:
            print(f'Epoch: {epoch} | Batch: {batch_idx} | Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}%')

# Calculate negative log-likelihood
def calculate_nll(outputs, targets):
    probs = softmax(outputs.cpu().numpy(), axis=1)
    target_indices = targets.cpu().numpy()
    target_probs = probs[np.arange(len(target_indices)), target_indices]
    return -np.mean(np.log(target_probs + 1e-10))

# Testing function with Top-1, Top-5, and NLL metrics
def test(epoch):
    model.eval()
    test_loss = 0
    all_outputs = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs,_ = model(inputs,inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            all_outputs.append(outputs.cpu())
            all_targets.append(targets.cpu())
    
    # Concatenate all batches
    all_outputs = torch.cat(all_outputs)
    all_targets = torch.cat(all_targets)
    
    # Calculate metrics
    top1_acc = accuracy(all_outputs, all_targets, task="multiclass", num_classes=10, top_k=1)
    top5_acc = accuracy(all_outputs, all_targets, task="multiclass", num_classes=10, top_k=5)
    nll = calculate_nll(all_outputs, all_targets)
    
    print(f'Test | Epoch: {epoch} | Loss: {test_loss/len(test_loader):.3f}')
    print(f'Top-1 Accuracy: {top1_acc*100:.2f}%')
    print(f'Top-5 Accuracy: {top5_acc*100:.2f}%')
    print(f'NLL: {nll:.4f}')
    
    return top1_acc, top5_acc, nll

# Train the model for 200 epochs as specified for baseline models
best_acc = 0
best_metrics = None
total_epochs = 200  # 200 epochs for baseline models as specified

for epoch in range(total_epochs):
    train(epoch)
    top1_acc, top5_acc, nll = test(epoch)
    scheduler.step()
    
    # Save the best model based on Top-1 accuracy
    if top1_acc > best_acc:
        best_acc = top1_acc
        best_metrics = (top1_acc, top5_acc, nll)
        print(f'Saving best model with Top-1 accuracy: {best_acc*100:.2f}%')
        torch.save(model.state_dict(), 'best_vanilla_cifar10_run1.pth')

# Print final best results
if best_metrics:
    print("\n=== Best Model Results ===")
    print(f'Top-1 Accuracy: {best_metrics[0]*100:.2f}%')
    print(f'Top-5 Accuracy: {best_metrics[1]*100:.2f}%')
    print(f'NLL: {best_metrics[2]:.4f}')

| Wide-ResNet 28x10 with dual encoders using none augmentation
Epoch: 0 | Batch: 0 | Loss: 2.334 | Acc: 6.250%
Epoch: 0 | Batch: 20 | Loss: 2.844 | Acc: 15.476%
Epoch: 0 | Batch: 40 | Loss: 2.644 | Acc: 15.320%


KeyboardInterrupt: 