In [None]:
import torch
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
import torch.nn as nn
import pandas as pd
import numpy as np


if torch.cuda.is_available():
    device=torch.device(type="cuda",index=0)
else:
    device=torch.device(type="cpu",index=0)

print(device)

train_dataset = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_dataset = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

batch_size=64

train_dl=DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)
test_dl=DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
)



class DRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu=nn.ReLU()
        self.conv1=nn.Conv2d(in_channels=1,out_channels=8, kernel_size=(3,3), stride=1, padding=0)
        self.bn1=nn.BatchNorm2d(8)
        self.mp1=nn.MaxPool2d(kernel_size=(2,2),stride=2,padding=0)

        self.conv2=nn.Conv2d(in_channels=8,out_channels=16, kernel_size=(3,3), stride=1, padding=0)
        self.bn2=nn.BatchNorm2d(16)

        self.conv3=nn.Conv2d(in_channels=16,out_channels=32, kernel_size=(3,3), stride=1, padding=0)
        self.bn3=nn.BatchNorm2d(32)

        self.conv4=nn.Conv2d(in_channels=32,out_channels=64, kernel_size=(3,3), stride=1, padding=0)
        self.bn4=nn.BatchNorm2d(64)

        self.flatten=nn.Flatten()

        self.lin1=nn.Linear(in_features=3136, out_features=10)
        self.bn5=nn.BatchNorm1d(num_features=10)

    def forward(self,x):
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.mp1(x)

        x=self.conv2(x)
        x=self.bn2(x)
        x=self.relu(x)

        x=self.conv3(x)
        x=self.bn3(x)
        x=self.relu(x)

        x=self.conv4(x)
        x=self.bn4(x)
        x=self.relu(x)

        x=self.flatten(x)

        x=self.lin1(x)
        output=self.bn5(x)

        return output

def train_one_epoch(dataloader, model,loss_fn, optimizer):
    model.train()
    track_loss=0
    num_correct=0
    for i, (imgs, labels) in enumerate(dataloader):
        imgs=imgs.to(device)
        labels=labels.to(device)
        pred=model(imgs)

        loss=loss_fn(pred,labels)
        track_loss+=loss.item()
        num_correct+=(torch.argmax(pred,dim=1)==labels).type(torch.float).sum().item()

        running_loss=round(track_loss/(i+(imgs.shape[0]/batch_size)),2)
        running_acc=round((num_correct/((i*batch_size+imgs.shape[0])))*100,2)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if i%100==0:
            print("Batch:", i+1, "/",len(dataloader), "Running Loss:",running_loss, "Running Accuracy:",running_acc)

    epoch_loss=running_loss
    epoch_acc=running_acc
    return epoch_loss, epoch_acc

def eval(dataloader, model, loss_fn):
    model.eval()

    total_samples = 0
    correct_predictions = 0

    with torch.no_grad():
        for imgs, labels in dataloader:  # Assuming your dataloader yields (images, labels)
            imgs = imgs.to(device)
            labels = labels.to(device)  # Move labels to the same device

            pred = model(imgs)
            pred_classes = torch.argmax(pred, dim=1).type(torch.int).cpu()

            # Update the total samples and correct predictions
            total_samples += labels.size(0)
            correct_predictions += (pred_classes == labels.cpu()).sum().item()

    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    print(f'Accuracy: {accuracy * 100:.2f}%')

model=DRNN()
model=model.to(device)
loss_fn=nn.CrossEntropyLoss()
lr=0.001
#optimizer=torch.optim.SGD(params=model.parameters(), lr=lr)
optimizer=torch.optim.Adam(params=model.parameters(), lr=lr)
n_epochs=3

for i in range(n_epochs):
    print("Epoch No:",i+1)
    train_epoch_loss, train_epoch_acc=train_one_epoch(train_dl,model,loss_fn,optimizer)
    print("Training:", "Epoch Loss:", train_epoch_loss, "Epoch Accuracy:", train_epoch_acc)
    print("--------------------------------------------------")


cuda:0
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 11399531.81it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 353912.32it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3202735.70it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4307150.98it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Epoch No: 1
Batch: 1 / 938 Running Loss: 2.81 Running Accuracy: 4.69
Batch: 101 / 938 Running Loss: 0.7 Running Accuracy: 88.83
Batch: 201 / 938 Running Loss: 0.57 Running Accuracy: 92.63
Batch: 301 / 938 Running Loss: 0.5 Running Accuracy: 94.17
Batch: 401 / 938 Running Loss: 0.45 Running Accuracy: 95.06
Batch: 501 / 938 Running Loss: 0.41 Running Accuracy: 95.6
Batch: 601 / 938 Running Loss: 0.38 Running Accuracy: 96.05
Batch: 701 / 938 Running Loss: 0.36 Running Accuracy: 96.36
Batch: 801 / 938 Running Loss: 0.34 Running Accuracy: 96.58
Batch: 901 / 938 Running Loss: 0.32 Running Accuracy: 96.79
Training: Epoch Loss: 0.31 Epoch Accuracy: 96.84
--------------------------------------------------
Epoch No: 2
Batch: 1 / 938 Running Loss: 0.19 Running Accuracy: 95.31
Batch: 101 / 938 Running Loss: 0.15 Running Accuracy: 98.78
Batch: 201 / 938 Running Loss: 0.14 Running Accuracy: 98.68
Batch: 301 / 938 Running Loss: 0.

In [None]:
eval(test_dl, model,loss_fn)

Accuracy: 99.21%
