# This is an exmaple of auto training Pytorch Model

In [1]:
import os
import sys
sys.path.append('../')

import numpy
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms

from tools.modelTool import ModelTool

In [2]:
def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 1
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


cfgs = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B':
    [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
        512, 512, 512, 'M'
    ],
    'E': [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512,
        512, 'M', 512, 512, 512, 512, 'M'
    ],
}

In [3]:
model = models.VGG(make_layers(cfgs['A']), 10)
model_name = 'vgg'
file_path = '../tmp/vgg.pth'
mt = ModelTool(model, model_name, file_path)

train_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
])
test_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
])

train_set = datasets.MNIST(
    '../tmp/dataset/mnist',
    train=True,
    transform=train_transform,
    download=True,
)
test_set = datasets.MNIST('../tmp/dataset/mnist',
                               train=False,
                               transform=test_transform,
                               download=True)

train_loader = DataLoader(train_set, 128)
test_loader = DataLoader(test_set, 128)

In [4]:
mt.auto_train(train_loader, test_loader, epoch_max=5, save_epoch=5)

Epoch:  1
===> Save model to  ../tmp/vgg.pth
Epoch:  2
===> Save model to  ../tmp/vgg.pth
Epoch:  3
===> Save model to  ../tmp/vgg.pth
Epoch:  4
===> Save model to  ../tmp/vgg.pth
Epoch:  5
===> Save model to  ../tmp/vgg.pth
===> Save model to  ../tmp/vgg-epoch-5.pth
