![Banner](https://i.imgur.com/a3uAqnb.png)

# Cell Classification using ViT + Swin Transformers (Sliding-Window Approach)

In this homework, we will classify biomedical cell images using two Vision Transformer architectures:
- **ViT-B/16**
- **Swin-T**

Both backbones require inputs of size **224×224**, which is smaller than the actual image sizes. Instead of resizing (which may distort the cell structure), we adopt a **sliding-window** approach:
- **Training**: we randomly crop 224×224 windows
- **Validation**: we center crop 224×224
- **Inference**: we slide a window over the full image and average the probabilities across windows

Sliding window apporach is very useful if we have huge images sizes, or if we have different resolutions amongst the images.

In [None]:
import os
import pandas as pd
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

In [None]:
import kagglehub
path = kagglehub.dataset_download("mohammad2012191/cells-types")
print("Path to dataset files:", path)

## 1️⃣ Load Data & Prepare Splits

**Task**: Load the `data.csv` file, extract labels, and perform stratified train/val split.

**ToDo**:
- Read the CSV file and cast `cell_type` to string
- Extract class names and build `label2idx` dictionary
- Perform stratified split with `train_test_split`


In [None]:
# 1. Load and split data
csv_path = path + "/data.csv"  # expects id,cell_type
df = pd.read_csv(csv_path)
df['cell_type'] = df['cell_type'].astype(str)
# Create label->index mapping
types = sorted(df['cell_type'].unique())
label2idx = {c: i for i, c in enumerate(types)}

# Stratified train/val split
train_df, val_df = train_test_split(
    df, test_size=0.2, stratify=df['cell_type'], random_state=42
)

## 2️⃣ Data Preprocessing

**Task**: Define image transformations and implement a custom dataset class.

**ToDo**:
- Don't use Resize
- Apply `RandomCrop(224)` during training
- Apply `CenterCrop(224)` during validation (best we can do, we will apply sliding window for full image in inference)
- Normalize using ImageNet stats
- Load images from the `images/` folder

In [None]:
# Transforms (no resizing; enforce 224x224 via crops)
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
train_transform = T.Compose([
    T.RandomCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean, std),
])
val_transform = T.Compose([
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean, std),
])

# Dataset class
doc_dir = path + "/images"  
class CellDataset(Dataset):
    def __init__(self, df, img_dir, transform, label2idx):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.label2idx = label2idx

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_id = self.df.loc[idx, 'id']
        label = self.df.loc[idx, 'cell_type']
        path = os.path.join(self.img_dir, f"{img_id}.png")
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.label2idx[label]

## 3️⃣ Create DataLoaders

**Task**: Load datasets using `DataLoader`.

**ToDo**:
- Use `shuffle=True` for training
- Use `shuffle=False` for validation
- Set batch size and workers
- Define the device

In [38]:
batch_size = 32
train_ds = CellDataset(train_df, doc_dir, train_transform, label2idx)
val_ds = CellDataset(val_df, doc_dir, val_transform, label2idx)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 4️⃣ Build ViT + Swin Combined Model

**Task**: Create a model that extracts features from both backbones and concatenates them.

**ToDo**:
- Load ViT-B/16 (models.vit_b_16) and Swin-T (models.swin_t) with pretrained weights
- Replace their heads with `nn.Identity` (i.e. remove the classifier heads)
- Concatenate features and pass to a linear layer

