![Banner](https://i.imgur.com/a3uAqnb.png)

# Cell Classification using ViT + Swin Transformers (Sliding-Window Approach)

In this homework, we will classify biomedical cell images using two Vision Transformer architectures:
- **ViT-B/16**
- **Swin-T**

The images have sizes 700*500. Both backbones require inputs of size **224×224**, which is smaller than the actual image sizes. Instead of resizing (which may distort the cell structure), we adopt a **sliding-window** approach:
- **Training**: we randomly crop 224×224 windows
- **Validation**: we center crop 224×224
- **Inference**: we slide a window over the full image and average the probabilities across windows to classify the full image.

Sliding window apporach is very useful if we have huge images sizes, or if we have different resolutions across the images.

In [1]:
import os
import pandas as pd
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("mohammad2012191/cells-types")

print("Path to dataset files:", path)

## 1️⃣ Load Data & Prepare Splits

**Task**: Load the `data.csv` file, extract labels, and perform stratified train/val split.

**ToDo**:
- Read the CSV file and cast `cell_type` to string
- Extract class names and build `label2idx` dictionary
- Perform stratified split with `train_test_split`


In [2]:
# TODO: Load the CSV file and cast 'cell_type' column to string
# TODO: Extract all unique classes and build a label2idx dictionary
# TODO: Perform stratified train/validation split based on cell_type

## 2️⃣ Data Preprocessing

**Task**: Define image transformations and implement a custom dataset class.

**ToDo**:
- Don't use Resize
- Apply `RandomCrop(224)` during training
- Apply `CenterCrop(224)` during validation (best we can do, we will apply sliding window for full image in inference)
- Normalize using ImageNet stats
- Load images from the `images/` folder

In [3]:
# TODO: Define training transforms (RandomCrop(224), RandomHorizontalFlip, Normalize)
# TODO: Define validation transforms (CenterCrop(224), Normalize)
# TODO: Implement CellDataset class:
#       - Load image using PIL
#       - Apply transforms
#       - Map label string to index using label2idx
#       - Return image and label

## 3️⃣ Create DataLoaders

**Task**: Load datasets using `DataLoader`.

**ToDo**:
- Use `shuffle=True` for training
- Use `shuffle=False` for validation
- Set batch size and workers
- Define the device

In [4]:
# TODO: Instantiate CellDataset for training and validation
# TODO: Wrap both in DataLoaders (shuffle=True for train, False for val)
# TODO: Check device availability and print

## 4️⃣ Build ViT + Swin Combined Model

**Task**: Create a model that extracts features from both backbones and concatenates them.

**ToDo**:
- Load ViT-B/16 (models.vit_b_16) and Swin-T (models.swin_t) with pretrained weights
- Replace their heads with `nn.Identity` (i.e. remove the classifier heads)
- Concatenate features and pass to a linear layer

In [5]:
# TODO: Create VitSwinConcat class that inherits from nn.Module
# TODO: In __init__():
#       - Load ViT-B/16 and Swin-T from torchvision (pretrained)
#       - Replace their classification heads with Identity()
#       - Concatenate their outputs (768+768) and use Linear classifier
# TODO: In forward(x):
#       - Get features from both backbones
#       - Concatenate along dim=1
#       - Pass through classifier
# TODO: Instantiate the model, define CrossEntropy loss and Adam optimizer

## 5️⃣ Train & Validate

**Task**: Train the model and evaluate accuracy on the validation set.

**ToDo**:
- Write training and inference loops
- Track training/validation loss and accuracy
- Save the model at the end

In [6]:
# TODO: Loop for each epoch:
#       - Training: Zero grad → Forward → Loss → Backward → Step
#       - Track and print average training loss
# TODO: In evaluation:
#       - Disable gradient, forward pass
#       - Calculate average validation loss and accuracy
# TODO: Save model checkpoint after training

## 6️⃣ Sliding-Window Inference

*Task*: Write a function to classify a full image(not the cropped one) using sliding windows.

*ToDo*:
- Slide a 224×224 window with stride (e.g. 112)
- Average softmax probabilities
- Print individual patch predictions and final class

In [None]:
# TODO: Define inference_sliding_window(model, img_path):
#       - Load full-size image
#       - Slide 224×224 window (stride=112)
#       - For each patch: normalize → predict → store softmax
# TODO: Average probabilities across patches
# TODO: Print:
#       - Probabilities for each patch
#       - Averaged probabilities
#       - Final predicted class for full image (argmax of average)

# TODO: Call inference_sliding_window(model, "images/5.png") or any other test image


# Example output:

# Patch 0:  {'astro': 0.0003629255515988916, 'cort': 0.9995890259742737, 'shsy5y': 4.803628326044418e-05}
# Patch 1:  {'astro': 0.002050854032859206, 'cort': 0.9978287816047668, 'shsy5y': 0.00012039497960358858}
# Patch 2:  {'astro': 0.0003413844096940011, 'cort': 0.9996126294136047, 'shsy5y': 4.605785943567753e-05}
# Patch 3:  {'astro': 0.0004650430055335164, 'cort': 0.9994938373565674, 'shsy5y': 4.10635257139802e-05}
# ...
# Patch 12:  {'astro': 0.00020842120284214616, 'cort': 0.9997585415840149, 'shsy5y': 3.304508572909981e-05}
# Patch 13:  {'astro': 0.0001657304965192452, 'cort': 0.9998027682304382, 'shsy5y': 3.147909228573553e-05}
# Patch 14:  {'astro': 0.0004247387987561524, 'cort': 0.9994537234306335, 'shsy5y': 0.00012151007103966549}
# Average:  {'astro': 0.0005312784924171865, 'cort': 0.9994047284126282, 'shsy5y': 6.39661229797639e-05}
# Final class:  cort

## 📝 Evaluation Criteria

Your Cell Classification homework will be evaluated based on:

1. **Implementation Correctness (70%)**
   - Proper stratified splitting and label encoding
   - Correct use of ViT-B/16 and Swin-T architectures
   - Model correctly concatenates features from both backbones
   - Sliding-window inference correctly averages predictions

2. **Training and Results (20%)**
   - Model trains and converges on validation set
   - Accuracy and loss values are properly printed each epoch
   - Trained weights saved successfully and reused for inference

3. **Code Quality and Structure (10%)**
   - Code follows modular structure with appropriate class/function design
   - Clean, readable code with `ToDo` comments per cell
   - Results are printed clearly, including per-patch and average inference
