<a href="https://colab.research.google.com/github/amaydixit11/research_papers/blob/main/AlexNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize])

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=train_transform, download=True)
val_dataset = datasets.CIFAR10(root='./data', train=False, transform=val_transform, download=True)

indices = list(range(len(train_dataset)))
train_indices, val_indices = train_test_split(indices, test_size=0.2, random_state=42)

# samplers
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

# loaders
train_loader = DataLoader(train_dataset, batch_size=128, sampler=train_sampler, num_workers=2)
# val_loader = DataLoader(val_dataset, batch_size=128, sampler=val_sampler, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:19<00:00, 8.95MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
class LRN(nn.Module):
    def __init__(self, size=5, alpha=1e-4, beta=0.75, k=2):
        super(LRN, self).__init__()
        self.size = size
        self.alpha = alpha
        self.beta = beta
        self.k = k

    def forward(self, x):
        div = x.pow(2).mean(dim=1, keepdim=True)
        div = (self.k + self.alpha * div).pow(self.beta)
        return x / div


In [4]:
class AlexNet(nn.Module):
  def __init__(self, num_classes):
    super(AlexNet, self).__init__()
    self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0)
    self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2)
    self.conv3 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)
    self.conv4 = nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1)
    self.conv5 = nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1)

    self.fc6 = nn.Linear(256 * 6 * 6, 4096)
    self.fc7 = nn.Linear(4096, 4096)
    self.fc8 = nn.Linear(4096, num_classes)

    self.maxPool = nn.MaxPool2d(kernel_size = 3, stride = 2)
    self.global_pool = nn.AdaptiveAvgPool2d((6, 6))
    self.do = nn.Dropout(0.5)
    self.relu = nn.ReLU(inplace=True)

    self.features = nn.Sequential(
        self.conv1, self.relu, LRN(), self.maxPool,
        self.conv2, self.relu, LRN(), self.maxPool,
        self.conv3, self.relu,
        self.conv4, self.relu,
        self.conv5, self.relu, self.maxPool
    )
    self.classifier = nn.Sequential(
        self.fc6, self.relu, self.do,
        self.fc7, self.relu, self.do,
        self.fc8
        )

  def forward(self, x):
    out = self.features(x)
    out = self.global_pool(out)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out

  def initialize_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, mean=0, std=0.01)
            if m.bias is not None:
                nn.init.constant_(m.bias, 1)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.01)
            nn.init.constant_(m.bias, 1)



In [5]:
num_classes = 10
num_epochs = 30
batch_size = 64
learning_rate = 0.01

model = AlexNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.0005, momentum = 0.9)

# Train the model
total_step = len(train_loader)

In [6]:
from tqdm import tqdm

total_step = len(train_loader)

for epoch in range(num_epochs):
    # Training loop
    pbar = tqdm(enumerate(train_loader), total=total_step, desc=f"Epoch {epoch+1}/{num_epochs}")
    for i, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar
        pbar.set_postfix({"Loss": loss.item()})

    # Validation loop
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        torch.cuda.empty_cache()

    print(f"Validation Accuracy on {total} images: {100 * correct / total:.2f} %")


Epoch 1/30: 100%|██████████| 313/313 [01:35<00:00,  3.28it/s, Loss=2.09]


Validation Accuracy on 10000 images: 22.80 %


Epoch 2/30: 100%|██████████| 313/313 [01:32<00:00,  3.38it/s, Loss=1.76]


Validation Accuracy on 10000 images: 38.00 %


Epoch 3/30: 100%|██████████| 313/313 [01:33<00:00,  3.36it/s, Loss=1.5]


Validation Accuracy on 10000 images: 45.00 %


Epoch 4/30: 100%|██████████| 313/313 [01:32<00:00,  3.37it/s, Loss=1.2]


Validation Accuracy on 10000 images: 54.50 %


Epoch 5/30: 100%|██████████| 313/313 [01:37<00:00,  3.19it/s, Loss=1.14]


Validation Accuracy on 10000 images: 62.15 %


Epoch 6/30: 100%|██████████| 313/313 [01:32<00:00,  3.40it/s, Loss=1.17]


Validation Accuracy on 10000 images: 63.47 %


