# Train

In [1]:
import sys
sys.path.append("../")
from image_classification.preprocessors.basic import preprocessor
from image_classification.datasets.fashion_mnist import FashionMNIST
from image_classification.models.vgg import vgg11_bn

In [2]:
import random
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import random_split

## Config

In [3]:
seed = 42
batch_size = 128
data_dir = "../data/fashion_mnist"

## Seed

In [4]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Load Data

In [5]:
dataset = FashionMNIST(is_train=True, data_dir=data_dir, transform=preprocessor)
train_dataset, eval_dataset = random_split(dataset, [50000, 10000])

In [6]:
test_dataset = FashionMNIST(is_train=False, data_dir=data_dir, transform=preprocessor)

In [7]:
len(train_dataset), len(eval_dataset),len(test_dataset)

(50000, 10000, 10000)

In [8]:
train_dl = DataLoader(train_dataset, batch_size, shuffle=True)
eval_dl = DataLoader(eval_dataset, batch_size*2, shuffle=False)
test_dl = DataLoader(test_dataset, batch_size*2, shuffle=False)

In [9]:
batch, labels = next(iter(train_dl))
batch.shape, labels.shape

(torch.Size([128, 1, 28, 28]), torch.Size([128]))

## Load Model

In [10]:
class CNN(nn.Module):
    
    def __init__(self, in_channels: int, out_dim: int, device: torch.device):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)
        self.drop = nn.Dropout2d(0.25)
        self.fc2 = nn.Linear(in_features=600, out_features=120)
        self.fc3 = nn.Linear(in_features=120, out_features=out_dim)
        self.device = device
        self.to(device)
        
    def forward(self, x):
        x = x.to(self.device)
        out = self.conv1(x)
        out = self.conv2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.drop(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
model = CNN(in_channels=1, out_dim=10, device=device)

In [13]:
model(batch).shape

torch.Size([128, 10])

## Training

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001, betas=(0.9,0.999), eps=1e-9)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 4, gamma=0.1, verbose=False)
###Best accuracy 0.9349, tensor(0.9385)

In [15]:
num_epochs = 2
examples_seen = 0
eval_interval = 100
model.train()
model.zero_grad()

In [16]:
for epoch in range(num_epochs):
    for i, (x, y) in enumerate(train_dl):
        examples_seen += batch_size
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i%eval_interval==0:
            losses, accuracies = [], []
            for x, y in eval_dl:
                with torch.no_grad():
                    y_pred = model(x)
                    loss = criterion(y_pred, y).item()
                    accuracy = (y_pred.argmax(axis=1)==y).float().mean().item()
                losses.append(loss)
                accuracies.append(accuracy)
            eval_loss = np.mean(losses)
            eval_accuracy = np.mean(accuracies)
            print(
                f"Iteration#{i} "
                f"Examples seen: {examples_seen}\t"
                f"Eval accuracy: {eval_accuracy}\t"
                f"Eval loss: {eval_loss}\t"
            )

Iteration#0 Examples seen: 128	Eval accuracy: 0.20107421875	Eval loss: 3.389490455389023	
Iteration#100 Examples seen: 12928	Eval accuracy: 0.84892578125	Eval loss: 0.4264045380055904	
Iteration#200 Examples seen: 25728	Eval accuracy: 0.83671875	Eval loss: 0.47547646760940554	
Iteration#300 Examples seen: 38528	Eval accuracy: 0.86201171875	Eval loss: 0.380259071290493	
Iteration#0 Examples seen: 50176	Eval accuracy: 0.87236328125	Eval loss: 0.3612319914624095	
Iteration#100 Examples seen: 62976	Eval accuracy: 0.88583984375	Eval loss: 0.3280178092420101	
Iteration#200 Examples seen: 75776	Eval accuracy: 0.862890625	Eval loss: 0.37957348823547366	
Iteration#300 Examples seen: 88576	Eval accuracy: 0.8919921875	Eval loss: 0.31033915579319	
