In [1]:
import torch

class TinyModel(torch.nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_dim, output_dim)
        self.softmax = torch.nn.Softmax(dim=0)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x


In [2]:
import torchvision.datasets as datasets
import numpy as np
from torch.utils.data import Dataset, DataLoader

class MNISTFlattenData(Dataset):
    
    def __init__(self, path):
        mnist_trainset = datasets.MNIST(root=path, train=True, download=True, transform=None)
        self.num_samples = len(mnist_trainset)
        self.X = np.stack([np.array(mnist_trainset[i][0]).flatten() for i in range(self.num_samples)], axis=0) / 255.
        self.X = self.X.astype(np.float32)
        self.Y = [mnist_trainset[i][1] for i in range(self.num_samples)]
        
    def __getitem__(self, index):
        return self.X[index], self.Y[index]
    
    def __len__(self):
        return self.num_samples
    
dataset = MNISTFlattenData('./data')

In [8]:
def accuracy(model, dataset):
    pred_ind = np.argmax(model(torch.from_numpy(dataset.X)).detach().cpu().numpy(), axis=1)
    return (pred_ind == dataset.Y).mean()


def train_one_epoch(model, dataloader, optimizer, loss_fn, epoch_index):
    
    running_loss = 0.
    num_samples = 0
    for i, data in enumerate(dataloader):
        # Preparing the step
        inputs, labels = data
        optimizer.zero_grad()
        # Performing forward and backward pass
        loss = loss_fn(model(inputs), labels)
        loss.backward()
        optimizer.step()
        # Gather data
        running_loss += loss.item() * inputs.shape[0]
        num_samples += inputs.shape[0]
        
    last_loss = running_loss / num_samples # loss per batch
    return last_loss


def train(model, dataset, batch_size, learning_rate, num_epochs):
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    loss_fn = torch.nn.CrossEntropyLoss()
    cur_accuracy = accuracy(model, dataset)
    print(f'Iteration #{0}: Accuracy={cur_accuracy:.3f}.')
    for epoch_index in range(num_epochs):
        cur_epoch_loss = train_one_epoch(model, iter(dataloader), optimizer, loss_fn, epoch_index)
        cur_accuracy = accuracy(model, dataset)
        print(f'Iteration #{epoch_index + 1}: Loss={cur_epoch_loss:.5f}, Accuracy={cur_accuracy:.3f}.')

In [12]:
tinymodel = TinyModel(784, 64, 10)
train(tinymodel, dataset, 20, 0.02, 20)

Iteration #0: Accuracy=0.107.
Iteration #1: Loss=2.01160, Accuracy=0.843.
Iteration #2: Loss=1.95965, Accuracy=0.859.
Iteration #3: Loss=1.95337, Accuracy=0.872.
Iteration #4: Loss=1.95098, Accuracy=0.888.
Iteration #5: Loss=1.94807, Accuracy=0.901.
Iteration #6: Loss=1.94931, Accuracy=0.912.
Iteration #7: Loss=1.94627, Accuracy=0.919.
Iteration #8: Loss=1.94378, Accuracy=0.930.
Iteration #9: Loss=1.94427, Accuracy=0.934.
Iteration #10: Loss=1.94199, Accuracy=0.939.
Iteration #11: Loss=1.94057, Accuracy=0.945.
Iteration #12: Loss=1.94162, Accuracy=0.946.
Iteration #13: Loss=1.94259, Accuracy=0.950.
Iteration #14: Loss=1.93948, Accuracy=0.951.
Iteration #15: Loss=1.94082, Accuracy=0.954.
Iteration #16: Loss=1.93969, Accuracy=0.957.
Iteration #17: Loss=1.94010, Accuracy=0.957.
Iteration #18: Loss=1.93938, Accuracy=0.959.
Iteration #19: Loss=1.93866, Accuracy=0.960.
Iteration #20: Loss=1.93880, Accuracy=0.961.
