In [19]:
from IPython.core.debugger import set_trace

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torchvision.datasets as datasets
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision.transforms import Compose

from src.classification.models.alexnet import AlexNet


In [2]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

## Model

### Create Model

In [3]:
alexnet = AlexNet(10)

In [4]:
cuda = torch.device("cuda:2")

In [5]:
alexnet = alexnet.to(cuda)

## Dataset 

### Define Transforms

In [6]:
def to_rgb(img):
    return np.repeat(img[..., None], 3, axis=2)

def to_numpy(img):
    return np.array(img)

def resize(img, size=(224, 224)):
    return cv2.resize(img, size)


def to_torch(img):
    return torch.from_numpy(np.transpose(img, (2, 0, 1)).astype(np.float32))

In [7]:
img_transform = Compose([
    to_numpy,
    to_rgb,
    resize, 
    to_torch,
])
def target_transform(label):
    return torch.Tensor(label)

### Create Dataset

In [8]:
mnist_testset = datasets.MNIST(
    root='./resources/classification/', 
    train=False, 
    download=True, 
    transform=img_transform,
)
mnist_trainset = datasets.MNIST(
    root='./resources/classification/', 
    train=True, 
    download=True, 
    transform=img_transform,
)

### Create Dataloader

In [39]:
train_dl = DataLoader(mnist_trainset, batch_size=32)
test_dl = DataLoader(mnist_testset, batch_size=32)

## Train

### Define Train Loop

In [35]:
def train(model, optimizer, train_dataloader, val_dataloader, loss_fn, epochs):
    for epoch in range(epochs):
        train_one_epoch(model, oprimizer, train_dataloader, val_dataloader, loss_fn)

In [36]:
def train_one_epoch(model, optimizer, train_dataloader, loss_fn, device):
    model.train()
    with tqdm(total=len(train_dataloader)) as tbar:
        for imgs, target_labels in train_dataloader:
            optimizer.zero_grad(set_to_none=True)
            pred_labels = model(imgs.to(device))
            loss_value = loss_fn(pred_labels, target_labels.to(device))
            loss_value.backward()
            optimizer.step()
            tbar.set_postfix({"loss": loss_value})
            tbar.update()
    

In [37]:
optimizer = torch.optim.SGD(alexnet.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [40]:
train_one_epoch(alexnet, optimizer, train_dl, criterion, cuda)

  0%|          | 0/1875 [00:00<?, ?it/s]