<a href="https://colab.research.google.com/github/amaydixit11/research_papers/blob/main/AlexNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize])

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=train_transform, download=True)
val_dataset = datasets.CIFAR10(root='./data', train=False, transform=val_transform, download=True)

indices = list(range(len(train_dataset)))
train_indices, val_indices = train_test_split(indices, test_size=0.2, random_state=42)

# samplers
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

# loaders
train_loader = DataLoader(train_dataset, batch_size=128, sampler=train_sampler, num_workers=2)
# val_loader = DataLoader(val_dataset, batch_size=128, sampler=val_sampler, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 35.2MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
class LRN(nn.Module):
    def __init__(self, size=5, alpha=1e-4, beta=0.75, k=2):
        super(LRN, self).__init__()
        self.size = size
        self.alpha = alpha
        self.beta = beta
        self.k = k

    def forward(self, x):
        div = x.pow(2).mean(dim=1, keepdim=True)
        div = (self.k + self.alpha * div).pow(self.beta)
        return x / div


In [None]:
class AlexNet(nn.Module):
  def __init__(self, num_classes):
    super(AlexNet, self).__init__()
    self.conv1 = nn.Conv2d(3, 96, 11, 4, 0)
    self.conv2 = nn.Conv2d(96, 256, 5, 1, 2)
    self.conv3 = nn.Conv2d(256, 384, 3, 1, 1)
    self.conv4 = nn.Conv2d(384, 384, 3, 1, 1)
    self.conv5 = nn.Conv2d(384, 256, 3, 1, 1)

    self.fc6 = nn.Linear(256 * 6 * 6, 4096)
    self.fc7 = nn.Linear(4096, 4096)
    self.fc8 = nn.Linear(4096, num_classes)

    self.maxPool = nn.MaxPool2d(kernel_size = 3, stride = 2)
    self.global_pool = nn.AdaptiveAvgPool2d((6, 6))
    self.do = nn.Dropout(0.5)
    self.relu = nn.ReLU(inplace=True)

    self.features = nn.Sequential(
        self.conv1, self.relu, LRN(), self.maxPool,
        self.conv2, self.relu, LRN(), self.maxPool,
        self.conv3, self.relu,
        self.conv4, self.relu,
        self.conv5, self.relu, self.maxPool
    )
    self.classifier = nn.Sequential(
        self.fc6, self.relu, self.do,
        self.fc7, self.relu, self.do,
        self.fc8
        )

  def forward(self, x):
    out = self.features(x)
    out = self.global_pool(out)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out

  def initialize_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight, mean=0, std=0.01)
            if m.bias is not None:
                nn.init.constant_(m.bias, 1)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.01)
            nn.init.constant_(m.bias, 1)



In [None]:
num_classes = 10
num_epochs = 30
batch_size = 64
learning_rate = 0.005

model = AlexNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)

# Train the model
total_step = len(train_loader)

In [None]:
from tqdm import tqdm

total_step = len(train_loader)

for epoch in range(num_epochs):
    # Training loop
    pbar = tqdm(enumerate(train_loader), total=total_step, desc=f"Epoch {epoch+1}/{num_epochs}")
    for i, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar
        pbar.set_postfix({"Loss": loss.item()})

    # Validation loop
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        torch.cuda.empty_cache()

    print(f"Validation Accuracy on {total} images: {100 * correct / total:.2f} %")


Epoch 1/30: 100%|██████████| 313/313 [01:33<00:00,  3.36it/s, Loss=2.3]


Validation Accuracy on 10000 images: 10.91 %


Epoch 2/30: 100%|██████████| 313/313 [01:42<00:00,  3.07it/s, Loss=2.12]


Validation Accuracy on 10000 images: 22.56 %


Epoch 3/30: 100%|██████████| 313/313 [01:35<00:00,  3.26it/s, Loss=1.85]


Validation Accuracy on 10000 images: 28.86 %


