In [1]:
# libs 
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.cuda.amp import GradScaler, autocast
from tqdm.autonotebook import tqdm
from torch.utils.data import DataLoader

In [2]:
# Define the neural network architecture
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)
        x = x.view(-1, 64*7*7)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [3]:
# Tranformations needed for MNIST

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307), (0.3081))
])

In [4]:
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)


train_loader = DataLoader(train_dataset, batch_size=20000,
                                            shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=20000,
                                        shuffle=False, num_workers=2)


In [5]:
# device checking
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [6]:
!Nvidia-smi

Sun May  5 22:05:53 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 526.98       Driver Version: 526.98       CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0 Off |                  N/A |
| N/A   56C    P8    12W /  N/A |      0MiB /  6144MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [7]:
# Initialize the model, optimizer, and loss function
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [None]:
# Normal Training loop

num_epochs = 2

for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.detach().item()
        _, predicted = outputs.max(1)
        total += labels.detach().size(0)
        correct += predicted.eq(labels).detach().sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%')


    # Evaluation loop
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            test_total += labels.detach().size(0)
            test_correct += predicted.eq(labels).detach().sum().item()

    test_acc = 100. * test_correct / test_total
    print(f'Test Accuracy: {test_acc:.2f}%')


In [8]:
# Automatic Mixed precision training Loop

scaler = GradScaler()

# Training loop
num_epochs = 2
for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        # Mixed precision training
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        """The update() method of the GradScaler class is used to update the scaling 
            factor used for the next iteration.The scaling factor is adjusted based on
            the finiteness of the gradients in the previous iteration,to ensure that the
            gradients remain in a suitable range for the 16-bit arithmetic."""

        running_loss += loss.detach().item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).detach().sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%')

    # Evaluation loop
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            test_total += labels.detach().size(0)
            test_correct += predicted.eq(labels).detach().sum().item()

    test_acc = 100. * test_correct / test_total
    print(f'Test Accuracy: {test_acc:.2f}%')


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch [1/2], Loss: 2.1605, Accuracy: 34.44%


  0%|          | 0/1 [00:00<?, ?it/s]

Test Accuracy: 73.30%


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch [2/2], Loss: 1.5403, Accuracy: 75.85%


  0%|          | 0/1 [00:00<?, ?it/s]

Test Accuracy: 80.18%
