In [2]:
from typing import Optional, Callable
from torch.utils.data import DataLoader
from torch.nn.functional import softmax
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torch.optim as Optimizer
import logging

In [30]:
#Simple lenet model
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y


In [36]:
#Encapsulate LeNet
class LeNet(object):
    def __init__(self, n_classes = 10, device = None):
        self.n_classes = n_classes
        if device is None:
            self.device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = Model()
        
    def __train_one_epoch__(self, train_loader, optimizer, criterion,
                        valid_loader = None, epoch = 0, each_batch_idx = 0):
        train_loss = 0
        data_size = 0
        
        for batch_idx, (img, label) in enumerate(train_loader):

            # zero the parameter gradients
            optimizer.zero_grad()

            # run forward
            pred_prob = self.model(img)

            # calculate loss
            loss = criterion(pred_prob, label)

            # calculate gradient (backprop)
            loss.backward()

            # total train loss
            train_loss += loss.item()
            data_size += label.size(0)

            # update weights
            optimizer.step()
            
        if valid_loader:
            acc = self.evaluate(test_loader=valid_loader)
            print('Accuracy on the valid dataset {}'.format(acc))
            
    def train(self, epochs, train_loader, valid_loader = None):
        self.model.train()
        optimizer = optim.SGD(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=0.001, momentum=0.9)

        criterion = nn.CrossEntropyLoss()
        for epoch in range(epochs):
            self.__train_one_epoch__(train_loader=train_loader,
                                   optimizer=optimizer,
                                   criterion=criterion,
                                   valid_loader=valid_loader,
                                   epoch=epoch
                                   )

    def evaluate(self, test_loader):
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (data,labels) in enumerate(test_loader):
                data = data.float()
                outputs = self.model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return 100 * correct / total

    def predict(self, test_loader):
        self.model.eval()
        predict_results = np.empty(shape=(0, 10))
        with torch.no_grad():
            for batch_idx,  (img, label) in enumerate(test_loader):
                outputs = self.model(img)
                outputs = softmax(outputs)
                predict_results = np.concatenate(
                    (predict_results, outputs.cpu().numpy()))
        return predict_results


In [9]:
import torchvision.datasets as datasets
from torchvision import transforms

In [10]:
#Obtain mnist dataset
mnist_train = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), target_transform=None, download=False)
mnist_test = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), target_transform=None, download=False)

In [13]:
#Dataloader of MNIST
from torch.utils.data.sampler import SubsetRandomSampler
random_seed = 123
validation_split = 0.1  
batch_size = 16
dataset_size = len(mnist_train)

indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, sampler=train_sampler)
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, sampler=valid_sampler)


In [37]:
test_model = LeNet(n_classes = 10)
test_model.train(epochs = 1, train_loader = dl_train, valid_loader = None)

In [38]:
acc = test_model.evaluate(test_loader = dl_test)
print(acc)

89.71666666666667
