<a href="https://colab.research.google.com/github/Kamohelo99/C0S711_Assignment_3/blob/Ndumiso/semi_supervised_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Semi‑Supervised Training

The model uses the previously trained super-vised learning model to leverage unlabelled images using pseudo‑labelling and consistency regularisation.


## 1. Setup

Import necessary modules and define file paths for labelled data, unlabelled data, the labels CSV, and the supervised model checkpoint.

In [None]:
!pip install -q iterative-stratification torchmetrics astropy
import os
import re
import warnings
from pathlib import Path
import json

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms, models
from PIL import Image
from tqdm.notebook import tqdm

from sklearn.preprocessing import MultiLabelBinarizer
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

from torchmetrics import MetricCollection
from torchmetrics.classification import MultilabelF1Score

warnings.filterwarnings("ignore")
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# TODO: add data in these paths
DRIVE_PATH = Path("/content/drive/MyDrive/assignmentdata")
DATA_DIR = DRIVE_PATH / "data"
CHECKPOINT_DIR = DRIVE_PATH / "checkpoints"
SPLIT_DIR = DRIVE_PATH / "splits"

# Path to the unlabeled images
unl_root = DRIVE_PATH / 'unl/unl_PNG'

# Path to the model trained in the supervised notebook
supervised_ckpt = CHECKPOINT_DIR / 'supervised_best_model.pth'

# We will load num_classes and the class names from the files created by the first notebook
CLASSES = ['Bent', 'Exotic', 'FR I', 'FR II', 'Point Source', 'S/Z shaped', 'Should be discarded', 'X-Shaped', 'typical']
num_classes = len(CLASSES)

print(f"Found {num_classes} classes: {CLASSES}")
print(f"Supervised model checkpoint path: {supervised_ckpt}")


## 2. Load datasets and model

We load the labelled dataset (typical and exotic images) and the unlabelled dataset.  We also instantiate the same network architecture and load the weights from the supervised training stage.

In [None]:
# Define Model Architecture (must match the supervised model)
def build_adapted_model(model_name="efficientnet_b0", num_classes=num_classes):
    """Adapts a pre-trained model for 1-channel input."""
    model = models.get_model(model_name, weights='IMAGENET1K_V1')
    conv_layer = model.features[0][0]
    new_conv = nn.Conv2d(1, conv_layer.out_channels,
                         kernel_size=conv_layer.kernel_size, stride=conv_layer.stride,
                         padding=conv_layer.padding, bias=conv_layer.bias is not None)
    new_conv.weight.data = conv_layer.weight.data.mean(dim=1, keepdim=True)
    model.features[0][0] = new_conv
    in_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(nn.Dropout(p=0.3), nn.Linear(in_features, num_classes))
    return model

# Define Transformations
IMG_SIZE = 128
inference_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

# Define Unlabeled Dataset Class
class UnlabeledDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.files = sorted(list(Path(folder_path).glob("*.png")))
        self.transform = transform
    def __len__(self): return len(self.files)
    def __getitem__(self, idx):
        fpath = self.files[idx]
        img = Image.open(fpath).convert('L')
        if self.transform: img = self.transform(img)
        return img, str(fpath) # Return image and its path

# Instantiate Unlabeled Dataset and DataLoader
unlabelled_dataset = UnlabeledDataset(unl_root, transform=inference_transform)
unlabelled_loader = DataLoader(unlabelled_dataset, batch_size=64, shuffle=False, num_workers=2)
print(f"Found {len(unlabelled_dataset)} unlabeled images.")

# Load Model and Supervised Weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_adapted_model(num_classes=num_classes)
if supervised_ckpt.exists():
    model.load_state_dict(torch.load(supervised_ckpt, map_location=device))
    print(f"Successfully loaded supervised weights from {supervised_ckpt.name}")
else:
    print(f"WARNING: Supervised checkpoint not found at {supervised_ckpt}. Model has random weights.")
model = model.to(device)

## 3. Generate pseudo‑labels

