In [1]:
import pickle
import os.path
import argparse
from matplotlib import  pyplot as plt
from utils import *
from cnnMNIST import CNNMNIST
import random

import torch
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import torch.backends.cudnn as cudnn


dataDir = f'../db'
modelDir = f'./model'
logDir = f'./log'

In [2]:
torch.manual_seed(1) 
random.seed(1) 
# reference for RandomCrop and RandomHorizontalFlip
# https://stackoverflow.com/questions/51677788/data-augmentation-in-pytorch
transform_train = transforms.Compose([
    transforms.RandomCrop(28, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trainset = torchvision.datasets.MNIST(dataDir, train=True,  download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(dataDir, train=False,  download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)
nTrainSamples, width, height = trainset.data.shape
nTestSamples, width, height = testset.data.shape
print(f'# train samples: {nTrainSamples} | # test samples:{nTestSamples}')
print(f'per image size: {width}*{height}')

# train samples: 60000 | # test samples:10000
per image size: 28*28


In [3]:
net = CNNMNIST()
criterion = torch.nn.CrossEntropyLoss()
netname=f'mnist-cnn'
# choose optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4, nesterov=True)
logFilePath= f'{logDir}/{netname}'
logger = Logger(logFilePath)
criterion = torch.nn.CrossEntropyLoss()
checkpointPath = f'{modelDir}/{netname}-checkpoint.pth.tar'
netclf = TrainAndTest(net, trainloader, testloader, 
                           criterion, optimizer, netname=netname)
netclf.build(start_epoch=0, total_epochs=50, checkpointPath=checkpointPath, 
                           logger=logger, modelDir=modelDir, vectorize=False)