In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.transforms import ToTensor
import torchvision.datasets as datasets
import torchvision.transforms as transform


In [2]:
class CNN_classifier(nn.Module):
    def __init__(self):
        super(CNN_classifier,self).__init__()
        self.net=nn.Sequential(nn.Conv2d(1,64,3),
                              nn.ReLU(),
                              nn.MaxPool2d((2,2),stride=2),
                              nn.Conv2d(64,128,3),
                              nn.ReLU(),
                              nn.MaxPool2d((2,2),stride=2),
                              nn.Conv2d(128,64,3),
                              nn.ReLU(),
                              nn.MaxPool2d((2,2),stride=2),
                              )
        self.classification_head = nn.Sequential(nn.Linear(64,20,bias=True),
                                                nn.ReLU(),
                                                nn.Linear(20,10,bias=True),)
    def forward(self,x):
        features=self.net(x)
        return self.classification_head(features.view(batch_size,-1))


In [3]:
train_dataset=datasets.MNIST(root="./data",download=True,train=True,transform=ToTensor())
train_loader=DataLoader(train_dataset,batch_size=50,shuffle=True)
test_dataset=datasets.MNIST(root="./data",download=True,train=False,transform=ToTensor())
test_loader=DataLoader(test_dataset,batch_size=50,shuffle=True)

In [4]:
model=CNN_classifier()
loss_fxn=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
epochs=10
batch_size=50

In [5]:
total_parameters=0
for name,param in model.named_parameters():
    params=param.numel()
    total_parameters+=params
total_parameters

149798

In [6]:
patience=5
curr_patience=0
best_validation_loss=float('inf')

In [7]:
#training
for epoch in range(epochs):
    running_loss=0.0
    model.train()
    for i,data in enumerate(train_loader,0):
        inputs,labels=data[0],data[1]
        optimizer.zero_grad()
        outputs=model(inputs)
        loss=loss_fxn(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        if i%100==99:
            print(f"epoch {epoch+1} iter {i+1} - > loss={running_loss/100}")
            running_loss=0.0
    #validation
    model.eval()
    with torch.no_grad():
        validation_loss=0.0
        for inputs,label in test_loader:
            inputs,labels=inputs,labels
            outputs=model(inputs)
            loss=loss_fxn(outputs,labels)
            validation_loss+=loss.item()
        validation_loss=validation_loss/len(test_loader)
        print(f"epoch {epoch+1} validation_loss={validation_loss}")
        
        if(validation_loss<best_validation_loss):
            best_validation_loss=validation_loss
            curr_patience=0
        else:
            curr_patience+=1
            if curr_patience>patience:
                print(f"early stopping {patience} reached")
print(f"training and validaion done..bvloss={best_validation_loss}")

        

epoch 1 iter 100 - > loss=2.3044142818450926
epoch 1 iter 200 - > loss=2.2972498893737794
epoch 1 iter 300 - > loss=2.2899226188659667
epoch 1 iter 400 - > loss=2.2806390953063964
epoch 1 iter 500 - > loss=2.268038468360901
epoch 1 iter 600 - > loss=2.250133376121521
epoch 1 iter 700 - > loss=2.2144155645370485
epoch 1 iter 800 - > loss=2.1560411071777343
epoch 1 iter 900 - > loss=2.039823812246323
epoch 1 iter 1000 - > loss=1.7747356545925141
epoch 1 iter 1100 - > loss=1.347001891732216
epoch 1 iter 1200 - > loss=1.044227167367935
epoch 1 validation_loss=4.460157349109649
epoch 2 iter 100 - > loss=0.8818890464305877
epoch 2 iter 200 - > loss=0.7391111186146736
epoch 2 iter 300 - > loss=0.6315556785464287
epoch 2 iter 400 - > loss=0.5685233515501023
epoch 2 iter 500 - > loss=0.47973887860774994
epoch 2 iter 600 - > loss=0.4769908121228218
epoch 2 iter 700 - > loss=0.4165441270172596
epoch 2 iter 800 - > loss=0.3826052460074425
epoch 2 iter 900 - > loss=0.36228183552622795
epoch 2 iter 