In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import os

# 1. Define the AlexNet Model

class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4), # conv1
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2), # lrn1 - Optional, can be removed for modern implementations
            nn.MaxPool2d(kernel_size=3, stride=2), # pool1
            nn.Conv2d(96, 256, kernel_size=5, padding=2), # conv2 (padding=2 to keep size same with stride 1)
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=1e-4, beta=0.75, k=2), # lrn2 - Optional
            nn.MaxPool2d(kernel_size=3, stride=2), # pool2
            nn.Conv2d(256, 384, kernel_size=3, padding=1), # conv3 (padding=1 to keep size same with stride 1)
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1), # conv4 (padding=1 to keep size same with stride 1)
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1), # conv5 (padding=1 to keep size same with stride 1)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2), # pool3
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) # Adaptive pooling to ensure consistent input to FC layers
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5), # dropout layer fc6
            nn.Linear(256 * 6 * 6, 4096), # fc6
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5), # dropout layer fc7
            nn.Linear(4096, 4096), # fc7
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes), # fc8
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# 2. Data Loading and Preprocessing (Example - using a dummy dataset)

class DummyImageDataset(Dataset): # Replace with your actual ImageNet or dataset loading
    def __init__(self, root_dir, transform=None, num_samples=1000, num_classes=1000):
        self.root_dir = root_dir
        self.transform = transform
        self.num_samples = num_samples
        self.num_classes = num_classes
        self.image_paths = [os.path.join(root_dir, f'image_{i}.jpg') for i in range(num_samples)]
        os.makedirs(root_dir, exist_ok=True)
        for path in self.image_paths:
            if not os.path.exists(path):
                Image.new('RGB', (256, 256), color = 'red').save(path) # Create dummy image if not exist

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = np.random.randint(0, self.num_classes) # Dummy random label

        if self.transform:
            image = self.transform(image)

        return image, label


# Data Augmentation Transforms (as described in AlexNet paper)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224), # Random crop to 224x224 from resized (implicitly 256x256 in AlexNet)
    transforms.RandomHorizontalFlip(), # Horizontal flip
    transforms.ToTensor(), # Convert to tensor, scales to [0, 1]
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet normalization - if needed for real ImageNet
])

val_transform = transforms.Compose([
    transforms.Resize(256), # Resize to 256x256
    transforms.CenterCrop(224), # Center crop to 224x224
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet normalization - if needed for real ImageNet
])


# Create Dummy Dataset and DataLoaders (Replace with your actual data)
num_classes = 1000 # For ImageNet
train_dataset = DummyImageDataset('dummy_train_images', transform=train_transform, num_samples=1000, num_classes=num_classes)
val_dataset = DummyImageDataset('dummy_val_images', transform=val_transform, num_samples=200, num_classes=num_classes)

batch_size = 128 # As mentioned in the paper
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) # num_workers for parallel data loading
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


# 3. Set up Training Parameters, Optimizer, and Learning Rate Scheduler

model = AlexNet(num_classes=num_classes)

# Check if CUDA is available and use GPU if possible
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device) # Move model to GPU if available

criterion = nn.CrossEntropyLoss() # CrossEntropyLoss for classification
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) # SGD with momentum and weight decay


# Learning Rate Scheduling (Manual decay as in paper)
def adjust_learning_rate(optimizer, epoch, initial_lr=0.01):
    """Decay learning rate by a factor of 10 when validation error plateaus."""
    lr = initial_lr
    if epoch >= 30: # Example epochs for decay - adjust based on validation performance
        lr /= 10
    if epoch >= 60:
        lr /= 10
    if epoch >= 80:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


# 4. Training Loop

num_epochs = 90 # As mentioned in the paper (roughly)
initial_lr = 0.01 # Initial learning rate

for epoch in range(num_epochs):
    current_lr = adjust_learning_rate(optimizer, epoch, initial_lr) # Adjust learning rate per epoch
    model.train() # Set model to training mode
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device) # Move data to GPU if available

        optimizer.zero_grad() # Zero gradients

        outputs = model(inputs) # Forward pass
        loss = criterion(outputs, labels) # Calculate loss
        loss.backward() # Backward pass
        optimizer.step() # Optimizer step (update weights)

        running_loss += loss.item()
        if i % 10 == 9: # Print every 10 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 10:.3f} LR: {current_lr}')
            running_loss = 0.0

    # Validation after each epoch (optional, but good practice)
    model.eval() # Set model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad(): # Disable gradient calculation during validation
        for data in val_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Epoch {epoch+1} Validation Loss: {val_loss/len(val_loader):.3f}, Validation Accuracy: {100 * correct / total:.2f}%')


print('Finished Training')