In [None]:
!pip install torch==2.5.0 torchvision --index-url https://download.pytorch.org/whl/cpu

# 1. Setting Up

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, random_split, Dataset

import numpy as np
import cv2
import os
import random
import matplotlib.pyplot as plt


import torchvision.transforms.v2 as T
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision.models import resnet50, ResNet50_Weights

# For bounding box manipulations (if needed)
from torchvision.ops import box_convert, box_iou

# If using Albumentations for data augmentations
# pip install albumentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# If you want to use a CRF post-processing library or morphological ops
# pip install pydensecrf, opencv-python, etc.
# import pydensecrf.densecrf as dcrf

# Import the data loading procedure you have written.
# from dataset_load import train_dataset, val_dataset, train_loader, val_loader
# OR define your train_loader and val_loader below similarly.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Load the Oxford-IIIT Pet Dataset (Classification Labels)

We only need image–class pairs, so we use:

In [5]:
from torchvision.datasets import OxfordIIITPet

# Define transforms
train_transform = T.Compose([
    T.Resize((256, 256)),
    T.RandomResizedCrop((224, 224), scale=(0.8, 1.0)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Custom wrapper to apply transform on-the-fly
class TransformDataset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        img, target = self.subset[idx]
        if self.transform:
            img = self.transform(img)
        return img, target

# Load dataset without transform first
base_dataset = OxfordIIITPet(
    root="./oxford_iiit_data",
    download=True,
    target_types="category",  # classification labels only
    split="trainval",
    transform=None,
)

# Train/val split
train_size = int(0.85 * len(base_dataset))
val_size = len(base_dataset) - train_size
train_subset, val_subset = random_split(base_dataset, [train_size, val_size])

# Wrap subsets with respective transforms
train_ds = TransformDataset(train_subset, train_transform)
val_ds = TransformDataset(val_subset, val_transform)

# Data loaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)

num_classes = 37  # There are 37 pet breeds

100%|██████████| 792M/792M [00:37<00:00, 20.9MB/s]
100%|██████████| 19.2M/19.2M [00:02<00:00, 8.74MB/s]


## 3. Model Initialization (Classifier)

Pre-trained ResNet50 for classification.

In [6]:
def get_resnet50_classifier_model(num_classes=37):
    """
    Oxford-IIIT has 37 categories (pet breeds) for classification.
    """
    model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

    # Replace the final FC layer
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

## 4. Training the Classifier

We’ll include:
  *	Learning rate scheduling (StepLR as an example)
  *	Weight decay
  *	Gradient clipping

In [11]:
def train_classifier(model, train_loader, val_loader,
                     num_epochs=10, lr=1e-3,
                     weight_decay=1e-4, clip_grad_norm=None, save_path="classifier.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.1)  # example scheduler

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:  # bboxes not loaded into train loader so not used in training here
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()

            # Gradient clipping
            if clip_grad_norm is not None:
              torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm)

            optimizer.step()
            running_loss += loss.item() * images.size(0)

        scheduler.step()

        epoch_loss = running_loss / len(train_loader.dataset)
        val_acc = evaluate_classifier(model, val_loader, device)

        print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {epoch_loss:.4f} | Val Acc: {val_acc:.4f}")


    # SAVE the classifier model after training
    torch.save(model.state_dict(), save_path)
    print(f"Classifier model saved to {save_path}")

    return model

def evaluate_classifier(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

## 5. Generating Raw CAMs (Grad-CAM)

A straightforward Grad-CAM approach:

In [None]:
class GradCAM:
    """
    Simple Grad-CAM for ResNet-based networks.
    """
    def __init__(self, model, target_layer_name="layer4"):
        self.model = model
        self.model.eval()

        # Hook the target layer
        self.target_layer = None
        for name, module in self.model.named_children():
            if name == target_layer_name:
                self.target_layer = module
                break

        if self.target_layer is None:
            raise ValueError(f"Layer {target_layer_name} not found in model")

        self.gradients = None
        self.activations = None

        def forward_hook(module, input, output):
            self.activations = output

        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]

        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)

    def __call__(self, x, class_idx=None):
        """
        x: input image tensor of shape [B, C, H, W]
        class_idx: which class index to compute CAM for. If None, uses argmax
        Returns: CAM for each image in the batch
        """
        logits = self.model(x)  # forward pass
        if class_idx is None:
            class_idx = torch.argmax(logits, dim=1)

        # Compute gradients w.r.t. target class
        one_hot = torch.zeros_like(logits)
        for i in range(logits.size(0)):
            one_hot[i, class_idx[i]] = 1.0

        self.model.zero_grad()
        logits.backward(gradient=one_hot, retain_graph=True)

        gradients = self.gradients  # [B, C, H', W']
        activations = self.activations  # [B, C, H', W']

        # Global-average-pool the gradients
        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)  # [B, C, 1, 1]

        # Weighted sum of activations
        cams = (weights * activations).sum(dim=1, keepdim=True)  # [B, 1, H', W']

        # ReLU
        cams = F.relu(cams)

        # Normalize each CAM to [0,1]
        cams = cams - cams.view(cams.size(0), -1).min(dim=1)[0].view(cams.size(0),1,1,1)
        cams = cams / (cams.view(cams.size(0), -1).max(dim=1)[0].view(cams.size(0),1,1,1) + 1e-8)

        return cams

**Bounding Box Usage:**
*	If you have bounding boxes for each image, you can mask out or scale down CAMs outside the bounding box. This can reduce background confusion. Below we do a simple approach: zero out CAM values outside the bounding box.

In [None]:
def apply_bbox_to_cam(cam, bbox, H, W):
    """
    cam: [H', W'] (already 2D)
    bbox: [x_min, y_min, x_max, y_max] in the original image scale
    H, W: original input image size
    We assume the cam is upsampled to HxW or we do the reverse scale if needed
    For simplicity, let's assume we already have cam upsampled to the input size [H, W].
    """
    # Ensure we clamp the bbox within [0, W]x[0, H]
    x_min, y_min, x_max, y_max = map(int, bbox)
    x_min = max(x_min, 0); y_min = max(y_min, 0)
    x_max = min(x_max, W-1); y_max = min(y_max, H-1)

    # Zero out outside the bounding box
    mask = np.zeros((H, W), dtype=np.float32)
    mask[y_min:y_max, x_min:x_max] = 1.0
    cam_bbox = cam * mask
    return cam_bbox

def generate_cams(model, data_loader, gradcam, device,
                  apply_bbox=True, output_dir="cams_out"):
    """
    Generate and save CAMs.
    We'll upsample the CAM to the original image size,
    then optionally mask it with the bounding box.
    """
    os.makedirs(output_dir, exist_ok=True)

    model.eval()
    gradcam.model.eval()

    with torch.no_grad():
        for i, (images, labels, bboxes) in enumerate(data_loader):
            images = images.to(device)

            # GradCAM call needs a forward and backward pass,
            # so let's do it with requires_grad
            with torch.enable_grad():
                cams_batch = gradcam(images, class_idx=labels.to(device))

            # cams_batch shape: [B, 1, H', W']
            # Upsample CAM to match input size (e.g. 224x224 if that's your input transform)
            # using bilinear interpolation
            upsampled_cams = F.interpolate(cams_batch, size=(images.shape[2], images.shape[3]),
                                           mode='bilinear', align_corners=False)

            upsampled_cams = upsampled_cams.squeeze(1).cpu().numpy()  # shape [B, H, W]

            # If original images are bigger than 224, you may have to re-scale bounding boxes and
            # possibly re-scale the CAM again. This depends on your transformations.

            for b in range(upsampled_cams.shape[0]):
                cam_2d = upsampled_cams[b]
                if apply_bbox:
                    # bboxes[b] is the bounding box for this image in original scale
                    # Make sure the scale is consistent with your input size!
                    H, W = cam_2d.shape
                    cam_2d = apply_bbox_to_cam(cam_2d, bboxes[b], H, W)

                # Save or return the CAM
                cam_path = os.path.join(output_dir, f"cam_{i*data_loader.batch_size + b}.npy")
                np.save(cam_path, cam_2d)

    print("CAM generation complete!")

## 6. Apply ReCAM (Refinement / Expansion)

Below is an illustrative approach inspired by the ReCAM paper. This usually involves:
  *	Re-scoring the CAM to ensure more complete coverage of the object.
  *	Possibly iterative expansions (e.g. random erasing, multi-scale expansions).
  *	We show a simplified version that scales up smaller areas, and we add expansions.

In [None]:
def recam_refinement(cam, expansion_factor=1.2, threshold=0.3):
    """
    Simplified approach:
    1. If the average of the top region is below a certain threshold,
       push it up (expand coverage).
    2. You could also do iterative random erasing or multi-scale expansions.
    """
    # cam: 2D np array [H, W] in [0,1]
    # Step 1: thresholding
    mask = (cam >= threshold).astype(np.uint8)
    coverage = mask.sum() / (cam.shape[0]*cam.shape[1])

    # If coverage < some ratio, inflate the activation
    if coverage < 0.1:
        cam = cam * expansion_factor
        cam = np.clip(cam, 0, 1)

    # Re-threshold
    return cam

def refine_cams_with_recam(cam_dir, refined_dir="cams_refined",
                           threshold=0.3, expansion_factor=1.2):
    os.makedirs(refined_dir, exist_ok=True)

    cam_files = [f for f in os.listdir(cam_dir) if f.endswith('.npy')]
    for cfile in cam_files:
        cam_path = os.path.join(cam_dir, cfile)
        cam = np.load(cam_path)

        refined_cam = recam_refinement(cam, expansion_factor, threshold)

        # Save
        refined_path = os.path.join(refined_dir, cfile)
        np.save(refined_path, refined_cam)

    print("ReCAM refinement complete!")

## 7. Pseudo-Label Filtering

A simple approach:
	•	We binarize or do top-k%.
	•	For example, if we do a threshold t=0.5, anything above 0.5 = foreground, else background.

In [None]:
def generate_pseudo_masks(refined_cam_dir, output_mask_dir="pseudo_masks", threshold=0.5):
    os.makedirs(output_mask_dir, exist_ok=True)

    cam_files = [f for f in os.listdir(refined_cam_dir) if f.endswith('.npy')]

    for cfile in cam_files:
        cam_path = os.path.join(refined_cam_dir, cfile)
        cam = np.load(cam_path)

        # Binarize
        pseudo_mask = (cam >= threshold).astype(np.uint8)

        # Save as PNG, for instance
        mask_path = os.path.join(output_mask_dir, cfile.replace('.npy', '.png'))
        cv2.imwrite(mask_path, pseudo_mask*255)

    print("Pseudo-label generation complete!")

## 8. Train Segmentation Model (DeepLab V3+)
Now we use the generated pseudo masks as “ground truth” for training. We’ll assume you have a new dataset that loads:
  *	(Image, PseudoMask)
  *	Possibly ignoring bounding boxes at this stage.

In [None]:
class PseudoSegDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # 0/255
        mask = (mask > 127).astype(np.uint8)  # binarize as 0/1

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask

def get_deeplab_v3(num_classes=2, pretrained=True):
    model = deeplabv3_resnet50(pretrained=pretrained)
    model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)  # output channels = 2 (background, foreground)
    return model