Epoch 7/30: 100%|██████████| 313/313 [01:35<00:00,  3.29it/s, Loss=0.915]


Validation Accuracy on 10000 images: 69.95 %


Epoch 8/30: 100%|██████████| 313/313 [01:33<00:00,  3.35it/s, Loss=0.867]


Validation Accuracy on 10000 images: 73.24 %


Epoch 9/30: 100%|██████████| 313/313 [01:33<00:00,  3.36it/s, Loss=0.799]


Validation Accuracy on 10000 images: 75.17 %


Epoch 10/30: 100%|██████████| 313/313 [01:33<00:00,  3.34it/s, Loss=0.655]


Validation Accuracy on 10000 images: 75.44 %


Epoch 11/30: 100%|██████████| 313/313 [01:33<00:00,  3.36it/s, Loss=0.494]


Validation Accuracy on 10000 images: 77.80 %


Epoch 12/30: 100%|██████████| 313/313 [01:33<00:00,  3.34it/s, Loss=0.384]


Validation Accuracy on 10000 images: 77.12 %


Epoch 13/30: 100%|██████████| 313/313 [01:42<00:00,  3.05it/s, Loss=0.456]


Validation Accuracy on 10000 images: 77.79 %


Epoch 14/30: 100%|██████████| 313/313 [01:33<00:00,  3.34it/s, Loss=0.463]


Validation Accuracy on 10000 images: 78.77 %


Epoch 15/30: 100%|██████████| 313/313 [01:33<00:00,  3.36it/s, Loss=0.554]


Validation Accuracy on 10000 images: 80.74 %


Epoch 16/30: 100%|██████████| 313/313 [01:33<00:00,  3.34it/s, Loss=0.433]


Validation Accuracy on 10000 images: 80.97 %


Epoch 17/30: 100%|██████████| 313/313 [01:33<00:00,  3.34it/s, Loss=0.41]


Validation Accuracy on 10000 images: 81.28 %


Epoch 18/30: 100%|██████████| 313/313 [01:32<00:00,  3.38it/s, Loss=0.447]


Validation Accuracy on 10000 images: 80.39 %


Epoch 19/30: 100%|██████████| 313/313 [01:33<00:00,  3.36it/s, Loss=0.359]


Validation Accuracy on 10000 images: 81.79 %


Epoch 20/30: 100%|██████████| 313/313 [01:34<00:00,  3.32it/s, Loss=0.449]


Validation Accuracy on 10000 images: 83.37 %


Epoch 21/30: 100%|██████████| 313/313 [01:32<00:00,  3.38it/s, Loss=0.452]


Validation Accuracy on 10000 images: 84.13 %


Epoch 22/30: 100%|██████████| 313/313 [01:33<00:00,  3.35it/s, Loss=0.235]


Validation Accuracy on 10000 images: 83.53 %


Epoch 23/30: 100%|██████████| 313/313 [01:35<00:00,  3.29it/s, Loss=0.258]


Validation Accuracy on 10000 images: 83.20 %


Epoch 24/30: 100%|██████████| 313/313 [01:33<00:00,  3.36it/s, Loss=0.239]


Validation Accuracy on 10000 images: 84.02 %


Epoch 25/30: 100%|██████████| 313/313 [01:37<00:00,  3.22it/s, Loss=0.162]


Validation Accuracy on 10000 images: 83.90 %


Epoch 26/30: 100%|██████████| 313/313 [01:32<00:00,  3.38it/s, Loss=0.245]


Validation Accuracy on 10000 images: 84.81 %


Epoch 27/30: 100%|██████████| 313/313 [01:34<00:00,  3.31it/s, Loss=0.213]


Validation Accuracy on 10000 images: 84.33 %


Epoch 28/30: 100%|██████████| 313/313 [01:36<00:00,  3.24it/s, Loss=0.136]


Validation Accuracy on 10000 images: 84.31 %


Epoch 29/30: 100%|██████████| 313/313 [01:34<00:00,  3.30it/s, Loss=0.238]


Validation Accuracy on 10000 images: 83.67 %


Epoch 30/30: 100%|██████████| 313/313 [01:33<00:00,  3.34it/s, Loss=0.182]


Validation Accuracy on 10000 images: 84.17 %
