<a href="https://colab.research.google.com/github/AnuruddhaPaul/RES_NET_From_Scratch/blob/main/RES_NET_From_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 1. Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 2. The Residual Block (BasicBlock)
# This is the heart of ResNet: Input + f(Input)
class BasicBlock(nn.Module):
    expansion = 1 # Used for ResNet-50+, where expansion=4

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()

        # First Conv Layer
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Second Conv Layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut connection (The "Skip Connection")
        # If input shape != output shape (due to stride or channel change),
        # we need a 1x1 conv to match dimensions so we can add them.
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # The Magic Step: Add original input (identity) to the output
        out += self.shortcut(identity)
        out = self.relu(out)

        return out

# 3. ResNet Architecture (Generic)
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, in_channels=1):
        super(ResNet, self).__init__()
        self.in_channels = 64

        # Initial Layer (Modified for MNIST)
        # We use 3x3 kernel with stride 1 instead of 7x7 stride 2 to preserve size
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        # ResNet Layers
        # usage of _make_layer allows stacking multiple blocks
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # Final Classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        # The first block in a layer handles the stride (downsampling)
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

# Helper function to create ResNet-18
def ResNet18(num_classes=10, in_channels=1):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, in_channels)

# 4. Training Setup
# MNIST is 28x28, we resize to 32x32 for cleaner division by 2
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
                                           transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False,
                                          transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize
model = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 5. Training Loop
print("Starting ResNet-18 Training...")
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# 6. Evaluation
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of ResNet-18 on MNIST: {100 * correct / total:.2f}%')

torch.save(model.state_dict(), 'resnet18_mnist.pth')

Using device: cuda


100%|██████████| 9.91M/9.91M [00:00<00:00, 19.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 475kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.42MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 15.2MB/s]


Starting ResNet-18 Training...
Epoch [1/3], Step [100/938], Loss: 0.0640
Epoch [1/3], Step [200/938], Loss: 0.1908
Epoch [1/3], Step [300/938], Loss: 0.0370
Epoch [1/3], Step [400/938], Loss: 0.0340
Epoch [1/3], Step [500/938], Loss: 0.0657
Epoch [1/3], Step [600/938], Loss: 0.0149
Epoch [1/3], Step [700/938], Loss: 0.0458
Epoch [1/3], Step [800/938], Loss: 0.1255
Epoch [1/3], Step [900/938], Loss: 0.1369
Epoch [2/3], Step [100/938], Loss: 0.1272
Epoch [2/3], Step [200/938], Loss: 0.0576
Epoch [2/3], Step [300/938], Loss: 0.0143
Epoch [2/3], Step [400/938], Loss: 0.0081
Epoch [2/3], Step [500/938], Loss: 0.0452
Epoch [2/3], Step [600/938], Loss: 0.0057
Epoch [2/3], Step [700/938], Loss: 0.0179
Epoch [2/3], Step [800/938], Loss: 0.0102
Epoch [2/3], Step [900/938], Loss: 0.0169
Epoch [3/3], Step [100/938], Loss: 0.1211
Epoch [3/3], Step [200/938], Loss: 0.0428
Epoch [3/3], Step [300/938], Loss: 0.0205
Epoch [3/3], Step [400/938], Loss: 0.1309
Epoch [3/3], Step [500/938], Loss: 0.0581
Epo