In [None]:
from dataset import ZIFAR10
from torch.utils.data import DataLoader
from torch import optim
from model import alexnet
import torch.nn as nn
import torch

In [None]:
train_dataset = ZIFAR10(root='data', train=True, download=True)
test_dataset = ZIFAR10(root='data', train=False, download=True)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
for image, label in train_loader:
    img = image[0].permute(1, 2, 0).numpy()
    plt.imshow(img)
    plt.show()
    break

In [None]:
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model = alexnet(3, 10)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), weight_decay=0.0005, lr=0.0001 ,momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [None]:
def train_one_episode(epoch):
    model.train()
    train_loss = 0
    num_train_batches = 0
    for batch in train_loader:
        num_train_batches += 1
        
        data, target = batch
        data, target = data.float().to(device), target.long().to(device)

        optimizer.zero_grad()
        pred = model(data)
        loss = criterion(pred, target)
        train_loss += loss.item()
        loss.backward()

        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         print(f"Epoch {epoch}, Batch {num_train_batches}, {name} grad mean: {param.grad.abs().mean()}")

        optimizer.step()

        if num_train_batches % 100 == 0:
            print(f'Epoch {epoch}, Train Loss: {loss.item()}')

def validate(epoch):
    cum_loss = 0
    num_val_batches = 0

    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            num_val_batches += 1
            data, target = batch
            data, target = data.float().to(device), target.long().to(device)

            pred = model(data)
            loss = criterion(pred, target)
            cum_loss += loss.item()

    print(f"Epoch {epoch}, Validation Loss: {cum_loss/num_val_batches}")    

In [None]:
for epoch in range(num_epochs):
    train_one_episode(epoch)
    validate(epoch)