**Training loop for segmentation:** (Need to save weights)

In [None]:
def train_segmentation_model(seg_model, seg_train_loader, seg_val_loader,
                            num_epochs=10, lr=1e-3, weight_decay=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seg_model = seg_model.to(device)

    optimizer = optim.Adam(seg_model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

    # For binary segmentation, we can use BCEWithLogitsLoss:
    criterion = nn.CrossEntropyLoss()  # if 2-class with integer mask[0 or 1]

    for epoch in range(num_epochs):
        seg_model.train()
        total_loss = 0

        for images, masks in seg_train_loader:
            images = images.to(device)
            masks = masks.long().to(device)  # ensure correct type for CE
            optimizer.zero_grad()

            outputs = seg_model(images)['out']  # DeepLab outputs a dict
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)

        scheduler.step()
        train_loss = total_loss / len(seg_train_loader.dataset)
        val_loss = evaluate_segmentation(seg_model, seg_val_loader, criterion, device)

        print(f"[{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    return seg_model

def evaluate_segmentation(seg_model, seg_val_loader, criterion, device):
    seg_model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, masks in seg_val_loader:
            images = images.to(device)
            masks = masks.long().to(device)
            outputs = seg_model(images)['out']
            loss = criterion(outputs, masks)
            total_loss += loss.item() * images.size(0)

    return total_loss / len(seg_val_loader.dataset)

## 9. Post-Processing (Optional: CRF / Morphological Ops)

You can apply a dense CRF or morphological opening/closing to each predicted segmentation mask. Here is a simple morphological example:

In [None]:
import cv2
import numpy as np

def morphological_refinement(pred_mask, kernel_size=3):
    """
    pred_mask: np.array of shape [H, W], values in {0,1}.
    """
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    # morphological closing -> remove small holes
    refined = cv2.morphologyEx(pred_mask, cv2.MORPH_CLOSE, kernel)
    return refined

**CRF usage typically looks like:**

In [None]:
# Outline (not a fully functional snippet):
# def apply_crf(image, pred_mask_prob):
#     # Setup the dense CRF using pydensecrf
#     d = dcrf.DenseCRF2D(W, H, 2)  # 2 classes
#     # set unary potentials from the pred_mask_prob
#     # set pairwise potentials
#     # run inference
#     # return refined mask
#     pass

## 10. Evaluation on Segmentation Metrics

To evaluate, we typically compute mIoU or Dice on the ground-truth masks. For Oxford-IIIT, you have ground-truth segmentation masks. We can do something like:

In [None]:
def compute_mIoU(pred_mask, gt_mask, num_classes=2):
    """
    pred_mask, gt_mask: [H, W] in {0,1} or {0,...,num_classes-1}
    For binary, it's 0 or 1.
    """
    # Flatten
    pred_flat = pred_mask.flatten()
    gt_flat = gt_mask.flatten()

    ious = []
    for c in range(num_classes):
        pred_inds = (pred_flat == c)
        gt_inds = (gt_flat == c)
        intersection = (pred_inds & gt_inds).sum()
        union = (pred_inds | gt_inds).sum()
        if union == 0:
            iou = 1 if intersection == 0 else 0
        else:
            iou = intersection / union
        ious.append(iou)
    return np.mean(ious)

def evaluate_on_test(seg_model, test_loader, device):
    seg_model.eval()
    all_ious = []
    with torch.no_grad():
        for images, masks in test_loader:
            images = images.to(device)
            outputs = seg_model(images)['out']
            # get predictions
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            # compare with ground truth
            for b in range(preds.shape[0]):
                pred_mask = preds[b]
                gt_mask = masks[b].numpy()
                iou = compute_mIoU(pred_mask, gt_mask, num_classes=2)
                all_ious.append(iou)
    return np.mean(all_ious)

# Main

In [12]:
def main():
    # 1. Load your train/val data
    # from dataset_load import train_loader, val_loader  # custom
    # train_loader, val_loader = ...

    # 2. Create a classifier, train it
    classifier = get_resnet50_classifier_model(num_classes=37)
    classifier = train_classifier(
        model=classifier,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=10, lr=1e-3, weight_decay=1e-4, clip_grad_norm=None, save_path="classifier.pth"
    ) # If gradient clipping is needed we should assign a optimal clip_grad_norm value Ex: 5.0

    ## Load trainval dataset with Bounding Boxes

    # 3. Generate raw CAMs
    # gradcam = GradCAM(classifier, target_layer_name="layer4")  # for ResNet
    # generate_cams(
    #     model=classifier,
    #     data_loader=train_loader,   # or a combined trainval loader if you want CAM for all
    #     gradcam=gradcam,
    #     device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    #     apply_bbox=True,
    #     output_dir="cams_out"
    # )

    # # 4. Apply ReCAM refinement
    # refine_cams_with_recam("cams_out", refined_dir="cams_refined",
    #                        threshold=0.3, expansion_factor=1.2)

    # # 5. Generate pseudo masks
    # generate_pseudo_masks(refined_cam_dir="cams_refined",
    #                       output_mask_dir="pseudo_masks",
    #                       threshold=0.5)

    # # 6. Train a segmentation model (DeepLab) with these pseudo masks
    # # Build a dataset from your images and "pseudo_masks"
    # images_list = ...  # list of paths to the original images
    # pseudo_masks_list = ...  # matching list of paths in "pseudo_masks"

    # Maybe We can use different transforms here for images and pseudo_masks for segmentation
    # instead of using the transforms used for classification

    # seg_dataset = PseudoSegDataset(images_list, pseudo_masks_list, transform=train_transform)
    # seg_train_loader = DataLoader(seg_dataset, batch_size=4, shuffle=True, num_workers=2)

    # # If you have a separate validation set with pseudo masks, define seg_val_loader similarly
    # seg_val_loader = ...

    # seg_model = get_deeplab_v3(num_classes=2, pretrained=True)
    # seg_model = train_segmentation_model(
    #     seg_model, seg_train_loader, seg_val_loader,
    #     num_epochs=10, lr=1e-3, weight_decay=1e-4
    # )

    # 7. Optional: Post-processing (CRF/morphological ops) inside your inference loop.

    # 8. Testing / Final evaluation
    # test_loader -> uses ground-truth masks
    # test_miou = evaluate_on_test(seg_model, test_loader, torch.device("cuda"))
    # print(f"Test mIoU: {test_miou}")

    print("Workflow complete!")

if __name__ == "__main__":
    main()

Epoch [1/10] | Loss: 1.5643 | Val Acc: 0.5326
Epoch [2/10] | Loss: 0.8846 | Val Acc: 0.5236
Epoch [3/10] | Loss: 0.6401 | Val Acc: 0.7391
Epoch [4/10] | Loss: 0.5205 | Val Acc: 0.6993
Epoch [5/10] | Loss: 0.4552 | Val Acc: 0.7192
Epoch [6/10] | Loss: 0.2327 | Val Acc: 0.8533
Epoch [7/10] | Loss: 0.1211 | Val Acc: 0.8822
Epoch [8/10] | Loss: 0.0925 | Val Acc: 0.8822
Epoch [9/10] | Loss: 0.0671 | Val Acc: 0.8822
Epoch [10/10] | Loss: 0.0562 | Val Acc: 0.8859
Classifier model saved to classifier.pth
Workflow complete!
