In [177]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn

In [178]:
print(torch.cuda.is_available())
device = 'cuda'

True


In [179]:
transform = transforms.ToTensor()

training = datasets.MNIST('./data', train=True, download=True, transform=transform)
testing = datasets.MNIST('./data', train=False, download=True, transform=transform)

In [180]:
from torch.utils.data import random_split

train_set, dev_set = random_split(training, [50000, 10000])

In [181]:
training_loader = DataLoader(train_set, batch_size=32, shuffle=True)
dev_loader = DataLoader(dev_set, batch_size=32)
testing_loader = DataLoader(testing, batch_size=32)

In [182]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_layers, output_size):
        super().__init__()
        self.layers = [input_size] + hidden_layers + [output_size]
        self.network_layers = nn.ModuleList()
        self.relu = nn.ReLU()

        for i in range(len(self.layers)-1):
            self.network_layers.append(nn.Linear(self.layers[i], self.layers[i+1]))

    def forward(self, x):
        for i in range(len(self.network_layers)-1): # all layers except output
            x = self.network_layers[i](x)
            x = self.relu(x)
        #output
        x = self.network_layers[-1](x)
        return x

In [183]:
net = NeuralNetwork(784, [256], 10).cuda()
print(net.layers)

[784, 256, 10]


In [184]:
def dataset_accuracy(model, loader):
    model.eval()
    with torch.no_grad():
        correct_predictions = 0
        for X, y in loader:

            X = X.reshape(-1, 784).cuda()
            y = y.cuda()

            output = model(X)
            correct_predictions += torch.sum(torch.max(output, 1)[1] == y).item()
        model.train()
        return correct_predictions / len(loader.dataset)

In [185]:
import torch.optim

loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)


epochs = 10
for epoch in range(epochs):

    epoch_loss = 0
    correct_predictions = 0

    for X, y in training_loader:

        X = X.reshape(-1, 784).cuda()
        y = y.cuda()

        optimizer.zero_grad()
        output = net(X)
        loss = loss_function(output, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        correct_predictions += torch.sum(torch.max(output, dim=1)[1] == y).item()

    loss_score = epoch_loss/len(training_loader)
    training_accuracy = correct_predictions/len(training_loader.dataset)
    dev_accuracy = dataset_accuracy(net, dev_loader)

    print(f"epoch: {epoch+1}  loss = {loss_score}, training accuracy = {training_accuracy}, dev accuracy = {dev_accuracy}")

print('------------------')
print(f'Test set accuracy = {dataset_accuracy(net, testing_loader)}')

epoch: 1  loss = 0.2744993265618118, training accuracy = 0.92132, dev accuracy = 0.9579
epoch: 2  loss = 0.11222708852239482, training accuracy = 0.96638, dev accuracy = 0.9672
epoch: 3  loss = 0.07280453393814974, training accuracy = 0.97858, dev accuracy = 0.9748
epoch: 4  loss = 0.05470272211834785, training accuracy = 0.98312, dev accuracy = 0.9767
epoch: 5  loss = 0.039427189923428425, training accuracy = 0.98774, dev accuracy = 0.9785
epoch: 6  loss = 0.02934659968609731, training accuracy = 0.99094, dev accuracy = 0.9784
epoch: 7  loss = 0.022920542332746034, training accuracy = 0.99298, dev accuracy = 0.9779
epoch: 8  loss = 0.017281454610703118, training accuracy = 0.99518, dev accuracy = 0.9808
epoch: 9  loss = 0.015822990708298176, training accuracy = 0.99516, dev accuracy = 0.9804
epoch: 10  loss = 0.011320926309271529, training accuracy = 0.99646, dev accuracy = 0.979
------------------
Test set accuracy = 0.9776


In [None]:
# import matplotlib.pyplot as plt
# import numpy as np

# def visualize_mnist(loader, num_images = 5):

#     random_choices = np.random.randint(1, 60000, size=num_images)

#     features_list = []
#     labels_list = []

#     for choice in random_choices:
#         X, y = loader[choice]
#         X = X.reshape(-1, 784)

#         features_list.append(X)
#         labels_list.append(y)

#     fig, axes = plt.subplots(2, num_images, figsize=(20, 6))
#     for i in range(num_images):
#         axes[0, i].imshow(X, cmap='gray')

# visualize_mnist(testing_loader)