In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.models import resnet50, ResNet50_Weights
import numpy as np
import time

In [10]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
# Load the pretrained ResNet-50 model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Freeze all the convolutional layers
for param in model.parameters():
    param.requires_grad = False

# Replace the final fully connected layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)

# Move model to GPU (if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [12]:
def train(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch} [{batch_idx * len(inputs)}/{len(train_loader.dataset)}] Loss: {loss.item():.4f} Acc: {100. * correct / total:.2f}%')

    return running_loss / len(train_loader), 100. * correct / total

# Evaluation function
def evaluate(model, data_loader, criterion, device):
    model.eval()
    loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss += criterion(outputs, targets).item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return loss / len(data_loader), 100. * correct / total

# Function to set random seeds
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [13]:
num_splits = 10
num_epochs = 50
patience = 5

all_train_losses = []
all_train_accs = []
all_val_losses = []
all_val_accs = []
all_test_losses = []
all_test_accs = []

for split in range(num_splits):
    set_seed(split)
    
    # Split the training dataset into training and validation sets
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset_split, val_dataset_split = random_split(train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset_split, batch_size=128, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset_split, batch_size=100, shuffle=False, num_workers=2)

    # Reload the pretrained ResNet-50 model for each split
    model = resnet50(weights=weights)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Linear(num_ftrs, 10)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        start_time = time.time()
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, epoch)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        end_time = time.time()

        print(f'Split {split}, Epoch {epoch} took {(end_time - start_time):.2f} seconds')
        print(f'Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f} Val Acc: {val_acc:.2f}%')

        # Check for early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save the best model
            torch.save(model.state_dict(), f'best_model_split_{split}.pth')
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping triggered")
            break
    
    # Load the best model before testing
    model.load_state_dict(torch.load(f'best_model_split_{split}.pth'))

    # Evaluate on the test set
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f'Split {split}, Test Loss: {test_loss:.4f} Test Acc: {test_acc:.2f}%')

    # Store metrics
    all_train_losses.append(train_loss)
    all_train_accs.append(train_acc)
    all_val_losses.append(val_loss)
    all_val_accs.append(val_acc)
    all_test_losses.append(test_loss)
    all_test_accs.append(test_acc)

# Print average metrics over all splits
print(f'Average Train Loss: {np.mean(all_train_losses):.4f} Train Acc: {np.mean(all_train_accs):.2f}%')
print(f'Average Val Loss: {np.mean(all_val_losses):.4f} Val Acc: {np.mean(all_val_accs):.2f}%')
print(f'Average Test Loss: {np.mean(all_test_losses):.4f} Test Acc: {np.mean(all_test_accs):.2f}%')

Epoch: 0 [0/40000] Loss: 2.3613 Acc: 11.72%
Epoch: 0 [12800/40000] Loss: 2.2366 Acc: 15.12%


KeyboardInterrupt: 