Epoch 4/30: 100%|██████████| 313/313 [01:34<00:00,  3.32it/s, Loss=1.54]


Validation Accuracy on 10000 images: 36.09 %


Epoch 5/30: 100%|██████████| 313/313 [01:34<00:00,  3.30it/s, Loss=1.44]


Validation Accuracy on 10000 images: 40.84 %


Epoch 6/30:  12%|█▏        | 38/313 [00:11<02:02,  2.24it/s, Loss=1.63]

Epoch 1/20: 100%|██████████| 313/313 [01:35<00:00,  3.27it/s, Loss=2.3]
Validation Accuracy on 10000 images: 12.23 %
Epoch 2/20: 100%|██████████| 313/313 [01:34<00:00,  3.32it/s, Loss=2.12]
Validation Accuracy on 10000 images: 21.94 %
Epoch 3/20: 100%|██████████| 313/313 [01:35<00:00,  3.26it/s, Loss=1.81]
Validation Accuracy on 10000 images: 28.26 %
Epoch 4/20: 100%|██████████| 313/313 [01:35<00:00,  3.27it/s, Loss=1.56]
Validation Accuracy on 10000 images: 34.84 %
Epoch 5/20: 100%|██████████| 313/313 [01:35<00:00,  3.29it/s, Loss=1.48]
Validation Accuracy on 10000 images: 38.61 %
Epoch 6/20: 100%|██████████| 313/313 [01:35<00:00,  3.29it/s, Loss=1.52]
Validation Accuracy on 10000 images: 46.43 %
Epoch 7/20: 100%|██████████| 313/313 [01:35<00:00,  3.29it/s, Loss=1.4]
Validation Accuracy on 10000 images: 50.36 %
Epoch 8/20: 100%|██████████| 313/313 [01:35<00:00,  3.27it/s, Loss=1.53]
Validation Accuracy on 10000 images: 53.06 %
Epoch 9/20: 100%|██████████| 313/313 [01:35<00:00,  3.26it/s, Loss=1.15]
Validation Accuracy on 10000 images: 56.58 %
Epoch 10/20: 100%|██████████| 313/313 [01:35<00:00,  3.29it/s, Loss=1.31]
Validation Accuracy on 10000 images: 53.98 %
Epoch 11/20: 100%|██████████| 313/313 [01:34<00:00,  3.30it/s, Loss=1.28]
Validation Accuracy on 10000 images: 58.68 %
Epoch 12/20: 100%|██████████| 313/313 [01:34<00:00,  3.33it/s, Loss=0.983]
Validation Accuracy on 10000 images: 60.24 %
Epoch 13/20: 100%|██████████| 313/313 [01:34<00:00,  3.31it/s, Loss=0.963]
Validation Accuracy on 10000 images: 59.22 %
Epoch 14/20: 100%|██████████| 313/313 [01:35<00:00,  3.26it/s, Loss=1.09]
Validation Accuracy on 10000 images: 63.26 %
Epoch 15/20: 100%|██████████| 313/313 [01:35<00:00,  3.28it/s, Loss=1.05]
Validation Accuracy on 10000 images: 66.60 %
Epoch 16/20: 100%|██████████| 313/313 [01:33<00:00,  3.34it/s, Loss=0.973]
Validation Accuracy on 10000 images: 67.23 %
Epoch 17/20: 100%|██████████| 313/313 [01:35<00:00,  3.29it/s, Loss=0.821]
Validation Accuracy on 10000 images: 68.21 %
Epoch 18/20: 100%|██████████| 313/313 [01:35<00:00,  3.27it/s, Loss=0.829]
Validation Accuracy on 10000 images: 68.63 %
Epoch 19/20: 100%|██████████| 313/313 [01:35<00:00,  3.29it/s, Loss=0.78]
Validation Accuracy on 10000 images: 68.33 %
Epoch 20/20: 100%|██████████| 313/313 [01:35<00:00,  3.27it/s, Loss=0.993]
Validation Accuracy on 10000 images: 71.56 %