<a href="https://colab.research.google.com/github/Vaishnavi-Hegde17/Deep_Learning-and-Gen_AI-Lab/blob/main/Week6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Week-6: Transfer Learning**

Implement the standard LeNet, AlexNet, VGG CNN architecture model to classify multicategory image dataset.

MNIST handwritten digits (0-9)

Note down accuracies obtained for epochs 5, 50, 250.



In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm


In [2]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize for compatibility with AlexNet/VGG
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


100%|██████████| 9.91M/9.91M [00:00<00:00, 59.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.73MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.7MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 2.97MB/s]


In [3]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5),  # 32x32 → 28x28
            nn.Tanh(),
            nn.AvgPool2d(2),                # 28x28 → 14x14
            nn.Conv2d(6, 16, kernel_size=5),# 14x14 → 10x10
            nn.Tanh(),
            nn.AvgPool2d(2),                # 10x10 → 5x5
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.Tanh(),
            nn.Linear(120, 84),
            nn.Tanh(),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        return self.model(x)


In [4]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.model = nn.Sequential(
            # Reduced kernel size and stride for smaller input
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), # 32x32 -> 32x32
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),              # 32x32 -> 16x16

            nn.Conv2d(64, 192, kernel_size=3, padding=1),       # 16x16 -> 16x16
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),              # 16x16 -> 8x8

            nn.Conv2d(192, 384, kernel_size=3, padding=1),      # 8x8 -> 8x8
            nn.ReLU(),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),      # 8x8 -> 8x8
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),      # 8x8 -> 8x8
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),              # 8x8 -> 4x4

            nn.Flatten(),
            # Calculate the flattened size based on the final spatial dimensions (4x4)
            nn.Linear(256 * 4 * 4, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 10)
        )

    def forward(self, x):
        return self.model(x)

In [5]:
class VGG11(nn.Module):
    def __init__(self):
        super(VGG11, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(512, 4096), nn.ReLU(), nn.Dropout(),
            nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(),
            nn.Linear(4096, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


In [8]:
def train_and_evaluate(model, epochs, model_name):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    accuracies = {}

    for epoch in range(1, epochs + 1):
        model.train()
        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()

        # Evaluation
        if epoch in [5, 50]:
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for data, targets in test_loader:
                    data, targets = data.to(device), targets.to(device)
                    output = model(data)
                    _, predicted = torch.max(output.data, 1)
                    total += targets.size(0)
                    correct += (predicted == targets).sum().item()
            accuracy = 100 * correct / total
            accuracies[epoch] = accuracy
            print(f"{model_name} Epoch {epoch}: Accuracy = {accuracy:.2f}%")

    return accuracies


In [None]:
lenet_acc = train_and_evaluate(LeNet(), 50, "LeNet")

LeNet Epoch 5: Accuracy = 98.39%


In [None]:
alexnet_acc = train_and_evaluate(AlexNet(), 50, "AlexNet")

In [None]:
vgg_acc = train_and_evaluate(VGG11(), 50, "VGG11")