In [1]:
import os, shutil, time, sys, pdb
import numpy as np
from collections import defaultdict , OrderedDict

import torch, torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

In [2]:
from torch_trainer.engine import Model
from torch_trainer.metrics import accuracy
from torch_trainer.callbacks import CyclicalLearningRate, ModelCheckpoint
from torch_trainer.callbacks.logger import NeptuneLogger
from torch_trainer.utils.lr_finder import LRFinder
from torch_trainer.optimizers import RAdam, Lookahead

In [3]:
PARAMS = {'HIDDEN_NUM_UNITS': 128}

input_num_units = 28*28
hidden_num_units = PARAMS['HIDDEN_NUM_UNITS']
output_num_units = 10
batch_size = 512

In [4]:
train_iter=DataLoader(torchvision.datasets.MNIST('./', train=True, 
                         transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]), 
                         target_transform=None, download=True), batch_size, shuffle=True)
test_iter=DataLoader(torchvision.datasets.MNIST('./', train=False, 
                         transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]), 
                         target_transform=None, download=True), batch_size, shuffle=True)

In [5]:
class MnistPytorch(Model):
    
    def __init__(self):
        super(MnistPytorch, self).__init__()
        self.linear_1 = nn.Linear(input_num_units, hidden_num_units)
        self.relu = nn.ReLU()
        self.linear_2 = nn.Linear(hidden_num_units, output_num_units)
        
    def forward(self, z):
        z = z.reshape(-1, input_num_units)
        z = self.linear_1(z)
        z = self.relu(z)
        z = self.linear_2(z)
        return z

In [6]:
model = MnistPytorch()
opt = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
model.compile(optimizer=opt, metrics=[accuracy], loss=criterion, clipnorm=3.0)

In [None]:
cpkt = ModelCheckpoint('test.h5', monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True)
cyclic_lr = CyclicalLearningRate(base_lr=1e-4, max_lr=1e-2, auto_find_lr=False)

In [None]:
hist = model.fit(train_iter, epochs=5, val_dataloader=test_iter, callbacks=[cyclic_lr, cpkt])