<a href="https://colab.research.google.com/github/AakashAhuja30/Pytorch/blob/main/CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transform
import numpy as np

In [2]:
class CNN(nn.Module):
  def __init__(self, in_channels=1, num_classes=10):
    super(CNN,self).__init__()
    self.conv1 = nn.Conv2d(in_channels=1,out_channels=8,kernel_size=(3,3),stride=(1,1),padding=(1,1))
    self.pool = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
    self.conv2 = nn.Conv2d(in_channels=8,out_channels=16,kernel_size=(3,3),stride=(1,1),padding=(1,1))
    self.fc1= nn.Linear(16*7*7, num_classes)

  def forward(self,x):
    x = F.relu(self.conv1(x))
    x = self.pool(x)
    x = F.relu(self.conv2(x))
    x = self.pool(x)
    x = x.reshape(x.shape[0],-1)
    x = self.fc1(x)

    return x

In [3]:
def save_checkpoint(state, filename = 'mycheckpoint.pth.tar'):
  print("Saving Checkpoint")
  torch.save(state,filename)

def load_checkpoint(checkpoint):
  print('Loading Checkpoint')
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])

In [4]:
#Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [5]:
#Hyperparameters
input_channels = 1
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 10
load_model = True

In [6]:
#Load Data
train_dataset = datasets.MNIST(root='dataset/',train = True, transform=transform.ToTensor(), download=True)
train_loader = DataLoader(dataset = train_dataset , batch_size=batch_size,shuffle=True)
test_dataset = datasets.MNIST(root='dataset/',train = False, transform=transform.ToTensor(), download=True)
test_loader = DataLoader(dataset = test_dataset , batch_size=batch_size,shuffle=True)

In [7]:
model= CNN().to(device)
#Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=learning_rate)

In [8]:
def check_accuracy(loader,model):
  if loader.dataset.train:
    print('Checking accuracy on Train data')
  else:
    print('Checking accuracy on Test data')

  num_samples =  0
  num_correct = 0
  model.eval()

  with torch.no_grad():
    for x,y in loader:
      x = x.to(device = device)
      y = y.to(device = device)

      scores= model(x)
      _,pred=torch.max(scores,1)
      num_correct += (pred ==y).sum()
      num_samples += pred.size(0)

    print(f'Accuracy is :{(num_correct/num_samples)*100}')

  model.train()  

In [9]:
if load_model:
  load_checkpoint(torch.load('mycheckpoint.pth.tar'))

for epoch in range(num_epochs):

  losses = []
  checkpoint = {'state_dict': model.state_dict(),'optimizer':optimizer.state_dict()}

  if epoch %3 ==0:
    save_checkpoint(checkpoint)


  for batch_idx, (data,target) in enumerate(train_loader):
    data = data.to(device = device)
    target = target.to(device = device)
    
    #Forward pass
    output = model(data)
    loss = criterion(output,target)
    losses.append(loss.item())

    #Backward pass
    loss.backward()

    #Gradient descent
    optimizer.step()
    optimizer.zero_grad()

  loss_mean = float(np.mean(losses))
  print(f'epoch{epoch+1} , Loss: {loss_mean}')
  check_accuracy(train_loader,model)
  check_accuracy(test_loader,model)

Loading Checkpoint
Saving Checkpoint
epoch1 , Loss: 0.0383143088793251
Checking accuracy on Train data
Accuracy is :98.78333282470703
Checking accuracy on Test data
Accuracy is :98.43999481201172
epoch2 , Loss: 0.034775494160301894
Checking accuracy on Train data
Accuracy is :99.07167053222656
Checking accuracy on Test data
Accuracy is :98.72999572753906
epoch3 , Loss: 0.032180779007958994
Checking accuracy on Train data
Accuracy is :99.20833587646484
Checking accuracy on Test data
Accuracy is :98.64999389648438
Saving Checkpoint
epoch4 , Loss: 0.028942582640312374
Checking accuracy on Train data
Accuracy is :99.03333282470703
Checking accuracy on Test data
Accuracy is :98.5199966430664
epoch5 , Loss: 0.0272563511928256
Checking accuracy on Train data
Accuracy is :99.23999786376953
Checking accuracy on Test data
Accuracy is :98.64999389648438
epoch6 , Loss: 0.025852051654099317
Checking accuracy on Train data
Accuracy is :99.41500091552734
Checking accuracy on Test data
Accuracy is :98