# MNIST classifier using logistic regression

In [21]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import cv2 as cv
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
import numpy as np

### Dataloaders


In [28]:
### Downloading the data ###
dataset = MNIST(root ='data/', download = True, transform = transforms.ToTensor()) #downloads MNIST dataset, converts PIL images in the dataset to tensors
train_dataset = MNIST(root = 'data', train = False, download = True)
### Splitting the validation and training sets ###
from torch.utils.data import random_split
from torch.utils.data import DataLoader

batch_size = 128
train_ds , val_ds = random_split(dataset, [50000,10000])  # total number of images is 60k, being split to 50k train and 10k validation
train_loader = DataLoader(train_ds , batch_size = batch_size, shuffle = True)
val_loader = DataLoader(val_ds , batch_size = batch_size, shuffle = True)

### Helper functions and variables

In [32]:
import torch.nn as nn
input_size = 28*28
num_classes = 10

In [38]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [33]:
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    optimizer = opt_func(model.parameters(), lr)
    history = [] # for recording epoch-wise results
    
    for epoch in range(epochs):
        
        # Training Phase 
        for batch in train_loader:
            loss = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        # Validation phase
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)

    return history

In [34]:
def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

### Model

In [36]:
class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(input_size, num_classes)
        
    def forward(self, xb):
        xb = xb.reshape(-1, 784)
        out = self.linear(xb)
        return out
    
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss, 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))
    
model = MnistModel()

### Running the model

In [39]:
history1 = fit(20, 0.001, model, train_loader, val_loader)

In [40]:
history1

In [47]:
accuracies = [result['val_acc'] for result in history1]
plt.plot(accuracies , '-x')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('Acc vs epochs plot')

### Testing out the model

In [48]:
test_dataset = MNIST(root='data/',train=False,transform=transforms.ToTensor())


In [49]:
def predict_image(img, model):
    xb = img.unsqueeze(0)
    yb = model(xb)
    _, preds = torch.max(yb, dim=1)
    return preds[0].item()

In [50]:
img, label = test_dataset[0]
plt.imshow(img[0], cmap='gray')
print('Label:', label, ', Predicted:', predict_image(img, model))

In [51]:
img, label = test_dataset[20]
plt.imshow(img[0], cmap='gray')
print('Label:', label, ', Predicted:', predict_image(img, model))

In [52]:
test_loader = DataLoader(test_dataset, batch_size=256)
result = evaluate(model, test_loader)
result

### Saving the weights

In [54]:
torch.save(model.state_dict(), './mnist-logistic.pth')

In [58]:
model2 = MnistModel()
result1 = evaluate(model2, test_loader)
print(result1)
model2.load_state_dict(torch.load('./mnist-logistic.pth'))
test_loader = DataLoader(test_dataset, batch_size=256)
result = evaluate(model2, test_loader)
result                       