<a href="https://colab.research.google.com/github/AnuruddhaPaul/MOBILE_NET_Form_Scratch/blob/main/MOBILE_NET_From_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 1. Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 2. Depthwise Separable Convolution Block
# Standard Conv: 3x3 filter looking at ALL channels at once.
# MobileNet Split:
#   Part A (Depthwise): 3x3 filter looking at ONE channel at a time.
#   Part B (Pointwise): 1x1 filter combining the results.
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(Block, self).__init__()

        # Depthwise Layer
        # groups=in_channels is the key! It forces each filter to work on only one channel.
        self.depthwise = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride,
                      padding=1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True)
        )

        # Pointwise Layer
        # A standard 1x1 convolution to mix the features.
        self.pointwise = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1,
                      padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

# 3. MobileNet V1 Architecture
class MobileNet(nn.Module):
    def __init__(self, num_classes=10, in_channels=1):
        super(MobileNet, self).__init__()

        # Initial Conv Layer (Standard Conv)
        # Stride 1 to preserve MNIST size early on
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )

        # MobileNet Layers
        # Configuration: (in_channels, out_channels, stride)
        # We reduce the stride occurrences to prevent image from vanishing (becoming 0x0)
        self.layers = nn.Sequential(
            Block(32, 64, stride=1),
            Block(64, 128, stride=2),  # 32x32 -> 16x16
            Block(128, 128, stride=1),
            Block(128, 256, stride=2), # 16x16 -> 8x8
            Block(256, 256, stride=1),
            Block(256, 512, stride=2), # 8x8 -> 4x4

            # 5 blocks of 512
            Block(512, 512, stride=1),
            Block(512, 512, stride=1),
            Block(512, 512, stride=1),
            Block(512, 512, stride=1),
            Block(512, 512, stride=1),

            Block(512, 1024, stride=2), # 4x4 -> 2x2
            Block(1024, 1024, stride=2) # 2x2 -> 1x1
        )

        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.layers(x)

        # Global Average Pooling (if spatial dims > 1x1, otherwise it just flattens)
        x = nn.functional.avg_pool2d(x, x.size()[2:])
        x = x.view(x.size(0), -1) # Flatten
        x = self.fc(x)
        return x

# 4. Data Preparation
# Resizing to 32x32 to fit the downsampling logic perfectly (divisible by 2 five times: 32->16->8->4->2->1)
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 5. Initialize Model
model = MobileNet(num_classes=10, in_channels=1).to(device)

# 6. Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 7. Training Loop
print("Starting MobileNet Training...")
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# 8. Evaluation
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy of MobileNet on MNIST: {100 * correct / total:.2f}%')

# Save Model
torch.save(model.state_dict(), 'mobilenet_mnist.pth')

Using device: cuda


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.01MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 132kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.26MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.7MB/s]


Starting MobileNet Training...
Epoch [1/3], Step [100/938], Loss: 0.6695
Epoch [1/3], Step [200/938], Loss: 0.3322
Epoch [1/3], Step [300/938], Loss: 0.3108
Epoch [1/3], Step [400/938], Loss: 0.1733
Epoch [1/3], Step [500/938], Loss: 0.1545
Epoch [1/3], Step [600/938], Loss: 0.2208
Epoch [1/3], Step [700/938], Loss: 0.0359
Epoch [1/3], Step [800/938], Loss: 0.1271
Epoch [1/3], Step [900/938], Loss: 0.0655
Epoch [2/3], Step [100/938], Loss: 0.0445
Epoch [2/3], Step [200/938], Loss: 0.0078
Epoch [2/3], Step [300/938], Loss: 0.0339
Epoch [2/3], Step [400/938], Loss: 0.0204
Epoch [2/3], Step [500/938], Loss: 0.0365
Epoch [2/3], Step [600/938], Loss: 0.0543
Epoch [2/3], Step [700/938], Loss: 0.0224
Epoch [2/3], Step [800/938], Loss: 0.0029
Epoch [2/3], Step [900/938], Loss: 0.0981
Epoch [3/3], Step [100/938], Loss: 0.0266
Epoch [3/3], Step [200/938], Loss: 0.0130
Epoch [3/3], Step [300/938], Loss: 0.0447
Epoch [3/3], Step [400/938], Loss: 0.0194
Epoch [3/3], Step [500/938], Loss: 0.0528
Epo