In [1]:
'''Import packages'''

import numpy as np
import time
import argparse
import os.path
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn as nn
import wandb ##weight and bias
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [2]:
def _load_data(DATA_PATH, batch_size):
    ## for training
    rotation = 15
    train_trans = transforms.Compose([transforms.RandomRotation(rotation),\
                                      transforms.RandomHorizontalFlip(),\
                                      transforms.ToTensor(),\
                                      transforms.Normalize((0.5), (0.5))])
    train_dataset = torchvision.datasets.MNIST(root=DATA_PATH, download=True,\
                                               train=True, transform=train_trans)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,\
                              shuffle=True, num_workers=0)
    ## for testing
    test_trans = transforms.Compose([transforms.ToTensor(),\
                                     transforms.Normalize((0.5), (0.5))])
    test_dataset = torchvision.datasets.MNIST(root=DATA_PATH,\
                                              download=True, train=False, transform=test_trans)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,\
                             shuffle=False, num_workers=0)
    
    return train_loader, test_loader

In [11]:
'''Fun: write the MLP model'''
class MLPModel(nn.Module):
    """docstring for ClassName"""
    def __init__(self,):
        super(MLPModel, self).__init__()
        ##-----------------------------------------------------------
        ## define the model architecture here
        ## MNIST image input size batch * 28 * 28 (one input channel)
        ##-----------------------------------------------------------
        
        ## Write code about three MLP layers below
        self.mlp = nn.Sequential(nn.Linear(28*28,100),
                                nn.ReLU(),
                                nn.Dropout(0.2),
                                nn.Linear(100,50),
                                nn.ReLU(),
                                nn.Linear(50,10)
                                )
    '''feed features to the model'''
    def forward(self, x):
        ## write flatten tensor code below [I have done it]
        x = torch.flatten(x,1)
        ## ---------------------------------------------------
        ## write code about MLP predict results
        ## ---------------------------------------------------
        result = self.mlp(x)
        
        return result

In [12]:
## compute accuracy of training and testing
def _compute_counts(y_pred, y_batch, mode='train'):
    return (y_pred==y_batch).sum().item()

In [13]:
def adjust_learning_rate(learning_rate, optimizer, epoch, decay):
    """initial LR decayed by 1/10 every args.lr epochs"""
    lr = learning_rate
    if (epoch > 5):
        lr = 0.001
    if (epoch >= 10):
        lr = 0.0001
    if (epoch > 20):
        lr = 0.00001
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [14]:
def _save_checkpoint(ckp_path, model, epoch, optimizer, global_step):
    ## save checkpoint to ckp_path: 'checkpoint/step_100.pt'
    ckp_path = ckp_path + 'ckp_{}.pt'.format(epoch+1) 
    checkpoint = {'epoch': epoch,
                  'global_step': global_step,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict()}
    torch.save(checkpoint, ckp_path)

