# Image Classification with ResNet

This notebook implements a custom CNN using PyTorch for image classification, featuring:
- Custom ResNet architecture
- CIFAR-10 dataset training
- Streamlit interface
- Model explainability with GradCAM
- Performance optimizations

Author: ShaLese

In [None]:
# Install required packages
!pip install torch torchvision streamlit captum gradcam

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
from typing import List, Tuple, Dict

# Device agnostic code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# Custom ResNet Block
class ResBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

# Custom ResNet
class CustomResNet(nn.Module):
    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = self.make_layer(64, 2, stride=1)
        self.layer2 = self.make_layer(128, 2, stride=2)
        self.layer3 = self.make_layer(256, 2, stride=2)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def make_layer(self, out_channels: int, num_blocks: int, stride: int) -> nn.Sequential:
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(ResBlock(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
# Data loading and preprocessing
def get_data_loaders(batch_size: int = 128, num_workers: int = 4) -> Tuple[DataLoader, DataLoader]:
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return trainloader, testloader

# Training function with timing and progress tracking
def train_model(model: nn.Module, trainloader: DataLoader, epochs: int = 10) -> List[float]:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    model = model.to(device)
    losses = []
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        start_time = time.time()
        
        progress_bar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{epochs}')
        for inputs, targets in progress_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            progress_bar.set_postfix({'loss': running_loss/len(trainloader)})
        
        epoch_loss = running_loss / len(trainloader)
        losses.append(epoch_loss)
        scheduler.step()
        
        epoch_time = time.time() - start_time
        print(f'Epoch {epoch+1} completed in {epoch_time:.2f} seconds. Loss: {epoch_loss:.4f}')
    
    return losses

In [None]:
# Model evaluation
def evaluate_model(model: nn.Module, testloader: DataLoader) -> Tuple[float, float]:
    model.eval()
    correct = 0
    total = 0
    start_time = time.time()
    
    with torch.no_grad():
        for inputs, targets in tqdm(testloader, desc='Evaluating'):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    accuracy = 100. * correct / total
    eval_time = time.time() - start_time
    return accuracy, eval_time

In [None]:
# GradCAM implementation for model explainability
class GradCAM:
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.features = None
        
        self.target_layer.register_forward_hook(self.save_features)
        self.target_layer.register_backward_hook(self.save_gradients)
    
    def save_features(self, module: nn.Module, input: torch.Tensor, output: torch.Tensor) -> None:
        self.features = output.detach()
    
    def save_gradients(self, module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor) -> None:
        self.gradients = grad_output[0].detach()
    
    def generate_cam(self, input_image: torch.Tensor, target_class: int) -> np.ndarray:
        self.model.eval()
        output = self.model(input_image)
        
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        self.model.zero_grad()
        output[0, target_class].backward()
        
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
        for i in range(pooled_gradients.shape[0]):
            self.features[:, i, :, :] *= pooled_gradients[i]
            
        heatmap = torch.mean(self.features, dim=1).squeeze()
        heatmap = F.relu(heatmap)
        heatmap /= torch.max(heatmap)
        
        return heatmap.cpu().numpy()

In [None]:
# Main execution
def main():
    print(f'Using device: {device}')
    
    # Initialize model and data
    model = CustomResNet()
    trainloader, testloader = get_data_loaders()
    
    # Train the model
    print('Starting training...')
    start_time = time.time()
    losses = train_model(model, trainloader)
    training_time = time.time() - start_time
    print(f'Training completed in {training_time:.2f} seconds')
    
    # Evaluate the model
    accuracy, eval_time = evaluate_model(model, testloader)
    print(f'Test Accuracy: {accuracy:.2f}%')
    print(f'Evaluation completed in {eval_time:.2f} seconds')
    
    # Save the model
    torch.save(model.state_dict(), 'resnet_cifar10.pth')
    
    return model, losses

if __name__ == '__main__':
    model, losses = main()

In [None]:
# Streamlit interface code (save as app.py when deploying)
'''
import streamlit as st
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from model import CustomResNet  # assuming model is saved in model.py

# Load the model
def load_model():
    model = CustomResNet()
    model.load_state_dict(torch.load('resnet_cifar10.pth', map_location=torch.device('cpu')))
    model.eval()
    return model

# Preprocess image
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    return transform(image).unsqueeze(0)

# Main Streamlit app
def main():
    st.title('Image Classification with ResNet')
    st.write('Upload an image for classification')
    
    uploaded_file = st.file_uploader('Choose an image...', type=['jpg', 'jpeg', 'png'])
    
    if uploaded_file is not None:
        # Display the uploaded image
        image = Image.open(uploaded_file)
        st.image(image, caption='Uploaded Image', use_column_width=True)
        
        # Make prediction
        model = load_model()
        processed_image = preprocess_image(image)
        
        with torch.no_grad():
            start_time = time.time()
            outputs = model(processed_image)
            prediction_time = time.time() - start_time
            
            probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
            predicted_class = torch.argmax(probabilities).item()
            
            # Generate GradCAM visualization
            gradcam = GradCAM(model, model.layer3[-1])
            cam = gradcam.generate_cam(processed_image, predicted_class)
            
            # Display results
            st.write(f'Predicted class: {predicted_class}')
            st.write(f'Confidence: {probabilities[predicted_class]:.2f}')
            st.write(f'Prediction time: {prediction_time:.3f} seconds')
            
            # Display GradCAM
            st.image(cam, caption='GradCAM Visualization', use_column_width=True)

if __name__ == '__main__':
    main()
'''