In [19]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from AlexNet import AlexNet
from torchvision.transforms import ToTensor
from tqdm import tqdm

In [2]:
model = AlexNet(10)
print(model)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=2)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): ReLU(inplace=True)
    (6): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=2)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=F

### Downloading the Dataset

In [6]:
training_data = datasets.FashionMNIST(
    root = "data",
    train = True,
    download = True,
    transform = ToTensor()
)

100%|██████████| 26.4M/26.4M [00:48<00:00, 544kB/s] 
100%|██████████| 29.5k/29.5k [00:00<00:00, 172kB/s]
100%|██████████| 4.42M/4.42M [00:09<00:00, 464kB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 16.3MB/s]


In [7]:
test_data = datasets.FashionMNIST(
    root = "data",
    train =  False,
    download = True,
    transform=ToTensor()
)

### Loading the Dataset

In [8]:
train_loader = DataLoader(training_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

### Optimizer and Hyperparameters

In [20]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model = torch.compile(model)

In [13]:
num_epochs = 10
best_acc = 0.0
best_model_path = "best_model.pth"

### Compiled Train and Test Loop

In [None]:
@torch.compile
def train_step(images, labels):
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    return loss

@torch.compile
def validate_step(images, labels):
    with torch.no_grad():
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
    return correct

In [None]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)
        loss = train_step(images, labels) # Using compiled train_step
        running_loss += loss.item()

    # Validation
    model.eval()
    total_val_correct = 0
    total_val_samples = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            total_val_samples += labels.size(0)
            total_val_correct += validate_step(images, labels) # Using compiled validate_step

    epoch_acc = 100 * total_val_correct / total_val_samples
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {epoch_acc:.2f}%")

    # Save best model
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved with accuracy: {best_acc:.2f}%")

print(f"Training finished. Best accuracy: {best_acc:.2f}%")

### Testing the Model

In [None]:
model.load_state_dict(torch.load(best_model_path))
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Final Test Accuracy: {100 * correct / total:.2f}%")