In [1]:
import torch
from torch import nn
import torchvision
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import v2

**Enable cuda if available**

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))

NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [4]:
writer = SummaryWriter()

**Use Imagenette instead of Imagenet**

In [5]:
img_transforms = v2.Compose([
    v2.Resize((256, 256)),
    v2.RandomCrop((224, 224)),
    v2.RandomHorizontalFlip(p=0.5),
    ToTensor()
])

In [6]:
imagenette_train = torchvision.datasets.Imagenette(root='data', split='train', download=True, transform=img_transforms)
imagenette_test = torchvision.datasets.Imagenette(root='data', split='val', download=True, transform=img_transforms)
len(imagenette_train), len(imagenette_test)

(9469, 3925)

In [7]:
imagenette_train[0][0].dtype

torch.float32

**Define AlexNet joined architecture**

In [None]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model =  nn.Sequential(nn.Conv2d(3,96,11, stride=4),
                                    nn.ReLU(),
                                    nn.MaxPool2d(3, stride=2),
                                    nn.Conv2d(96,256,5, padding=2),
                                    nn.ReLU(),
                                    nn.MaxPool2d(3, stride=2),
                                    nn.Conv2d(256,384,3, padding=1),
                                    nn.ReLU(),
                                    nn.Conv2d(384,384,3, padding=1),
                                    nn.ReLU(),
                                    nn.Conv2d(384,256,3, padding=1),
                                    nn.ReLU(),
                                    nn.MaxPool2d(3, stride=2),
                                    nn.Flatten(),
                                    nn.Linear(6400, 4096),
                                    nn.ReLU(),
                                    nn.Dropout(p=0.5),
                                    nn.Linear(4096, 4096),
                                    nn.ReLU(),
                                    nn.Dropout(p=0.5),
                                    nn.Linear(4096, 10),
                                    )

        # initialize based on AlexNet Paper
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)

    def forward(self, x):
        x = self.model(x)
        return x

**Move model to specified device**

In [9]:
model = AlexNet().to(device=device)
model

AlexNet(
  (model): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (13): Flatten(start_dim=1, end_dim=-1)
    (14): Linear(in_features=6400, out_features=4096, bias=True)
    (15): ReLU()
    (16): Dropout(p=0.5, inplace=False)
    (17): Linear(in_features=4096, out_features=4096, bias=True)
    (18): ReLU()
    (19): Dropout(p

**Hyperparameters**

In [None]:
epochs = 50
learning_rate = 0.01
batch_size = 128
momentum = 0.9
weight_decay = 0.0005

**Loss Function + Optimizer**

In [11]:
dataloader_train = DataLoader(imagenette_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(imagenette_test, batch_size=batch_size, shuffle=True)

num_train_batches = len(dataloader_train)
num_test_batches = len(dataloader_test)

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
num_train_batches, num_test_batches

(74, 31)

**Train + Validation**

In [12]:
for epoch in range(epochs):
    train_loss = 0
    train_acc = 0

    model.train() # training model
    for batch_idx, (train_features, train_labels) in enumerate(dataloader_train):
        train_features = train_features.to(device)
        train_labels = train_labels.to(device) # move to device

        optimizer.zero_grad()

        predictions = model(train_features)
        predictions_labels = torch.argmax(predictions, dim=1)

        train_batch_acc = (predictions_labels == train_labels).sum().item() / train_features.shape[0]

        train_batch_loss = loss(predictions, train_labels)
        train_batch_loss.backward()

        optimizer.step()

        train_loss += train_batch_loss.item()
        train_acc += train_batch_acc

    val_loss = 0
    val_acc = 0

    model.eval() # evaluation mode
    with torch.no_grad():
        for batch_idx, (test_features, test_labels) in enumerate(dataloader_test):
            test_features = test_features.to(device)
            test_labels = test_labels.to(device) # move to device

            predictions = model(test_features)
            predictions_labels = torch.argmax(predictions, dim=1)

            test_batch_acc = (predictions_labels == test_labels).sum().item() / test_features.shape[0]
            test_batch_loss = loss(predictions, test_labels)

            val_loss += test_batch_loss.item()
            val_acc += test_batch_acc

    train_loss /= num_train_batches
    train_acc /= num_train_batches

    val_loss /= num_test_batches
    val_acc /= num_test_batches

    writer.add_scalar("Loss/train", train_loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)

    writer.add_scalar("Loss/val", val_loss, epoch)
    writer.add_scalar('Accuracy/val', val_acc, epoch)

KeyboardInterrupt: 

In [None]:
writer.flush()