<a href="https://colab.research.google.com/github/Pranav-Bhatlapenumarthi/Deep_Learning/blob/main/NN_Optimisers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm # to show progression bar
import random

import torch
import torch.nn as nn
import torch.optim as optim


import torch.nn.functional as F # provides functions that are used in building NN layers
from torch.utils.data import DataLoader # helps load data in batches
from torchvision import datasets, transforms # used to access datasets and applying necessary transformatations to images

!pip install torchmetrics
import torchmetrics # provides performance metrics for model evaluation

In [17]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
batch_size = 100

train_dataset = datasets.MNIST(root='datasets/', train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.MNIST(root='datasets/', train=False, download=True, transform=transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [19]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


In [20]:
model = CNN().to(device)
print(model)


CNN(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=3136, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=10, bias=True)
  )
)


In [21]:
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [22]:
def train_one_epoch(model, loader, optimizer, criterion):
  model.train() #Switches to training mode.
  total_loss = 0

  for x, y in loader:
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad() # clears old gradients as PyTorch accumulates gradients by default
    out = model(x)
    loss = criterion(out, y)
    loss.backward() # computes gradients via backpropagation
    optimizer.step() # updates model parameters

    total_loss += loss.item()

  return total_loss / len(loader) # Returns mean training loss per batch for the epoch as a monitoring metric



In [23]:
def evaluate(model, loader, criterion):
    model.eval() # Switches to evaluation mode
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad(): # Disables gradient tracking
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)

            total_loss += loss.item()
            preds = out.argmax(dim=1) # Converts logits to the predicted class
            correct += (preds == y).sum().item() # Computes classification accuracy
            total += y.size(0)

    return total_loss / len(loader), correct / total # returns average validation loss and validation accuracy
