<a href="https://colab.research.google.com/github/KwonDoRyoung/ABRLaboratory/blob/main/0722/3_run_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import modules

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.datasets as datasets
import torch.cuda as cuda
import numpy as np

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

# Define MLP network

In [None]:
# define MLP network with one hidden layer (original version)
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten() # flatten the input tensor as a 1D vector ((28, 28) -> (784))
        self.input_layer = nn.Linear(28*28, 512)
        self.hidden_layer = nn.Linear(512, 256)
        self.output_layer = nn.Linear(256, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.flatten(x)
        h = self.relu(self.input_layer(x))
        h = self.relu(self.hidden_layer(h))
        y = self.output_layer(h)
        return y

# Define run_mnist function

In [None]:
def run_mnist(
    training=True,
    save_path='model.pt',
    epochs=10,
    random_seed=0xAB,
    valid_ratio=0.1
):
    
    device = 'cuda' if cuda.is_available() else 'cpu'
    myMLP = MLP().to(device)
    
    # train mode
    if training:
        
        # load checkpoint:
        if os.path.exists(save_path):
            myMLP.load_state_dict(torch.load(save_path))
        
        # load or download MNIST datasets
        train_data = datasets.MNIST(root='./data', train=True, download=True, transform=T.ToTensor())
        valid_data = datasets.MNIST(root='./data', train=True, download=True, transform=T.ToTensor())
        
        # fix random seed
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        
        # split given train data into train and valid dataset
        num_train = len(train_data)
        indices = list(range(num_train))
        num_valid = int(valid_ratio * num_train)
        
        train_idx, valid_idx = indices[num_valid:], indices[:num_valid]

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)
        
        # create data loaders to feed data into our model
        batch_size = 64
        train_dataloader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
        valid_dataloader = DataLoader(valid_data, batch_size=batch_size, sampler=valid_sampler)
        
        # define a loss function and an optimizer
        loss_fn = nn.CrossEntropyLoss(reduction='mean')
        optimizer = torch.optim.SGD(myMLP.parameters(), lr=5e-3)
        
        report_interval = 100
        max_valid_acc = 0
        
        ndata = len(train_dataloader.dataset)
        print('training starts!')

        for e in range(epochs):
            print(f'\nepoch {e+1}\n------------------------------')
            myMLP.train() # train mode

            for b, (X, y) in enumerate(train_dataloader):
                X, y = X.to(device), y.to(device) # input and target to device(gpu)

                prediction = myMLP(X) # forward pass
                train_loss = loss_fn(prediction, y) # calculate the loss 

                optimizer.zero_grad() # clear gradients
                train_loss.backward() # backpropagation
                optimizer.step() # update the parameters

                if b % report_interval == 0: # track the training
                    train_loss, current = train_loss.item(), b * len(X) 
                    print(f'[{current:>5d}/{ndata:>5d}]  train loss: {train_loss:>7f}  ', end="")

                    myMLP.eval()
                    with torch.no_grad():
                        valid_ndata = 0
                        valid_nbatch = len(valid_dataloader)
                        valid_loss, valid_correct = 0, 0
                        for X, y in valid_dataloader:
                            X, y = X.to(device), y.to(device)
                            prediction = myMLP(X)
                            valid_loss += loss_fn(prediction, y).item() # add up the loss
                            valid_correct += (prediction.argmax(1) == y).type(torch.float).sum().item() # add up the correct predictions
                            valid_ndata += len(X)
                        valid_loss /= valid_nbatch
                        valid_correct /= valid_ndata
                        print(f"valid accuracy: {(100*valid_correct):>0.1f}%, valid loss: {valid_loss:>8f}")

                    if max_valid_acc < valid_correct:
                        torch.save(myMLP.state_dict(), save_path)
                        max_valid_acc = valid_correct

        print('\ntraining is finished!')
    
    # evaluation mode
    else:
        
        if os.path.exists(save_path):
            myMLP.load_state_dict(torch.load(save_path))
        else:
            print(f'There is no checkpoint filename {save_path}')
            sys.exit()
        
        # load or download MNIST datasets
        test_data = datasets.MNIST(root='./data', train=False, download=True, transform=T.ToTensor())
        
        # create data loaders to feed data into our model
        batch_size = 64
        test_dataloader = DataLoader(test_data, batch_size=batch_size)
        
        # test the MLP model
        ndata = len(test_dataloader.dataset)
        nbatch = len(test_dataloader)
        myMLP.eval() # test mode
        correct = 0 # initialize

        with torch.no_grad():
            for X, y in test_dataloader:
                X, y = X.to(device), y.to(device)
                prediction = myMLP(X)
                correct += (prediction.argmax(1) == y).type(torch.float).sum().item() # add up the correct predictions
        correct /= ndata # accuracy for all data
        print(f"Test accuracy: {(100*correct):>0.1f}%")

# Training

In [None]:
run_mnist(
    training=True,
    save_path='model.pt',
    epochs=1,
    random_seed=0xAB,
    valid_ratio=0.1
)

training starts!

epoch 1
------------------------------
[    0/60000]  train loss: 2.301766  valid accuracy: 12.1%, valid loss: 2.301702
[ 6400/60000]  train loss: 2.286403  valid accuracy: 29.4%, valid loss: 2.281255
[12800/60000]  train loss: 2.252296  valid accuracy: 43.5%, valid loss: 2.258803
[19200/60000]  train loss: 2.221064  valid accuracy: 52.7%, valid loss: 2.231353
[25600/60000]  train loss: 2.196161  valid accuracy: 60.3%, valid loss: 2.196320
[32000/60000]  train loss: 2.176315  valid accuracy: 62.7%, valid loss: 2.149452
[38400/60000]  train loss: 2.094090  valid accuracy: 64.8%, valid loss: 2.087826
[44800/60000]  train loss: 2.024556  valid accuracy: 67.4%, valid loss: 2.002415
[51200/60000]  train loss: 1.911689  valid accuracy: 68.6%, valid loss: 1.893471

training is finished!


In [None]:
for i in range(2):
    run_mnist(
        training=True,
        save_path='model{}.pt'.format(i),
        epochs=1,
        random_seed=i,
        valid_ratio=0.1
    )

training starts!

epoch 1
------------------------------
[    0/60000]  train loss: 2.304593  valid accuracy: 9.9%, valid loss: 2.304173
[ 6400/60000]  train loss: 2.281907  valid accuracy: 24.6%, valid loss: 2.282630
[12800/60000]  train loss: 2.259383  valid accuracy: 46.2%, valid loss: 2.259212
[19200/60000]  train loss: 2.237567  valid accuracy: 55.7%, valid loss: 2.230620
[25600/60000]  train loss: 2.208418  valid accuracy: 61.5%, valid loss: 2.195406
[32000/60000]  train loss: 2.166166  valid accuracy: 66.3%, valid loss: 2.148903
[38400/60000]  train loss: 2.081265  valid accuracy: 67.2%, valid loss: 2.084625
[44800/60000]  train loss: 1.993808  valid accuracy: 68.5%, valid loss: 1.998226
[51200/60000]  train loss: 1.961805  valid accuracy: 71.0%, valid loss: 1.885194

training is finished!
training starts!

epoch 1
------------------------------
[    0/60000]  train loss: 2.303794  valid accuracy: 8.2%, valid loss: 2.307722
[ 6400/60000]  train loss: 2.286244  valid accuracy: 12

# Evaluation

In [None]:
run_mnist(
    training=False,
    save_path='model.pt'
)

Test accuracy: 67.7%
