In [5]:
import numpy as np
import torch
import os

In [6]:
path = os.path.join('data', 'mnist', 'mnist.pkl.gz')

In [9]:
# Import raw data
import gzip
import pickle

with gzip.open(path, 'rb') as f:
    ((x_train, y_train), (x_valid, y_valid), unknown) = pickle.load(f, encoding = 'latin-1')

In [25]:
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid) )

In [22]:
class Mnist_model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = torch.nn.Linear(784, 128)
        self.hidden2 = torch.nn.Linear(128, 256)
        self.out = torch.nn.Linear(256, 10)
        
    def forward(self, x):
        x = torch.nn.functional.relu(self.hidden1(x))
        x = torch.nn.functional.relu(self.hidden2(x))
        x = self.out(x)
        return x


In [26]:
# Create train and test data
train_ds = torch.utils.data.TensorDataset(x_train, y_train)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size = 128, shuffle = True)

valid_ds = torch.utils.data.TensorDataset(x_valid, y_valid)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size = 256)

In [28]:
loss_func = torch.nn.CrossEntropyLoss()

In [39]:
def get_data(train_ds, valid_ds):
    return (torch.utils.data.DataLoader(train_ds, batch_size = 128, shuffle = True), 
           torch.utils.data.DataLoader(valid_ds, batch_size = 256))

def get_model():
    model = Mnist_model()
    return model

def get_optim(model):
    return torch.optim.SGD(model.parameters(), lr = 0.001)

def loss_batch(model, loss_func, xb, yb, opt = None):
    # also optimizer
    loss = loss_func(model(xb), yb)
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
    
    return loss.item(), len(xb)

def fit(steps, model, loss_func, train_dl, valid_dl):
    for step in range(steps):
        # set model in train mode
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt = get_optim(model))
        
        model.eval()
        with torch.no_grad():
            losses, nums = zip( *[loss_batch(model, loss_func, xb, yb) 
                                  for xb, yb in valid_dl])
            valid_loss = np.sum(np.multiply(losses, nums) / np.sum(nums))
            
        print('Step: {}, loss: {}'.format(step, valid_loss))

In [40]:
train_dl, valid_dl = get_data(train_ds, valid_ds)
model = get_model()
fit(25, model, loss_func, train_dl, valid_dl)

Step: 0, loss: 2.2936904724121097
Step: 1, loss: 2.2830302268981937
Step: 2, loss: 2.271281169128418
Step: 3, loss: 2.2573572257995607
Step: 4, loss: 2.240075008010864
Step: 5, loss: 2.21809147644043
Step: 6, loss: 2.189735460281372
Step: 7, loss: 2.1529442615509033
Step: 8, loss: 2.104997978210449
Step: 9, loss: 2.042851379776001
Step: 10, loss: 1.9635436347961428
Step: 11, loss: 1.8654697584152222
Step: 12, loss: 1.7495662075042726
Step: 13, loss: 1.6200777475357055
Step: 14, loss: 1.483881675720215
Step: 15, loss: 1.349219493675232
Step: 16, loss: 1.2235611444473267
Step: 17, loss: 1.1116140420913696
Step: 18, loss: 1.0146974981307983
Step: 19, loss: 0.9322955263137818
Step: 20, loss: 0.8629566877365112
Step: 21, loss: 0.8038945945739746
Step: 22, loss: 0.7538351017951965
Step: 23, loss: 0.7107317025184632
Step: 24, loss: 0.6738278924942016


In [41]:
train_dl

<torch.utils.data.dataloader.DataLoader at 0x12f6fafd0>