In [15]:
def main():
    ## choose cpu or gpu
    seed = 1
    torch.manual_seed(seed)
    ## numpy.rand(1), 1.1
    ## choose GPU id
    gpu_id = 0  ## 1, 2, 3,4
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        device = torch.device('cuda', gpu_id)
    else:
        device = torch.device('cpu')
    print("device: ", device)
    ## random seed for cuda
    if use_cuda:
        torch.cuda.manual_seed(72)
    
    ## initialize hyper-parameters
    num_epoches = 10
    decay = 0.01
    learning_rate = 0.0001
    batch_size = 50 #100
    ckp_path = 'checkpoint/'
    
    ## step 1: Data loader to load MNIST data
    DATA_PATH = "./data/"
    train_loader, test_loader=_load_data(DATA_PATH, batch_size)
    ##-------------------------------------------------------
    ## Step 2: load the MLP model in model.py file
    ##-------------------------------------------------------
    model =  MLPModel()
    ## load model to gpu or cpu
    model.to(device)
    
    ## --------------------------------------------------
    ## Step 3: define the Opimization method and LOSS FUNCTION: cross-entropy
    ## --------------------------------------------------
    optimizer = optim.Adam(model.parameters(),lr=learning_rate)  ## optimizer
    loss_fun = nn.CrossEntropyLoss()    ## cross entropy loss
    
    ## ---------------------------------------
    ## load checkpoint below
    ## ---------------------------------------
    
    ##  model training
    iteration = 0
    if True:
        model = model.train() ## model training
        for epoch in range(num_epoches): #10-50
            ## learning rate
            adjust_learning_rate(learning_rate, optimizer, epoch, decay)
            for batch_id, (x_batch,y_labels) in enumerate(train_loader):
                
                iteration += 1
                x_batch,y_labels = Variable(x_batch).to(device), Variable(y_labels).to(device)
                
                ## feed input data x into model
                output_y = model(x_batch)
                ##--------------------------------------------------------------
                ## Step 4: compute loss between ground truth and predicted result
                ##---------------------------------------------------------------
                loss = loss_fun(output_y, y_labels)
                
                ##----------------------------------------------
                ## Step 5: write back propagation steps below
                ##----------------------------------------------
                optimizer.zero_grad()
                loss.backward()
                optimizer.step() # update params
                
                ##---------------------------------------------------------
                ## Step 6: get the predict result and then compute accuracy
                ##---------------------------------------------------------
                y_pred = torch.argmax(output_y.data, 1)
                accy = _compute_counts(y_pred, y_labels)/batch_size
                ##----------------------------------------------------------
                ## Step 7: print loss values [I have done it]
                ##----------------------------------------------------------
                if iteration%10==0:
                    print('iter: {} loss: {}, accy: {}'.format(iteration, loss.item(), accy))
                    wandb.log({'iter': iteration, 'loss': loss.item()})
                    wandb.log({'iter': iteration, 'accy': accy})
                    
            ##---------------------------------------------------
            ##    save checkpoint below
            ##---------------------------------------------------
            _save_checkpoint(ckp_path, model, epoch, optimizer, iteration)
    
    ##------------------------------------
    ##    model testing code below
    ##------------------------------------
    total = 0
    accy_count = 0
    model.eval() ##test
    with torch.no_grad(): ## no gradient update
        for batch_id, (x_batch,y_labels) in enumerate(test_loader):
            x_batch, y_labels = Variable(x_batch).to(device), Variable(y_labels).to(device)
            ##---------------------------------------
            ## Step 8: write the predict result below
            ##---------------------------------------
            output_y = model(x_batch)
            y_pred = torch.argmax(output_y.data, 1)
            
            ##--------------------------------------------------
            ## Step 9: computing the test accuracy
            ##---------------------------------------------------
            total += len(y_labels)
            accy_count += _compute_counts(y_pred, y_labels)
    accy = accy_count/total
    print("testing accy: ", accy)
            

In [None]:
with wandb.init(project='MLP', name='MLP_demo'):
    main()

device:  cpu
iter: 10 loss: 2.2467539310455322, accy: 0.16
iter: 20 loss: 2.26986026763916, accy: 0.16
iter: 30 loss: 2.2397360801696777, accy: 0.24
iter: 40 loss: 2.248927593231201, accy: 0.22
iter: 50 loss: 2.2449376583099365, accy: 0.2
iter: 60 loss: 2.188424825668335, accy: 0.24
iter: 70 loss: 2.1766555309295654, accy: 0.32
iter: 80 loss: 2.1025702953338623, accy: 0.4
iter: 90 loss: 2.14996337890625, accy: 0.26
iter: 100 loss: 2.0548348426818848, accy: 0.38
iter: 110 loss: 2.0163094997406006, accy: 0.4
iter: 120 loss: 2.0130138397216797, accy: 0.42
iter: 130 loss: 1.9852347373962402, accy: 0.42
iter: 140 loss: 1.854904294013977, accy: 0.5
iter: 150 loss: 1.8025190830230713, accy: 0.5
iter: 160 loss: 1.8375836610794067, accy: 0.42
iter: 170 loss: 1.8970354795455933, accy: 0.4
iter: 180 loss: 1.7094659805297852, accy: 0.64
iter: 190 loss: 1.7193937301635742, accy: 0.58
iter: 200 loss: 1.8723993301391602, accy: 0.34
iter: 210 loss: 1.589950442314148, accy: 0.66
iter: 220 loss: 1.53352

