In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_size = 784 # 28x28
num_classes = 10
batch_size = 100

In [6]:
train_dataset = torchvision.datasets.MNIST(root='.', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='.', 
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

In [None]:
class FeedForwardNeuralNet(nn.Module):
    def __init__(self, input_size, num_classes):
        super(FeedForwardNeuralNet, self).__init__()
        self.linear1 = nn.Linear(input_size, 600)
        self.linear2 = nn.Linear(600, 600)
        self.linear3 = nn.Linear(600, num_classes)
        self.leakyRelu = nn.LeakyReLU()
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.leakyRelu(x)
        x = self.linear2(x)
        x = self.leakyRelu(x)
        x = self.linear3(x)
        # no activation or softmax used 
        return x

In [None]:
model = FeedForwardNeuralNet(input_size, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 

In [None]:
def train_model(train_loader, num_epochs, model):
    total_steps = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            # flatten images
            images = images.reshape(-1, 784).to(device)
            labels = labels.to(device)

            # forward pass
            results = model(images)
            loss = criterion(results, labels)

            # backwards and optimise
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i+1) % 200 == 0:
                print(f'Epoch: {epoch+1} out of {num_epochs}, Step: {i+1} out of {total_steps}, Loss: {loss.item():.4f}')
