In [19]:
'''

Problem statement :

    Construction of a neural network that fulfills the following conditions :

    - Made using PyTorch
    - Model should be able to recognize handwritten digits
    - Dataloader initialized from MNIST training dataset
    - Neurons use only linear regression 
    - Program only up to the training phase

'''

In [20]:
# library imports
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

In [21]:
# dataset creation
train_dataset = torchvision.datasets.MNIST(root='.', train = True, transform = transforms.ToTensor(), download = True)

# dataloader initialization
train_dataloader = DataLoader(train_dataset, batch_size = 10, shuffle = True)

In [22]:
# neural network definition
model = nn.Sequential(nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, 32), nn.ReLU(), nn.Linear(32, 10), nn.ReLU())

In [23]:
# GPU enable
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [24]:
# neural network training

optimizer = optim.SGD(model.parameters(), lr=0.001)
e_func = nn.MSELoss()

epochs = 32
for i in range(epochs):
    
    j = 0
    lossSum = 0
    
    for images, labels in train_dataloader:
        
        # for dumping tensors to GPU for processing
        images, labels = images.to(device), labels.to(device)
        
        # reshaping & conversion to compatible types 
        images = images.view(images.shape[0], -1)
        labels = labels.to(torch.float32)
        
        # forward pass
        optimizer.zero_grad()
        output = model(images)
        loss = e_func(output, labels)
        
        # average loss calculation variable updation
        j += 1
        lossSum += loss.item()
        
        # backpropagation
        loss.backward()
        optimizer.step()
        
    print('Loss : ', (lossSum / j))
    
print('\n\n-------- Training complete --------\n\n')