In [22]:
import os
import numpy as np
import pandas as pd

import torch 
import torchvision



In [23]:
train_data = torchvision.datasets.MNIST(root="./../0_sample_data/", train=True, transform=torchvision.transforms.ToTensor(), download=False)
test_data = torchvision.datasets.MNIST(root="./../0_sample_data/", train=False, transform=torchvision.transforms.ToTensor(), download=False)

batch_size=64
train_dataload = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
test_dataload = torch.utils.data.DataLoader(train_data, shuffle=False, batch_size=batch_size)

In [24]:
def init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        
def accuracy(y_hat, y):
    return (y_hat.argmax(axis=1) == y).sum()


def evaluate_acc(test_dataload):
    
    accs = []
    lengs = []

    for batch_X, batch_y in test_dataload:
        
        batch_y_hat = model(batch_X)
        acc = (batch_y_hat.argmax(axis=1) == batch_y).sum()
        leng = len(batch_y_hat)
        
        accs.append(acc)
        lengs.append(leng)
        
    return sum(accs)/sum(lengs)


def train_epoch(train_dataload, test_dataload):

    for batch_X, batch_y in train_dataload:
        
        batch_y_hat = model(batch_X)
        batch_loss = loss(batch_y_hat, batch_y)
        
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()


def train(epoch_num):
    for i in range(epoch_num):
        train_epoch(train_dataload, test_dataload)
        accuracy = evaluate_acc(test_dataload)
        print("Epoch {}, Accuracy: {:.2f}%".format(i, accuracy*100))



In [25]:
model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 10),
    )
model.apply(init_weights)

loss = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adagrad(params=model.parameters(), lr=0.1)

train(epoch_num=10)


Epoch 0, Accuracy: 96.55%
Epoch 1, Accuracy: 97.82%
Epoch 2, Accuracy: 98.57%
Epoch 3, Accuracy: 98.91%
Epoch 4, Accuracy: 99.17%
Epoch 5, Accuracy: 99.36%
Epoch 6, Accuracy: 99.50%
Epoch 7, Accuracy: 99.57%
Epoch 8, Accuracy: 99.73%
Epoch 9, Accuracy: 99.78%


In [26]:
torch.save(model.state_dict(),"./pytorch_model.pkl")