In [39]:
# Model definition
class VitSwinConcat(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # ViT backbone\        
        self.vit = models.vit_b_16(pretrained=True)
        # remove classification head
        self.vit.heads = nn.Identity()
        # Swin backbone
        self.swin = models.swin_t(pretrained=True)
        self.swin.head = nn.Identity()
        # both backbones output 768-d features
        self.classifier = nn.Linear(768 * 2, num_classes)

    def forward(self, x):
        f1 = self.vit(x)
        f2 = self.swin(x)
        f = torch.cat([f1, f2], dim=1)
        return self.classifier(f)

# instantiate
n_classes = len(types)
model = VitSwinConcat(n_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

## 5️⃣ Train & Validate

**Task**: Train the model and evaluate accuracy on the validation set.

**ToDo**:
- Write training and inference loops
- Track training/validation loss and accuracy
- Save the model at the end

In [40]:
# 6. Training loop
epochs = 3
for epoch in range(epochs):
    # training
    model.train()
    running_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)

    # validation
    model.eval()
    val_loss = 0.0
    correct = 0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
    val_loss /= len(val_loader.dataset)
    accuracy = correct / len(val_loader.dataset)
    print(f"Epoch {epoch}: Train Loss {epoch_loss:.4f} | Val Loss {val_loss:.4f} | Val Acc {accuracy:.4f}")

# save weights
torch.save(model.state_dict(), "vit_swin_concat.pth")

Epoch 0 [Train]:   0%|          | 0/65 [00:00<?, ?it/s]

Epoch 0 [Val]:   0%|          | 0/17 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b638f9bf600>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x7b638f9bf600>    
if w.is_alive():Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
     self._shutdown_workers() 
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
       if w.is_alive(): 
 ^ ^ ^ ^ ^ ^ ^ ^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
^^^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^^

  File "/usr/lib/python

Epoch 0: Train Loss 0.2313 | Val Loss 0.0310 | Val Acc 0.9903


Epoch 1 [Train]:   0%|          | 0/65 [00:00<?, ?it/s]

Epoch 1 [Val]:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1: Train Loss 0.0304 | Val Loss 0.0123 | Val Acc 0.9981


Epoch 2 [Train]:   0%|          | 0/65 [00:00<?, ?it/s]

Epoch 2 [Val]:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 2: Train Loss 0.0072 | Val Loss 0.0079 | Val Acc 0.9961


## 6️⃣ Sliding-Window Inference

**Task**: Write a function to classify a full image using sliding windows.

**ToDo**:
- Slide a 224×224 window with stride (e.g. 112)
- Average softmax probabilities
- Print individual patch predictions and final class


In [None]:
# Inference with sliding window
def inference_sliding_window(model, img_path, window_size=224, stride=112, device=device):
    model.eval()
    img = Image.open(img_path).convert('RGB')
    w, h = img.size
    to_tensor = T.ToTensor()
    normalize = T.Normalize(mean, std)
    probs = []
    # slide
    for y in range(0, h - window_size + 1, stride):
        for x in range(0, w - window_size + 1, stride):
            patch = img.crop((x, y, x + window_size, y + window_size))
            tensor = normalize(to_tensor(patch)).unsqueeze(0).to(device)
            with torch.no_grad():
                out = model(tensor)
                p = torch.softmax(out, dim=1).cpu().numpy()[0]
            probs.append(p)
    probs = np.stack(probs, axis=0)
    avg_prob = probs.mean(axis=0)
    final_idx = int(avg_prob.argmax())
    # print per-patch and average
    for i, p in enumerate(probs):
        print(f"Patch {i}: ", {types[j]: float(p[j]) for j in range(len(types))})
    print("Average: ", {types[j]: float(avg_prob[j]) for j in range(len(types))})
    print("Final class: ", types[final_idx])


inference_sliding_window(model, path + "images/5.png")  


Patch 0:  {'astro': 0.0003629255515988916, 'cort': 0.9995890259742737, 'shsy5y': 4.803628326044418e-05}
Patch 1:  {'astro': 0.002050854032859206, 'cort': 0.9978287816047668, 'shsy5y': 0.00012039497960358858}
Patch 2:  {'astro': 0.0003413844096940011, 'cort': 0.9996126294136047, 'shsy5y': 4.605785943567753e-05}
Patch 3:  {'astro': 0.0004650430055335164, 'cort': 0.9994938373565674, 'shsy5y': 4.10635257139802e-05}
Patch 4:  {'astro': 0.001034583430737257, 'cort': 0.9988666772842407, 'shsy5y': 9.87577805062756e-05}
Patch 5:  {'astro': 0.0003107638331130147, 'cort': 0.999643087387085, 'shsy5y': 4.611604526871815e-05}
Patch 6:  {'astro': 0.00019685731967911124, 'cort': 0.9997627139091492, 'shsy5y': 4.044066372443922e-05}
Patch 7:  {'astro': 0.00017888678121380508, 'cort': 0.9997997879981995, 'shsy5y': 2.136872717528604e-05}
Patch 8:  {'astro': 0.00012940005399286747, 'cort': 0.999840259552002, 'shsy5y': 3.0311841328511946e-05}
Patch 9:  {'astro': 0.0006890001241117716, 'cort': 0.999224543571