iter: 1770 loss: 0.6567776203155518, accy: 0.74
iter: 1780 loss: 0.6194265484809875, accy: 0.82
iter: 1790 loss: 0.6990736126899719, accy: 0.76
iter: 1800 loss: 0.6239172220230103, accy: 0.76
iter: 1810 loss: 0.6141528487205505, accy: 0.82
iter: 1820 loss: 0.6803627610206604, accy: 0.8
iter: 1830 loss: 0.6330241560935974, accy: 0.82
iter: 1840 loss: 0.6751653552055359, accy: 0.74
iter: 1850 loss: 0.9091954231262207, accy: 0.74
iter: 1860 loss: 0.7397760152816772, accy: 0.78
iter: 1870 loss: 0.5643408298492432, accy: 0.76
iter: 1880 loss: 0.6540694236755371, accy: 0.78
iter: 1890 loss: 0.6912649273872375, accy: 0.72
iter: 1900 loss: 0.8282075524330139, accy: 0.72
iter: 1910 loss: 0.7204115986824036, accy: 0.74
iter: 1920 loss: 0.9447862505912781, accy: 0.6
iter: 1930 loss: 0.7248539924621582, accy: 0.84
iter: 1940 loss: 0.8394641280174255, accy: 0.76
iter: 1950 loss: 0.5565263628959656, accy: 0.82
iter: 1960 loss: 0.653151273727417, accy: 0.78
iter: 1970 loss: 0.666596531867981, accy: 0

iter: 3510 loss: 0.4595865309238434, accy: 0.9
iter: 3520 loss: 0.5014176368713379, accy: 0.82
iter: 3530 loss: 0.835774302482605, accy: 0.76
iter: 3540 loss: 0.4900963604450226, accy: 0.86
iter: 3550 loss: 0.7871101498603821, accy: 0.76
iter: 3560 loss: 0.4468817412853241, accy: 0.9
iter: 3570 loss: 0.8099664449691772, accy: 0.74
iter: 3580 loss: 0.53631192445755, accy: 0.82
iter: 3590 loss: 0.528367817401886, accy: 0.86
iter: 3600 loss: 0.7351015210151672, accy: 0.76
iter: 3610 loss: 0.6437206864356995, accy: 0.84
iter: 3620 loss: 0.41012537479400635, accy: 0.84
iter: 3630 loss: 0.6580483913421631, accy: 0.82
iter: 3640 loss: 0.5768964886665344, accy: 0.86
iter: 3650 loss: 0.5376125574111938, accy: 0.78
iter: 3660 loss: 0.6230846643447876, accy: 0.82
iter: 3670 loss: 0.6075738668441772, accy: 0.88
iter: 3680 loss: 0.5069350004196167, accy: 0.84
iter: 3690 loss: 0.6758291721343994, accy: 0.8
iter: 3700 loss: 0.46932682394981384, accy: 0.86
iter: 3710 loss: 0.625130295753479, accy: 0.8

iter: 5230 loss: 0.5712704658508301, accy: 0.8
iter: 5240 loss: 0.58003830909729, accy: 0.78
iter: 5250 loss: 0.5547457933425903, accy: 0.86
iter: 5260 loss: 0.3831980228424072, accy: 0.9
iter: 5270 loss: 0.42537567019462585, accy: 0.8
iter: 5280 loss: 0.5064650177955627, accy: 0.86
iter: 5290 loss: 0.4596623182296753, accy: 0.84
iter: 5300 loss: 0.43958580493927, accy: 0.8
iter: 5310 loss: 0.38412028551101685, accy: 0.86
iter: 5320 loss: 0.5923264026641846, accy: 0.78
iter: 5330 loss: 0.5382394790649414, accy: 0.82
iter: 5340 loss: 0.4790576696395874, accy: 0.84
iter: 5350 loss: 0.4102603495121002, accy: 0.9
iter: 5360 loss: 0.42622727155685425, accy: 0.84
iter: 5370 loss: 0.4234161078929901, accy: 0.9
iter: 5380 loss: 0.4767766296863556, accy: 0.82
iter: 5390 loss: 0.3377172350883484, accy: 0.92
iter: 5400 loss: 0.6707130670547485, accy: 0.8
iter: 5410 loss: 0.3497048318386078, accy: 0.9
iter: 5420 loss: 0.6402935981750488, accy: 0.8
iter: 5430 loss: 0.5620410442352295, accy: 0.78
it