In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

In [None]:
from torch.utils.data import random_split

batch_size = 128
random_seed = 1447
torch.manual_seed(random_seed)

transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.1309, std=0.2893)
])

dataset = torchvision.datasets.MNIST(root="./data/", train=True,
                                     download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root="./data/", train=True,
                                      download=True, transform=transform)

val_per = 0.1
val_size = int(len(dataset) * 0.2)
train_size = len(dataset) - val_size
train_set, val_set = random_split(dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size,
                                          shuffle=False, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                         shuffle=False, num_workers=4)

print("Train set:      ", len(train_set))
print("Validation set: ", len(val_set))
print("Test set:       ", len(test_set))

In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def show_batch(dl):
    for images, labels in dl:
        fig, ax = plt.subplots(figsize=(12, 6))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images, nrow=16).permute(1, 2, 0))
        break

show_batch(train_loader)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

In [None]:
model = CNN().to(device)
loss_fn = F.cross_entropy

In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

@torch.no_grad()
def evaluate(val_loader):
    model.eval()

    running_loss = 0.0
    running_accuracy = 0.0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation loop"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)

            loss = loss_fn(outputs, labels)
            acc = accuracy(outputs, labels)

            running_loss += loss.item() * labels.size(0)
            running_accuracy += acc.item()

    val_loss = running_loss / len(val_loader.dataset)
    val_acc = running_accuracy /len(val_loader)

    return {'val_loss': val_loss, 'val_acc': val_acc}

evaluate(val_loader)

In [None]:
import timeit

num_epochs = 5
lr = 0.001
# optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# optimizer = torch.optim.Adafactor(model.parameters(), lr=lr)
# optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
# optimizer = torch.optim.RAdam(model.parameters(), lr=lr)
# optimizer = torch.optim.NAdam(model.parameters(), lr=lr)

def train(epochs, train_loader, val_loader, optimizer):
    start = timeit.default_timer()

    history = []
    batch_losses = []

    for epoch in range(epochs):
        # Training Phase 
        model.train()
        running_loss = 0.0
        
        for images, labels in tqdm(train_loader, desc="Training loop", disable=True):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = loss_fn(outputs, labels)
            running_loss += loss.item() * labels.size(0)

            loss.backward()
            optimizer.step()
            
            optimizer.zero_grad()

            batch_losses.append(loss.item() * labels.size(0))

        train_loss = running_loss / len(train_loader.dataset)

        # Validation phase
        model.eval()

        running_loss = 0.0
        running_accuracy = 0.0

        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc="Validation loop", disable=True):
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)

                loss = loss_fn(outputs, labels)
                acc = accuracy(outputs, labels)

                running_loss += loss.item() * labels.size(0)
                running_accuracy += acc.item() * labels.size(0)

        val_loss = running_loss / len(val_loader.dataset)
        val_acc = running_accuracy /len(val_loader.dataset)
        
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch + 1, train_loss, val_loss, val_acc))

        history.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_acc': val_acc
        })

    stop = timeit.default_timer()
    execution_time = stop - start
    print(execution_time)
    return history, batch_losses

history, batch_losses = train(num_epochs, train_loader, val_loader, optimizer)

In [None]:
def plot_batch_losses(batch_losses, path=None):
    plt.figure().set_figwidth(15)
    plt.plot(batch_losses, '-b')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('Training Loss')
    plt.show()

    if type(path) == str:
        plt.savefig(path)

plot_batch_losses(batch_losses)

In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

@torch.no_grad()
def test(model, test_loader):
    model.eval()

    total_accuracy = 0.0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Test loop"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)

            loss = loss_fn(outputs, labels)
            acc = accuracy(outputs, labels)

            total_accuracy += acc.item()

    total_accuracy = total_accuracy / len(test_loader)

    return total_accuracy

pred_accuracy = test(model, val_loader)
print(pred_accuracy)