# AlexNet Implementation

### 1. Loading AlexNet

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [3]:
transform = transforms.Compose([
    transforms.Resize((112, 112)),                # Smaller than 224x224 for faster CPU training
    transforms.Grayscale(num_output_channels=3), # Convert 1 channel → 3 channels
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])

In [4]:
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
full_dataset = torch.utils.data.ConcatDataset([dataset, test_dataset])
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_data, val_data, test_data = random_split(full_dataset, [train_size, val_size, test_size])

In [6]:
batch_size = 64

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [7]:
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}, Labels shape: {labels.shape}")

Batch shape: torch.Size([64, 3, 112, 112]), Labels shape: torch.Size([64])


### 2. AlexNet Transfer Learning

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
from torchvision.models import AlexNet_Weights

# Use the default pretrained weights
model = models.alexnet(weights=AlexNet_Weights.DEFAULT)

In [11]:
# Freeze feature extractor (convolutional layers)
for param in model.features.parameters():
    param.requires_grad = False

In [12]:
# Replace classifier for 10 MNIST classes
model.classifier[6] = nn.Linear(4096, 10)
model = model.to(device)

In [13]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)


In [14]:
num_epochs = 3  # MNIST converges quickly

for epoch in range(num_epochs):
    running_loss = 0.0
    model.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

print("Training complete!")

Epoch 1/3, Loss: 0.2577
Epoch 2/3, Loss: 0.1829
Epoch 3/3, Loss: 0.1571
Training complete!


In [15]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Validation Accuracy: {100 * correct / total:.2f}%")

Validation Accuracy: 97.90%


In [68]:
from PIL import Image

preprocess = transforms.Compose([
    transforms.Resize((112,112)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])

def predict_digit(image_path):
    image = Image.open(image_path).convert("L")
    input_tensor = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(input_tensor)
        _, predicted = torch.max(outputs, 1)
    return predicted.item()

# Example usage:
image_path = 'test_imgs/9.png'  # Replace with your image path
predicted_digit = predict_digit(image_path)
print(f"Predicted Digit: {predicted_digit}")

Predicted Digit: 0
