In [19]:
import torch
import torchvision
import torchvision.transforms as transforms

from lenet import LeNet
from vgg import VGGNet
from resnet import ResNet
from utils import Trainer

In [20]:
# configurable parameters, can as needed

# skip if loading existing model file
skip_training = True

# choose only from: 'lenet', 'vggnet' or 'resnet'
network_archi = 'resnet'
assert network_archi in ('lenet', 'vggnet', 'resnet')

# set data & model dir paths
data_dir = 'data/'
save_path = 'models/' + network_archi + '.pth'

In [21]:
# additional settings

# set to as appropriate: on either cpu or gpu
if skip_training:
   device_type = 'cpu'
elif torch.cuda.is_available():
    device_type = 'cuda:0'
else:
    device_type = 'cpu'

# assign device
device = torch.device(device_type)
print("Running in {}".format(device_type))

Running in cpu


In [22]:
# initialize the network based on selected network archi type

if network_archi == 'lenet':
    net = LeNet(1, 10)
elif network_archi == 'vggnet':
    net = VGGNet(1, 10)
elif network_archi == 'resnet':
    net = ResNet(1, 10)

In [23]:
# prepare dataset

# list image transformations to perform
transform = transforms.Compose([
    transforms.ToTensor(),  # transform to tensor
    transforms.Normalize((0.5,), (0.5,))  # scale to [-1, 1]
])

# load fashion mnist dataset (available in torch)
trainset = torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform)
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# use dataloader to load and iterate on dataset easily
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False)

In [24]:
# run training by calling the training method from utils

if not skip_training:
    train = Trainer(net, trainloader, testloader, device_type, archi_type=network_archi)
    train.train()


In [25]:
# save the model

if not skip_training:
    torch.save(net.state_dict(), save_path)
    print('Model saved to: {}'.format(save_path))

In [26]:
# load the model
if skip_training:
    net.load_state_dict(torch.load(save_path, map_location=lambda storage, loc: storage))
    print('Model loaded from: {}'.format(save_path))
    net.to(device)
    net.eval()

Model loaded from: models/resnet.pth


In [27]:
# evaluate and compute accuracy

if skip_training:
    train = Trainer(net, trainloader, testloader, device)
accuracy = train.test()
print('Accuracy of the network on the test images: {}'.format(accuracy))

Accuracy of the network on the test images: 0.9275