Use the supervised model to infer labels for the unlabelled dataset.  For each image, compute the sigmoid output and assign labels for classes whose probabilities exceed a confidence threshold (e.g. 0.95).  You must define a mapping between class indices and label names (`class_names`).

In [None]:
# The class_names are already loaded as `CLASSES` from our classes.txt file.
class_names = CLASSES
class_to_idx = {name: i for i, name in enumerate(class_names)}

confidence_threshold = 0.90 # Using a slightly lower threshold to get more labels
margin_threshold = 0.2     # For FR I/II mutual exclusivity

# Dictionary to store pseudo-labels keyed by image path
pseudo_labels = {}

print(f"Generating pseudo-labels with confidence > {confidence_threshold}")
model.eval()
with torch.no_grad():
    for images, paths in tqdm(unlabelled_loader, desc="Generating Pseudo-Labels"):
        images = images.to(device)
        outputs = model(images)
        probs = torch.sigmoid(outputs)

        for i in range(len(paths)):
            prob_vec = probs[i].cpu().numpy()

            # Select labels that pass the high confidence threshold
            pos_indices = np.where(prob_vec >= confidence_threshold)[0]
            labels_for_img = [class_names[j] for j in pos_indices]

            # Apply FR I/II mutual exclusivity rule
            if 'FR I' in labels_for_img and 'FR II' in labels_for_img:
                fr1_prob = prob_vec[class_to_idx['FR I']]
                fr2_prob = prob_vec[class_to_idx['FR II']]
                if abs(fr1_prob - fr2_prob) < margin_threshold:
                    labels_for_img.remove('FR I')
                    labels_for_img.remove('FR II')
                elif fr1_prob > fr2_prob:
                    labels_for_img.remove('FR II')
                else:
                    labels_for_img.remove('FR I')

            if labels_for_img: # Only add if there are any labels left
                pseudo_labels[paths[i]] = labels_for_img

print(f"\nGenerated {len(pseudo_labels)} pseudo-labels for the unlabeled set.")

# Inspect a few pseudo-labels
print("\nSample Pseudo-Labels")
for path, labels in list(pseudo_labels.items())[:5]:
    print(f"{Path(path).name}: {labels}")

# Save the pseudo-labels to a file for reuse
pseudo_labels_df = pd.DataFrame(pseudo_labels.items(), columns=['image_path', 'labels_list'])
pseudo_labels_df.to_csv(SPLIT_DIR / "pseudo_labels.csv", index=False)
print(f"\nPseudo-labels saved to {SPLIT_DIR / 'pseudo_labels.csv'}")


## 4. Create a combined dataset

To train on both labelled and pseudo‑labelled data, you need a dataset that returns a multi‑hot vector for each image.  Extend the `RadioDataset` class or write a new dataset class that looks up pseudo‑labels in the dictionary created above (for unlabelled images) and uses true labels for the labelled images.  Then concatenate the two datasets.

In [None]:
# Define the Dataset Class for Combined Data
# We will use the same dataset class from the supervised notebook, as it's already compatible.
class RadioDataset(Dataset):
    def __init__(self, df, classes, transform=None, is_pseudo=False):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.mlb = MultiLabelBinarizer(classes=classes)
        self.is_pseudo = is_pseudo
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['image_path']).convert('L')
        if self.transform:
            img = self.transform(img)
        labels_one_hot = self.mlb.fit_transform([row['labels_list']])[0]
        labels = torch.tensor(labels_one_hot, dtype=torch.float32)
        # Apply lower weight for pseudo-labeled samples
        weight = 0.6 if self.is_pseudo else 1.0
        return img, labels, torch.tensor(weight, dtype=torch.float32)

# Load the Original Labeled Data and New Pseudo-Labeled Data
train_df_manual = pd.read_csv(SPLIT_DIR / "train.csv")
train_df_manual["labels_list"] = train_df_manual["labels"].astype(str).apply(
    lambda s: [lbl.strip() for lbl in s.split(",") if lbl.strip()]
)

