In [None]:
import pickle
import os.path
import argparse
from matplotlib import  pyplot as plt
from utils import *

import torch
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import torch.backends.cudnn as cudnn


projPath = '.'
dataDir = f'{projPath}/db'
modelDir = f'{projPath}/model'
logDir = f'{projPath}/log'
# # General setups
# parser = argparse.ArgumentParser(description='PyTorch Training')
# parser.add_argument('--total_epochs', '-e', default=10, type=int, help='total training epoch')
# parser.add_argument('--batch_size', '-b', default=32, type=int, help='batch size')
# parser.add_argument('--checkpoint', '-c', default=None, type=str, help='resume from checkpoint')
# args = parser.parse_args()
# batch_size = args.batch_size
# total_epochs = args.total_epochs

# MNIST
## Data preparation

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])
trainset = torchvision.datasets.MNIST(dataDir, train=True,  download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(dataDir, train=False,  download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)
nTrainSamples, width, height = trainset.data.shape
nTestSamples, width, height = testset.data.shape
print(f'# train samples: {nTrainSamples} | # test samples:{nTestSamples}')
print(f'per image size: {width}*{height}')

# train samples: 60000 | # test samples:10000
per image size: 28*28


## Train neural nets

In [5]:
net = CNNMNIST()
netname=f'mnist-cnn'
# choose optimizer
optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
logFilePath= f'{logDir}/{netname}'
logger = Logger(logFilePath)
criterion = torch.nn.CrossEntropyLoss()
checkpointPath = f'{modelDir}/{netname}-checkpoint.pth.tar'
netclf = TrainAndTest(net, trainloader, testloader, 
                           criterion, optimizer, netname=netname)
netclf.build(start_epoch=0, total_epochs=20, checkpointPath=None, 
                           logger=logger, modelDir=modelDir, vectorize=False)

# CIFAR - 10

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(dataDir, train=True,  download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(dataDir, train=False,  download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)
nTrainSamples, width, height, channel = trainset.data.shape
nTestSamples, width, height, channel = testset.data.shape
print(f'# train samples: {nTrainSamples} | # test samples:{nTestSamples}')
print(f'per image size: {width}*{height} | per image channel:{channel}')

# train samples: 50000 | # test samples:10000
per image size: 32*32 | per image channel:3


In [4]:
net = CNNCIFAR10()
netname=f'cifar10-cnn'
# choose optimizer
optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
logFilePath= f'{logDir}/{netname}'
logger = Logger(logFilePath)
criterion = torch.nn.CrossEntropyLoss()
checkpointPath = f'{modelDir}/{netname}-checkpoint.pth.tar'
netclf = TrainAndTest(net, trainloader, testloader, 
                           criterion, optimizer, netname=netname)
netclf.build(start_epoch=0, total_epochs=50, checkpointPath=None, 
                           logger=logger, modelDir=modelDir, vectorize=False)
# netclf.build(start_epoch=0, total_epochs=70, checkpointPath=checkpointPath, 
#                            logger=logger, modelDir=modelDir, vectorize=False)

In [5]:
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
#                                           shuffle=True, num_workers=2)
# x = iter(trainloader).next()
# net = CNNCIFAR10()
# net(x[0])