In [0]:
#@title Default title text

import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision import datasets,transforms
import torch.nn as nn
import torch.optim as optim
import seaborn as sns
import mlflow
import mlflow.pytorch

#PARAMS


In [0]:
class Params(object):
    def __init__(self,batch_size,epochs,seed,log_interval):
        self.batch_size=batch_size
        self.epochs=epochs
        self.seed=seed
        self.log_interval=log_interval
args=Params(256,4,0,20)

#DATASET


In [14]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])

trainset=datasets.MNIST('../data',train=True,download=True,transform=transform)
testset=datasets.MNIST('../data',train=False,download=True,transform=transform)

trainloader=torch.utils.data.DataLoader(trainset,shuffle=True,batch_size=args.batch_size)
testloader=torch.utils.data.DataLoader(testset,shuffle=False,batch_size=args.batch_size)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz


 99%|█████████████████████████████████████████████████████████████████████████████▍| 9.84M/9.91M [00:44<00:00, 616kB/s]

Extracting ../data\MNIST\raw\train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data\MNIST\raw\train-labels-idx1-ubyte.gz



0.00B [00:00, ?B/s]
  0%|                                                                                      | 0.00/28.9k [00:01<?, ?B/s]
 57%|███████████████████████████████████████████▋                                 | 16.4k/28.9k [00:01<00:00, 54.4kB/s]
32.8kB [00:01, 22.4kB/s]                                                                                               

Extracting ../data\MNIST\raw\train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data\MNIST\raw\t10k-images-idx3-ubyte.gz



0.00B [00:00, ?B/s]
  0%|                                                                                      | 0.00/1.65M [00:00<?, ?B/s]
  1%|▊                                                                            | 16.4k/1.65M [00:01<00:32, 50.6kB/s]
  2%|█▌                                                                           | 32.8k/1.65M [00:01<00:30, 52.2kB/s]
  3%|██▎                                                                          | 49.2k/1.65M [00:01<00:29, 54.6kB/s]
  4%|███                                                                          | 65.5k/1.65M [00:01<00:28, 54.9kB/s]
  5%|███▊                                                                         | 81.9k/1.65M [00:02<00:29, 52.7kB/s]
  6%|████▌                                                                        | 98.3k/1.65M [00:02<00:23, 65.4kB/s]
  6%|█████                                                                         | 106k/1.65M [00:02<00:29, 53.0kB/s]
  7%|█████▊        

Extracting ../data\MNIST\raw\t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz




0.00B [00:00, ?B/s]

  0%|                                                                                      | 0.00/4.54k [00:00<?, ?B/s]

8.19kB [00:00, 8.49kB/s]                                                                                               

Extracting ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Processing...
Done!


#MODEL


In [0]:
class Model(nn.Module):
    def __init__(self,nH=32):
        super(Model,self).__init__()
        self.classifier=nn.Sequential(
            nn.Linear(784,nH),
            nn.ReLU(),
            nn.Linear(nH,10)
        )
    
    def forward(self,x):
        x=x.view(x.size(0),-1)
        x=self.classifier(x)
        return x
    

In [0]:
def train(epoch):
    model.train()
    
    for batch_id,data in enumerate(trainloader):
        inputs,labels=data
        opt.zero_grad()
        outputs=model(inputs)
        loss=loss_fn(outputs,labels)
        loss.backward()
        opt.step()
        
        if batch_id % args.log_interval ==0:
            pos=epoch* len(trainloader) + batch_id
            mlflow.log_metric('train_loss',loss.data.item()/len(inputs)*1000)
            
            print('Train Epoch: {} [{}/{} ( {:0.2f} % )] \t Loss : {:.3f}'.format(
            epoch,batch_id * len(inputs),len(trainloader.dataset),100.*batch_id/len(trainloader),loss.data.item()))

In [0]:
def test(epoch):
    model.eval()
    test_loss=0
    correct=0
    confusion_matrix=np.zeros([10,10])
    
    with torch.no_grad():
        for inputs,labesl in testloader:
            outputs=model(inputs)
            test_loss+=loss_fn(ouputs,labels).data.item()
            pred=outputs.data.max(1)[1]
            correct +=pred.eq(labels.data).sum().item()
            
            for x,y, in zip(pred.numpy(),labels.numpy()):
                confusion_matrix[x][y] +=1
        test_loss /=len(testloader.dataset)
        test_accuracy=100.0 * correct/len(testloader.dataset)
        
        pos=(epoch+1)*len(trainloader)
        mlflow.log_metric('test_loss',test_loss*1000)
        mlflow.log_metric('test accuracy',test_accuracy)
        
        print("\n Test Set: Average Loss: {:.4f}, Accuracy : {}/{} ({:.0f}%\n".format(test_loss,correct,len(testloader.dataset),test_accuracy))
        
        
        if epoch==args.epoch:
            classes=np.arange(10)
            fig,ax=plt.subplots()
            im=ax.imshow(confusion_matrix,interpolation='nearest',cmap=plt.cm.Blues)
            ax.figure.colorbar(im,ax=ax)
            ax.set(xticks=np.arange(confusion_matrix.shape[1]),
                                    yticks=np.arange(confusion_matrix.shape[0]),
                                                    xticklabels=classes,yticklabels=classes,
                                                    ylabel='True Label',
                                                    xlabel='Predicted Label',
                                                    title="Epoch %d"% epoch)
            thres = confusion_matrix.max() /2
            for i in range(confusion_matrix.shape[0]):
                for j in range(confusion_matrix.shape[1]):
                   ax.text(j,i,int(confusion_matrix[i,j]),ha="center",va="center",color="white" if confusion_matrix[i,j] > thres else "black")
            fig.tight_layout()
            image_path="images/%s.png" % (expt_id)
            plt.savefig(image_path)
            mlflow.log_metric(image_path)