# Convert the pseudo-labels dictionary back to a DataFrame with the right format
pseudo_df = pd.DataFrame(pseudo_labels.items(), columns=['image_path', 'labels_list'])

# Instantiate the Datasets
train_tf = transforms.Compose([
    transforms.RandomRotation(360),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])
manual_dataset = RadioDataset(train_df_manual, classes=CLASSES, transform=train_tf, is_pseudo=False)
pseudo_dataset = RadioDataset(pseudo_df, classes=CLASSES, transform=train_tf, is_pseudo=True)

# Concatenate and Create Final DataLoader
combined_dataset = ConcatDataset([manual_dataset, pseudo_dataset])
combined_loader = DataLoader(combined_dataset, batch_size=32, shuffle=True, num_workers=2)

print(f"Created a combined dataset:")
print(f"  - Manual (Human) Labels: {len(manual_dataset)}")
print(f"  - Pseudo (Generated) Labels: {len(pseudo_dataset)}")
print(f"  - Total Training Samples: {len(combined_dataset)}")


## 5. Fine‑tuning loop (e.g. FixMatch)

Implement the semi‑supervised training loop.  For example, using FixMatch:

- For each unlabelled image, create a weakly augmented and a strongly augmented version.
- Use the model to produce pseudo‑labels from the weak view and enforce that the model’s prediction on the strong view matches these pseudo‑labels.
- Combine this unsupervised loss with the supervised loss on labelled data.

Below is a skeleton structure.  You must fill in the details for augmentation, loss computation, and parameter updates.

In [None]:
# We will implement the Self-Training approach, which is simpler than FixMatch
# and directly matches our project plan. The core idea is to retrain the model
# on the combined dataset with weighted loss.

# Re-initialize the Model
# We start from scratch to see the full benefit of the larger dataset.
model = build_adapted_model(num_classes=num_classes)
model = model.to(device)

# Define Optimizer and Loss
# We use a weighted loss function.
criterion = torch.nn.BCEWithLogitsLoss(reduction='none') # 'none' is crucial for weighting
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
metrics = MetricCollection({'MacroF1': MultilabelF1Score(num_labels=num_classes, average='macro')}).to(device)

# Define Validation Set
val_df = pd.read_csv(SPLIT_DIR / "val.csv")
val_df["labels_list"] = val_df["labels"].astype(str).apply(
    lambda s: [lbl.strip() for lbl in s.split(",") if lbl.strip()]
)
val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])
val_dataset = RadioDataset(val_df, classes=CLASSES, transform=val_tf)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Fine-tuning Loop
num_epochs = 50
best_f1 = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(combined_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for images, labels, weights in pbar:
        images, labels, weights = images.to(device), labels.to(device), weights.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        # Calculate unweighted loss per sample
        unweighted_loss = criterion(outputs, labels)

        # Apply sample weights and calculate the mean loss for the batch
        weighted_loss = (unweighted_loss * weights.view(-1, 1)).mean()

        weighted_loss.backward()
        optimizer.step()

        running_loss += weighted_loss.item()
        pbar.set_postfix(loss=running_loss/len(pbar))

    # Validation (using the helper function from the supervised notebook)
    def evaluate_epoch(model, loader, device, metrics_collection):
        model.eval()
        metrics_collection.reset()
        with torch.no_grad():
            for images, labels, _ in loader: # We ignore weight during validation
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                metrics_collection.update(outputs, labels.int())
        return metrics_collection.compute()

    val_metrics = evaluate_epoch(model, val_loader, device, metrics)
    val_f1 = val_metrics['MacroF1'].item()

    print(f"Epoch {epoch+1} Summary: Train Loss: {running_loss/len(combined_loader):.4f}, Val MacroF1: {val_f1:.4f}")

    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model.state_dict(), CHECKPOINT_DIR / 'semi_supervised_best_model.pth')
        print(f'New best semi-supervised model saved with F1-score: {best_f1:.4f}')

# Save the final model
torch.save(model.state_dict(), CHECKPOINT_DIR / 'semi_supervised_final_model.pth')
print(f"Training finished. Final semi-supervised model saved.")