In [149]:
import torch
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as fun
import torchvision.transforms as transforms
import numpy as np


In [150]:
dataset_1 = MNIST(root='data/', download=True)


In [151]:
test_dataset = MNIST(root='data/', train=False,)


In [152]:
dataset = MNIST(root='data/',train=True,transform=transforms.ToTensor())
from torch.utils.data import random_split

train_ds, val_ds = random_split(dataset, [55000, 5000])

In [153]:
trainloader = DataLoader(train_ds,10,shuffle=True)
validationloader = DataLoader(val_ds,10,shuffle=True)

In [154]:
def accuracy_fn(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [155]:
class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28,10)
        
    def forward(self, xb):
        xb = xb.reshape(-1, 784)
        out = self.linear(xb)
        return out
    
model = MnistModel()

#Loading trained model parameters
model.load_state_dict(torch.load('model_state_dict'))


<All keys matched successfully>

In [156]:
optimizer = torch.optim.SGD(model.parameters(),0.001)

In [157]:
loss_fn = fun.cross_entropy

In [158]:
def batch_loss_calc(model,xb,yb,opt = None,accuracy_fn = None) :
  preds = model(xb)
  loss = loss_fn(preds,yb)
  if opt is not None :
    loss.backward()  
    opt.step()
    opt.zero_grad()

  accuracy = None
  if accuracy_fn is not None :
    accuracy = accuracy_fn(preds,yb)
  return loss.item() , len(xb) , accuracy




In [159]:
def validate(model,valid_dl) :  
    results = [batch_loss_calc(model,xb,yb,accuracy_fn= accuracy_fn)
                for xb , yb in valid_dl]
    loss , batch_size , accuracy = zip(*results)
    total = np.sum(batch_size)
    average_loss = np.sum(np.multiply(loss,batch_size)) / total
    average_metric = np.sum(np.multiply(accuracy,batch_size)) / total
    return average_loss ,total , average_metric

In [160]:
def trainer(epochs,model,loss_fn,opt,trainloader,validationloader,accuracy = None) :
  temp = []
  for epoch in range(epochs) :
    for xb,yb in trainloader :
      loss,_,_ = batch_loss_calc(model,xb,yb,opt)
    result = validate(model,validationloader)
    current_loss , batch_size, current_accuracy = result
    
    if accuracy is None :
      print(f'Epoch[{epoch+1}] , loss : {current_loss}')
    else :
      print(f'Epoch[{epoch+1}] , loss : {current_loss} , {accuracy.__name__} : {current_accuracy}')
    temp.append(current_accuracy)


  return temp






In [161]:
batch_loss, batch_size, batch_accuracy = validate(model,validationloader)
print(batch_loss,batch_size,batch_accuracy)


0.22824129552021621 5000 0.9349999940395355


In [162]:
trainer(2,model,fun.cross_entropy,optimizer,trainloader,validationloader,accuracy_fn)

Epoch[1] , loss : 0.2287135124793276 , accuracy_fn : 0.9349999922513962
Epoch[2] , loss : 0.2291640617484227 , accuracy_fn : 0.9355999926328659


[0.9349999922513962, 0.9355999926328659]

In [163]:
temp = (trainer(5,model,fun.cross_entropy,optimizer,trainloader,validationloader,accuracy_fn))

Epoch[1] , loss : 0.2295256270878017 , accuracy_fn : 0.9353999933004379
Epoch[2] , loss : 0.22967477270495146 , accuracy_fn : 0.9355999925136567
Epoch[3] , loss : 0.22993027312215417 , accuracy_fn : 0.9353999923467636
Epoch[4] , loss : 0.23030841226503254 , accuracy_fn : 0.934999993443489
Epoch[5] , loss : 0.230582824844867 , accuracy_fn : 0.935199991941452


In [164]:
max = sorted(temp)[-1]
average = sum(temp) / len(temp)
print(f'Max accuracy : {max}')
print(f'Average accuracy : {average}')


Max accuracy : 0.9355999925136567
Average accuracy : 0.9353199927091598


In [165]:
test_dataset = MNIST(root='data/', train=False,transform=transforms.ToTensor())
result_num = 0
for i in test_dataset :
  if torch.argmax(model(i[0])) == i[1] :
    result_num +=1
print(f'Average accuracy for Test data set : {result_num/10000}')


Average accuracy for Test data set : 0.9256


In [166]:
#Saving model parameters
#torch.save(model.state_dict(),'model_state_dict' )