In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((227,227)),
    transforms.ToTensor(),
    transforms.Normalize(
        (0.485,0.456,0.406),
        (0.229,0.224,0.225)
    )
])

train_ds = CIFAR100(root="./data", train=True, download=True, transform=transform)
test_ds  = CIFAR100(root="./data", train=False, download=True, transform=transform)

  entry = pickle.load(f, encoding="latin1")


In [3]:
train_dataloader = DataLoader(
    train_ds,
    batch_size = 128,
    shuffle = True,
    num_workers=3
)
test_dataloader = DataLoader(
    test_ds,
    batch_size = 128,
    shuffle = False,
    num_workers=3
)

In [4]:
class AlexNet(nn.Module):
    def __init__(self, num_classes = 100):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride =2, padding=0),
            
            
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2, bias=False),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
            
            
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)    
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=256*6*6, out_features=4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            
            nn.Linear(in_features=4096, out_features=num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x,1)
        x = self.classifier(x)
        return x

In [5]:
model = AlexNet()
x = torch.randn(1,3,227,227)
print(model.features(x).shape)

torch.Size([1, 256, 6, 6])


In [6]:
loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.01,
    momentum=0.9,
    weight_decay=5e-4
)

EPOCHS  = 30

In [8]:
from tqdm import tqdm
train_losses = []
train_accs = []

for EPOCH in range(EPOCHS):

    model.train()

    running_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(train_dataloader, desc=f"Epoch {EPOCH+1}/{EPOCHS}"):

        optimizer.zero_grad()

        outputs = model(images)

        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / len(train_dataloader)
    epoch_acc = correct / total

    train_losses.append(epoch_loss)
    train_accs.append(epoch_acc)

    print(f"Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}")

Epoch 1/30:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 1/30: 100%|██████████| 391/391 [16:00<00:00,  2.46s/it]


Loss: 4.5595 | Acc: 0.0157


Epoch 2/30: 100%|██████████| 391/391 [15:40<00:00,  2.41s/it]


Loss: 3.9912 | Acc: 0.0811


Epoch 3/30: 100%|██████████| 391/391 [15:48<00:00,  2.42s/it]


Loss: 3.5522 | Acc: 0.1541


Epoch 4/30: 100%|██████████| 391/391 [15:53<00:00,  2.44s/it]


Loss: 3.1796 | Acc: 0.2226


Epoch 5/30: 100%|██████████| 391/391 [15:37<00:00,  2.40s/it]


Loss: 2.8747 | Acc: 0.2846


Epoch 6/30: 100%|██████████| 391/391 [15:30<00:00,  2.38s/it]


Loss: 2.5875 | Acc: 0.3419


Epoch 7/30: 100%|██████████| 391/391 [15:21<00:00,  2.36s/it]


Loss: 2.3323 | Acc: 0.3961


Epoch 8/30: 100%|██████████| 391/391 [15:46<00:00,  2.42s/it]


Loss: 2.1160 | Acc: 0.4407


Epoch 9/30: 100%|██████████| 391/391 [15:25<00:00,  2.37s/it]


Loss: 1.9067 | Acc: 0.4860


Epoch 10/30: 100%|██████████| 391/391 [15:31<00:00,  2.38s/it]


Loss: 1.7116 | Acc: 0.5317


Epoch 11/30: 100%|██████████| 391/391 [15:36<00:00,  2.40s/it]


Loss: 1.5314 | Acc: 0.5728


Epoch 12/30: 100%|██████████| 391/391 [15:50<00:00,  2.43s/it]


Loss: 1.3589 | Acc: 0.6145


Epoch 13/30: 100%|██████████| 391/391 [15:34<00:00,  2.39s/it]


Loss: 1.2281 | Acc: 0.6474


Epoch 14/30: 100%|██████████| 391/391 [15:36<00:00,  2.39s/it]


Loss: 1.0929 | Acc: 0.6809


Epoch 15/30: 100%|██████████| 391/391 [15:42<00:00,  2.41s/it]


Loss: 0.9762 | Acc: 0.7106


Epoch 16/30: 100%|██████████| 391/391 [15:52<00:00,  2.44s/it]


Loss: 0.8742 | Acc: 0.7415


Epoch 17/30: 100%|██████████| 391/391 [15:40<00:00,  2.41s/it]


Loss: 0.7939 | Acc: 0.7637


Epoch 18/30: 100%|██████████| 391/391 [15:35<00:00,  2.39s/it]


Loss: 0.7302 | Acc: 0.7785


Epoch 19/30: 100%|██████████| 391/391 [15:41<00:00,  2.41s/it]


Loss: 0.6656 | Acc: 0.8000


Epoch 20/30: 100%|██████████| 391/391 [15:48<00:00,  2.42s/it]


Loss: 0.6031 | Acc: 0.8178


Epoch 21/30: 100%|██████████| 391/391 [15:46<00:00,  2.42s/it]


Loss: 0.5575 | Acc: 0.8318


Epoch 22/30: 100%|██████████| 391/391 [15:38<00:00,  2.40s/it]


Loss: 0.5308 | Acc: 0.8404


Epoch 23/30: 100%|██████████| 391/391 [15:58<00:00,  2.45s/it]


Loss: 0.4833 | Acc: 0.8541


Epoch 24/30: 100%|██████████| 391/391 [15:47<00:00,  2.42s/it]


Loss: 0.4611 | Acc: 0.8614


Epoch 25/30: 100%|██████████| 391/391 [15:55<00:00,  2.44s/it]


Loss: 0.4212 | Acc: 0.8725


Epoch 26/30: 100%|██████████| 391/391 [15:50<00:00,  2.43s/it]


Loss: 0.4241 | Acc: 0.8725


Epoch 27/30: 100%|██████████| 391/391 [16:10<00:00,  2.48s/it]


Loss: 0.3899 | Acc: 0.8813


Epoch 28/30: 100%|██████████| 391/391 [16:02<00:00,  2.46s/it]


Loss: 0.3801 | Acc: 0.8848


Epoch 29/30: 100%|██████████| 391/391 [15:16<00:00,  2.34s/it]


Loss: 0.3599 | Acc: 0.8913


Epoch 30/30: 100%|██████████| 391/391 [15:21<00:00,  2.36s/it]

Loss: 0.3362 | Acc: 0.8986





In [9]:
from tqdm import tqdm

model.eval()   # switch to evaluation mode

test_loss = 0
correct = 0
total = 0

with torch.no_grad():   # disable gradients (important)

    for images, labels in tqdm(test_dataloader, desc="Testing"):

        outputs = model(images)

        loss = loss_fn(outputs, labels)
        test_loss += loss.item()

        _, preds = torch.max(outputs, 1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)


test_loss /= len(test_dataloader)
test_acc = correct / total

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Acc : {test_acc:.4f}")

Testing: 100%|██████████| 79/79 [01:29<00:00,  1.14s/it]

Test Loss: 2.1505
Test Acc : 0.5315





###Model is overfitting !!
Goal was to implement AlexNet from scratch on CIFAR-100 using PyTorch