In this Jupyter notebook, I train two models, one using WSSS and the other with a fully supervised framework. 
Issue: we would need to use one of the three pip installs - pip install opencv

Look into data augmentation? See if that changes anything?

## Training CAM model from scratch (WSSS)

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2

# 1. Setup and Data Preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Oxford-IIIT Pet dataset with image-level (category) labels.
dataset = datasets.OxfordIIITPet(root='./data', split='trainval', target_types='category', transform=transform, download=True)

# Number of classes (labels are provided as numbers starting from 0)
num_classes = len(dataset.classes)
print(f"Number of classes: {num_classes}")

# Create a DataLoader
batch_size = 32
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# 2. Model Definition 
# Initialise ResNet18 without pre-trained weights
model = models.resnet18(pretrained=False)
# Replace the final fully connected layer to match the number of classes
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 3. Loss and Optimiser
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 4. Training Loop
num_epochs = 10
model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.4f}")

# Save the trained mode
torch.save(model.state_dict(), "pet_cam_model.pth")
print("Training complete and model saved.")

# 5. Generate CAM from the Trained Model
feature_maps = None

def hook_feature(module, input, output):
    global feature_maps
    feature_maps = output.detach()

# Register hook on layer4
model.layer4.register_forward_hook(hook_feature)

def generate_cam(model, img_tensor, target_class=None):
    global feature_maps
    model.eval()
    with torch.no_grad():
        output = model(img_tensor.to(device))
    if target_class is None:
        target_class = output.argmax(dim=1).item()
    
    # Get the weights of the final fully connected layer for the target class.
    fc_weights = model.fc.weight.data.cpu().numpy() 
    
    # Get the feature maps captured by the hook: shape (1, C, H, W)
    fmap = feature_maps.cpu().numpy()[0]
    
    # Compute the CAM as the weighted sum of feature maps.
    cam = np.zeros(fmap.shape[1:], dtype=np.float32)
    for i, w in enumerate(fc_weights[target_class]):
        cam += w * fmap[i, :, :]
    
    # Normalise the CAM
    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (img_tensor.size(3), img_tensor.size(2)))
    cam -= np.min(cam)
    if np.max(cam) != 0:
        cam /= np.max(cam)
    return cam, target_class

# 6. Visualise CAM on a Sample Image
def visualize_cam(original_img, cam, target_class):
    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    plt.imshow(original_img)
    plt.title("Original Image")
    plt.axis("off")
    
    plt.subplot(1,2,2)
    plt.imshow(original_img)
    plt.imshow(cam, cmap='jet', alpha=0.5)  # overlay CAM
    plt.title(f"CAM Overlay (Class: {target_class})")
    plt.axis("off")
    
    plt.show()

# Pick a sample image from the dataset
sample_img, _ = dataset[0]
# We need the original PIL image for visualisation, so reload it without normalisation:
original_pil = Image.open(os.path.join(dataset._imgs[0])).convert('RGB') if hasattr(dataset, '_imgs') else dataset[0][0].permute(1,2,0).numpy()
sample_tensor = sample_img.unsqueeze(0)

cam, predicted_class = generate_cam(model, sample_tensor)
print(f"Predicted Class for Sample Image: {predicted_class}")

# Convert tensor to PIL image for visualisation (undo normalisation)
def tensor_to_pil(tensor):
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    tensor = inv_normalize(tensor.squeeze(0)).clamp(0, 1)
    np_img = tensor.cpu().permute(1,2,0).numpy()
    return Image.fromarray((np_img * 255).astype(np.uint8))

original_img = tensor_to_pil(sample_tensor)

visualize_cam(original_img, cam, predicted_class)


100%|██████████| 792M/792M [00:43<00:00, 18.1MB/s] 
100%|██████████| 19.2M/19.2M [00:01<00:00, 17.6MB/s]


Number of classes: 37




Epoch [1/10] - Loss: 3.5969
Epoch [2/10] - Loss: 3.2959
Epoch [3/10] - Loss: 3.0917


## Training CAM model from scratch (fully supervised)

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# 1. Data Preparation
class SegmentationMaskTransform:
    def __init__(self, size=(224, 224)):
        self.size = size

    def __call__(self, mask):
        mask = mask.resize(self.size, resample=Image.NEAREST)
        mask_np = np.array(mask).astype(np.int64)
        # The masks in this dataset have values in {1,2,3}. Subtract 1 to get {0,1,2}.
        mask_np = mask_np - 1
        return torch.from_numpy(mask_np)

# Define the transform for images.
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Define a combined transform for the dataset.
class OxfordPetSegmentationDataset(datasets.OxfordIIITPet):
    def __init__(self, root, split='trainval', transform=None, target_transform=None, download=False):
        super().__init__(root, split=split, target_types="segmentation", transform=transform,
                         target_transform=target_transform, download=download)

# Create training dataset and loader.
dataset = OxfordPetSegmentationDataset(
    root='./data',
    split='trainval',
    transform=image_transform,
    target_transform=SegmentationMaskTransform(size=(224, 224)),
    download=True
)

batch_size = 8
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# 2. Model Definition (DeepLabV3)
model = models.segmentation.deeplabv3_resnet50(pretrained=True)
# The classifier head is a DeepLabHead: we replace it.
model.classifier = models.segmentation.deeplabv3.DeepLabHead(2048, 3)

# Move model to device (GPU if available).
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 3. Loss and Optimiser
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 4. Training Loop
num_epochs = 10 

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for images, masks in dataloader:
        images = images.to(device)
        # masks should be LongTensor with shape (B, H, W)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        # DeepLabV3 returns a dict; 'out' is the segmentation prediction.
        outputs = model(images)['out']  
        # outputs shape: (B, 3, H, W)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

for epoch in range(num_epochs):
    epoch_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f}")

# Save the fully supervised model
torch.save(model.state_dict(), "fully_supervised_pet_segmentation.pth")
print("Fully supervised model training complete and saved.")

# 5. Visualisation of Predictions
def visualize_prediction(model, dataset, index=0):
    model.eval()
    image, true_mask = dataset[index]
    image_batch = image.unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image_batch)['out']
    # Get the predicted segmentation mask (choose the class with highest probability per pixel)
    pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
    
    # Convert image back to PIL for display (undo normalisation)
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    image_disp = inv_normalize(image).clamp(0,1).permute(1,2,0).cpu().numpy()
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(image_disp)
    axs[0].set_title("Input Image")
    axs[0].axis("off")
    
    axs[1].imshow(true_mask, cmap='gray')
    axs[1].set_title("Ground Truth Mask")
    axs[1].axis("off")
    
    axs[2].imshow(pred_mask, cmap='gray')
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")
    
    plt.tight_layout()
    plt.show()

# Visualise a prediction on a sample image from the dataset.
visualize_prediction(model, dataset, index=0)
