<a href="https://colab.research.google.com/github/IHMilon/CNN-Image-Classification-Portfolio/blob/main/VGG11_CIFAR10/CNN_SHIFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchprofile 1>/dev/null

In [None]:
import os
import random
from collections import OrderedDict, defaultdict
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

import torch
import torchvision
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import Dataset,DataLoader
from torchprofile import profile_macs
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torch.optim.lr_scheduler import OneCycleLR

In [None]:
seed_value = 0
torch.manual_seed(seed_value)
np.random.seed(seed_value)
random.seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # data augmentation
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    ])

In [None]:
train_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=train_transform
)

test_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=test_transform
)

In [None]:
batch_size = 64

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
     pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False,
     pin_memory=True
)

In [None]:
images, labels = next(iter(train_loader))
print( f"Image batch shape: {images.shape}")  # (B, C, H, W)
print( f"Label batch shape: {labels.shape}")  # (B,)

Image batch shape: torch.Size([64, 3, 32, 32])
Label batch shape: torch.Size([64])


In [None]:
class CPSMBlock(nn.Module):
    """
    Co-Prime Shift Mixer (CPSM) block.
    Can be used like a Transformer block: input -> CPSM -> output (residual)
    """
    def __init__(self, in_channels, out_channels=None, shifts=None, use_residual=True):
        """
        Args:
            in_channels (int): Number of input channels
            out_channels (int): Number of output channels (default = in_channels)
            shifts (list of tuple): List of (dy, dx) shifts to apply. Default: 4 shifts
            use_residual (bool): Whether to use residual connection
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels or in_channels
        self.use_residual = use_residual

        # Default shifts: 4 co-prime style shifts
        if shifts is None:
            self.shifts = [(0,0), (0,1), (1,0), (1,1)]
        else:
            self.shifts = shifts

        # 1x1 mixer to combine stacked channels
        self.mixer = nn.Conv2d(
            in_channels * len(self.shifts),
            self.out_channels,
            kernel_size=1,
            bias=False
        )
        self.bn = nn.BatchNorm2d(self.out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Residual scaling factor (learnable)
        self.alpha = nn.Parameter(torch.zeros(1))  # better

        # Optional projection if in/out channels differ
        if self.use_residual and in_channels != self.out_channels:
            self.res_proj = nn.Conv2d(in_channels, self.out_channels, kernel_size=1, bias=False)
        else:
            self.res_proj = None

    def forward(self, x):
        # Step 1: Apply shifts
        shifted_maps = []
        for dy, dx in self.shifts:
            shifted = torch.roll(x, shifts=(dy, dx), dims=(2,3))  # circular roll
            shifted_maps.append(shifted)

        # Step 2: Stack shifted maps along channel dimension
        stacked = torch.cat(shifted_maps, dim=1)  # B x (C*N) x H x W

        # Step 3: Mix channels with 1x1 conv
        mixed = self.mixer(stacked)
        mixed = self.bn(mixed)
        mixed = self.relu(mixed)

        # Step 4: Residual connection
        if self.use_residual:
            residual = self.res_proj(x) if self.res_proj is not None else x
            out = residual + self.alpha * mixed
        else:
            out = mixed
        return out

        return out

In [None]:
class CPSMNetCIFAR10(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # Stem
        self.stem = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        # Stage 1
        self.stage1 = nn.Sequential(
            CPSMBlock(64, 128, shifts=[(0,0),(1,0),(0,1),(1,1)]),
            CPSMBlock(128, 128, shifts=[(0,0),(2,0),(0,2),(1,1)])
        )
        self.down1 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False)

        # Stage 2
        self.stage2 = nn.Sequential(
            CPSMBlock(128, 256, shifts=[(0,0),(1,0),(0,2),(2,1)]),
            CPSMBlock(256, 256, shifts=[(0,0),(3,0),(0,3),(2,2)])
        )
        self.down2 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False)

        # Stage 3
        self.stage3 = nn.Sequential(
            CPSMBlock(256, 256, shifts=[(0,0),(1,0),(0,4),(3,2)]),
            CPSMBlock(256, 256, shifts=[(0,0),(4,0),(0,5),(3,3)])
        )

        # Classifier Head (256-dim input now)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.relu(self.bn1(self.stem(x)))
        x = self.stage1(x)
        x = self.down1(x)
        x = self.stage2(x)
        x = self.down2(x)
        x = self.stage3(x)
        x = self.pool(x).flatten(1)
        x = self.fc(x)
        return x

In [None]:
model=CPSMNetCIFAR10(num_classes=10)
model=model.to(device)

In [None]:
num_params = 0
for param in model.parameters():
  if param.requires_grad:
    num_params += param.numel()
print("#Params:", num_params/1000000,"Million")

#Params: 1.867344 Million


In [None]:
model.eval()  # Set the model to evaluation mode
num_macs = profile_macs(model, torch.zeros(1, 3, 32, 32).cuda())
print("#MACs:", num_macs/1000000,"Millions")

#MACs: 329.911296 Millions




In [None]:
EPOCHS = 20
LR = 0.001

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=1e-2)
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)  # decay lr every 5 epochs

In [None]:

for epoch in range(EPOCHS):
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        for images, labels in tqdm(train_loader, desc=f"[Train] Epoch {epoch+1}/{EPOCHS}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            predicted = outputs.argmax(dim=1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

        train_acc = 100. * train_correct / train_total

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in tqdm(test_loader, desc=f"[Val] Epoch {epoch+1}/{EPOCHS}"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                predicted = outputs.argmax(dim=1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = 100. * val_correct / val_total
        scheduler.step()

        print(f"Epoch {epoch+1}/{EPOCHS}")
        print(f"  Train Loss: {train_loss/len(train_loader):.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val   Loss: {val_loss/len(test_loader):.4f} | Val Acc: {val_acc:.2f}%")