In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import os
import copy
import time
import numpy as np
import random

# Set the random seed for reproducibility
def set_seed(seed=71):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

In [8]:
# For training: resize, then random crop, and apply horizontal flip and normalization.
train_transforms = transforms.Compose([
    transforms.Resize(256),                    # Resize to 256 on the shorter side
    transforms.RandomCrop(224),                # Random crop to 224x224
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # Use ImageNet normalization if using a pretrained AlexNet
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# For validation: resize, then center crop, and apply normalization.
test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Load CIFAR10 dataset using torchvision.datasets
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)

# Split the dataset into 80% training and 20% validation using a fixed seed
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size],
                                            generator=torch.Generator().manual_seed(10))
# Update validation dataset to use test transforms
val_dataset.dataset.transform = test_transforms

# Create DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)

dataloaders = {'train': train_loader, 'val': val_loader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}

Files already downloaded and verified


In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load AlexNet model; use pretrained weights (or set pretrained=False to train from scratch)
model = models.alexnet(pretrained=True)
# Modify the final fully connected layer to output 10 classes (for CIFAR10)
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 10)
model = model.to(device)

# Define loss criterion and optimizer (using Adam)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler: Reduce learning rate on plateau
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

In [10]:
# Early stopping parameters
patience = 5  # epochs to wait without improvement
best_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())
epochs_no_improve = 0
num_epochs = 30  # maximum number of epochs

since = time.time()
print("Starting training ...")

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print('-' * 10)
    
    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # set model to training mode
        else:
            model.eval()   # set model to evaluation mode
        
        running_loss = 0.0
        running_corrects = 0
        
        # Iterate over data.
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass; track gradients only in train phase
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]
        
        print(f'{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # Validate and update scheduler
        if phase == 'val':
            scheduler.step(epoch_loss)
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                epochs_no_improve = 0
                print("Validation loss decreased, saving model ...")
            else:
                epochs_no_improve += 1
    
    # Early stopping check
    if epochs_no_improve == patience:
        print("Early stopping triggered.")
        break
    
    print()

time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed//60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best Validation Loss: {best_loss:.4f}')

# Load best model weights
model.load_state_dict(best_model_wts)

Starting training ...
Epoch 1/30
----------
Train Loss: 1.5629 Acc: 0.4244
Val Loss: 1.2736 Acc: 0.5500
Validation loss decreased, saving model ...

Epoch 2/30
----------
Train Loss: 1.0190 Acc: 0.6454
Val Loss: 0.9539 Acc: 0.6636
Validation loss decreased, saving model ...

Epoch 3/30
----------
Train Loss: 0.8283 Acc: 0.7144
Val Loss: 0.7318 Acc: 0.7500
Validation loss decreased, saving model ...

Epoch 4/30
----------
Train Loss: 0.7154 Acc: 0.7551
Val Loss: 0.7511 Acc: 0.7468

Epoch 5/30
----------
Train Loss: 0.6316 Acc: 0.7849
Val Loss: 0.7066 Acc: 0.7606
Validation loss decreased, saving model ...

Epoch 6/30
----------
Train Loss: 0.5763 Acc: 0.8031
Val Loss: 0.6459 Acc: 0.7859
Validation loss decreased, saving model ...

Epoch 7/30
----------
Train Loss: 0.5397 Acc: 0.8164
Val Loss: 0.6511 Acc: 0.7832

Epoch 8/30
----------
Train Loss: 0.4905 Acc: 0.8319
Val Loss: 0.6236 Acc: 0.7954
Validation loss decreased, saving model ...

Epoch 9/30
----------
Train Loss: 0.4737 Acc: 0.83

<All keys matched successfully>

In [11]:
# Save the best model's weights
torch.save(model.state_dict(), 'alexnet.pth')
print("Model saved as alexnet.pth")

Model saved as alexnet.pth
