<a href="https://colab.research.google.com/github/AnuruddhaPaul/Alex-Net-from-scratch/blob/main/Alex_Net_From_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install torchinfo



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import time
from torchinfo import summary
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
num_epochs = 5
num_classes = 10
batch_size = 64
learning_rate = 0.001

transform = transforms.Compose([
    transforms.Resize((227, 227)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # MNIST Mean & Std
])

train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

class AlexNet(nn.Module):
  def __init__(self, num_classes):
    super(AlexNet,self).__init__()
    self.features=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=96,kernel_size=11,stride=4,padding=0),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3,stride=2),

        nn.Conv2d(96,256,kernel_size=5,stride=1,padding=2),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3,stride=2),


        nn.Conv2d(256, 384, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),

        # Layer 4
        nn.Conv2d(384, 384, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),

        # Layer 5
        nn.Conv2d(384, 256, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3, stride=2),
    )
    # Adaptive pooling allows us to handle slight variations if needed,
    # ensuring the output is always 6x6 per filter before flattening.
    self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

    self.classifier = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(256 * 6 * 6, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, 4096),
        nn.ReLU(inplace=True),
        nn.Linear(4096, num_classes),
    )
  def forward(self,x):
    x=self.features(x)
    x=self.avgpool(x)
    x=torch.flatten(x,1)
    x=self.classifier(x)
    return x


model = AlexNet(num_classes=num_classes).to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

summary(model, input_size=(batch_size, 1, 227, 227))

print("Starting Training...")
total_step = len(train_loader)
start_time = time.time()

for epoch in range(num_epochs):
    model.train() # Set model to training mode (enables Dropout)
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')

print(f"Training finished in {(time.time() - start_time)/60:.2f} minutes.")

# ---------------------------------------------------------
# 5. Testing / Evaluation
# ---------------------------------------------------------
model.eval() # Set model to evaluation mode (disables Dropout)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of the model on the 10,000 test images: {100 * correct / total:.2f}%')

Using device: cuda
Starting Training...
Epoch [1/5], Step [100/938], Loss: 0.5497
Epoch [1/5], Step [200/938], Loss: 0.1472
Epoch [1/5], Step [300/938], Loss: 0.0885
Epoch [1/5], Step [400/938], Loss: 0.0561
Epoch [1/5], Step [500/938], Loss: 0.0843
Epoch [1/5], Step [600/938], Loss: 0.1273
Epoch [1/5], Step [700/938], Loss: 0.1546
Epoch [1/5], Step [800/938], Loss: 0.1071
Epoch [1/5], Step [900/938], Loss: 0.0576
Epoch [2/5], Step [100/938], Loss: 0.0234
Epoch [2/5], Step [200/938], Loss: 0.0608
Epoch [2/5], Step [300/938], Loss: 0.1127
Epoch [2/5], Step [400/938], Loss: 0.0129
Epoch [2/5], Step [500/938], Loss: 0.0683
Epoch [2/5], Step [600/938], Loss: 0.1542
Epoch [2/5], Step [700/938], Loss: 0.0635
Epoch [2/5], Step [800/938], Loss: 0.1348
Epoch [2/5], Step [900/938], Loss: 0.1242
Epoch [3/5], Step [100/938], Loss: 0.0255
Epoch [3/5], Step [200/938], Loss: 0.1226
Epoch [3/5], Step [300/938], Loss: 0.0272
Epoch [3/5], Step [400/938], Loss: 0.0542
Epoch [3/5], Step [500/938], Loss: 0

In [4]:
from torchinfo import summary
summary(model, input_size=(batch_size, 1, 227, 227))

Layer (type:depth-idx)                   Output Shape              Param #
AlexNet                                  [64, 10]                  --
├─Sequential: 1-1                        [64, 256, 6, 6]           --
│    └─Conv2d: 2-1                       [64, 96, 55, 55]          11,712
│    └─ReLU: 2-2                         [64, 96, 55, 55]          --
│    └─MaxPool2d: 2-3                    [64, 96, 27, 27]          --
│    └─Conv2d: 2-4                       [64, 256, 27, 27]         614,656
│    └─ReLU: 2-5                         [64, 256, 27, 27]         --
│    └─MaxPool2d: 2-6                    [64, 256, 13, 13]         --
│    └─Conv2d: 2-7                       [64, 384, 13, 13]         885,120
│    └─ReLU: 2-8                         [64, 384, 13, 13]         --
│    └─Conv2d: 2-9                       [64, 384, 13, 13]         1,327,488
│    └─ReLU: 2-10                        [64, 384, 13, 13]         --
│    └─Conv2d: 2-11                      [64, 256, 13, 13]      