In [11]:
import os
os.chdir('/Users/jakedugan/Documents/UniversityofEdinburgh/CV/cw1')  # Replace with your actual path

# Verify the change
print(os.getcwd())  # Should now show the new directory

/Users/jakedugan/Documents/UniversityofEdinburgh/CV/cw1


## Preprocessing


In [None]:
#############################################
#          Preprocessing Definition         #
#############################################

import os
import sys
import csv
import cv2
import time
import torch
import random
import shutil
import argparse
import numpy as np
from tqdm import tqdm
from pathlib import Path
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from sklearn.model_selection import train_test_split


# --------------------------
# Resizing and Padding Classes
# --------------------------

class ResizeWithPadding:
    """
    For color images. Uses bilinear interpolation.
    """
    def __init__(self, target_size=512, padding_mode='mean', force_resize=True, resize_dims=(128, 128)):
        self.target_size = target_size
        self.force_resize = force_resize
        self.resize_dims = resize_dims
        assert padding_mode in ['mean', 'reflect', 'hybrid'], \
            "Padding mode must be 'mean', 'reflect', or 'hybrid'"
        self.padding_mode = padding_mode

    def resize_image(self, image):
        h, w = image.shape[:2]
        scale = self.resize_dims[0] / max(h, w)
        new_w, new_h = int(w * scale), int(h * scale)
        resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        return resized_image

    def pad_image(self, image):
        h, w = image.shape[:2]
        # Determine target dimensions
        target_h, target_w = self.resize_dims if self.force_resize else (self.target_size, self.target_size)
        delta_w = max(0, target_w - w)
        delta_h = max(0, target_h - h)
        top = delta_h // 2
        bottom = delta_h - top
        left = delta_w // 2
        right = delta_w - left

        if self.padding_mode == 'mean':
            mean_pixel = np.mean(image, axis=(0, 1), dtype=int)
            padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
                                              cv2.BORDER_CONSTANT, value=mean_pixel.tolist())
        elif self.padding_mode == 'reflect':
            padded_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REFLECT)
        else:  # hybrid
            reflected = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REFLECT)
            mean_pixel = np.mean(image, axis=(0, 1), dtype=int)
            padded_image = cv2.copyMakeBorder(reflected, top, bottom, left, right,
                                              cv2.BORDER_CONSTANT, value=mean_pixel.tolist())
        return padded_image

    def __call__(self, img):
        # Convert PIL image to numpy array, process, and convert back
        arr = np.array(img)
        resized = self.resize_image(arr)
        padded = self.pad_image(resized)
        return Image.fromarray(padded)

class ResizeWithPaddingLabel:
    """
    For labels/masks. Uses nearest-neighbor interpolation.
    Assumes labels are single-channel images.
    """
    def __init__(self, target_size=512, force_resize=False, resize_dims=(256, 256)):
        self.target_size = target_size
        self.force_resize = force_resize
        self.resize_dims = resize_dims

    def resize_image(self, image):
        h, w = image.shape[:2]
        scale = self.resize_dims[0] / max(h, w)
        new_w, new_h = int(w * scale), int(h * scale)
        resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        return resized_image

    def pad_image(self, image):
        h, w = image.shape[:2]
        target_h, target_w = self.resize_dims if self.force_resize else (self.target_size, self.target_size)
        delta_w = max(0, target_w - w)
        delta_h = max(0, target_h - h)
        top = delta_h // 2
        bottom = delta_h - top
        left = delta_w // 2
        right = delta_w - left
        # For labels, use a constant value (e.g., 0 for background)
        padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
                                          cv2.BORDER_CONSTANT, value=0)
        return padded_image

    def __call__(self, label_img):
        arr = np.array(label_img)
        resized = self.resize_image(arr)
        padded = self.pad_image(resized)
        return Image.fromarray(padded)

# --------------------------
# Augmentation Function for Synchronized Transforms
# --------------------------

def elastic_transform_pair(image, label, alpha=34, sigma=4):
    """Apply elastic transformation to both image and label."""
    np_img = np.array(image)
    np_label = np.array(label)
    random_state = np.random.RandomState(None)
    shape = np_img.shape[:2]
    dx = (random_state.rand(*shape) * 2 - 1)
    dy = (random_state.rand(*shape) * 2 - 1)
    dx = cv2.GaussianBlur(dx, (17, 17), sigma) * alpha
    dy = cv2.GaussianBlur(dy, (17, 17), sigma) * alpha
    x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
    map_x = (x + dx).astype(np.float32)
    map_y = (y + dy).astype(np.float32)
    transformed_img = cv2.remap(np_img, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    transformed_label = cv2.remap(np_label, map_x, map_y, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REFLECT)
    return Image.fromarray(transformed_img), Image.fromarray(transformed_label)

def augment_pair(image, label, size=(256, 256), apply_color=True, **params):
    """
    Applies a random subset of augmentations in random order.
    Each augmentation is stored as a function with its own probability.

    To disable an augmentation, set its probability to 0.
    To guarantee an augmentation, set its probability to 1.
    Modify parameters (e.g., rotation_angle, translate_factor) to adjust intensity.
    """
    # Build list of augmentation functions with their associated probability.
    augmentations = []

    # Horizontal flip
    augmentations.append((
        lambda img, lbl: (F.hflip(img), F.hflip(lbl)),
        params.get("flip_prob", 0.5)  # default: 50% chance
    ))

    # Rotation (you can control intensity via "rotation_angle")
    def rotate_aug(img, lbl):
        angle = random.uniform(-params.get("rotation_angle", 5), params.get("rotation_angle", 45))
        return F.rotate(img, angle, interpolation=Image.BILINEAR), F.rotate(lbl, angle, interpolation=Image.NEAREST)
    augmentations.append((
        rotate_aug,
        params.get("rotation_prob", 0.25)  # default: always apply rotation
    ))

    # Translation (intensity via "translate_factor")
    def translate_aug(img, lbl):
        max_translate = params.get("translate_factor", 0.05) * size[0]
        tx = int(random.uniform(-max_translate, max_translate))
        ty = int(random.uniform(-max_translate, max_translate))
        return (F.affine(img, angle=0, translate=(tx, ty), scale=1.0, shear=0, interpolation=Image.BILINEAR),
                F.affine(lbl, angle=0, translate=(tx, ty), scale=1.0, shear=0, interpolation=Image.NEAREST))
    augmentations.append((
        translate_aug,
        params.get("translate_prob", 0.05)  # default: always apply translation
    ))

    # Random Resized Crop (controls framing; intensity via crop_scale_range and crop_ratio_range)
    def crop_aug(img, lbl):
        i, j, h, w = transforms.RandomResizedCrop.get_params(
            img,
            scale=params.get("crop_scale_range", (0.9, 1.0)),
            ratio=params.get("crop_ratio_range", (1.0, 1.0))
        )
        return (F.resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR),
                F.resized_crop(lbl, i, j, h, w, size, interpolation=Image.NEAREST))
    augmentations.append((
        crop_aug,
        params.get("crop_prob", 0.01)  # default: always apply crop
    ))

    # Elastic Transformation (intensity via "elastic_alpha" and "elastic_sigma")
    def elastic_aug(img, lbl):
        return elastic_transform_pair(
            img, lbl,
            alpha=params.get("elastic_alpha", 15),
            sigma=params.get("elastic_sigma", 2)
        )
    augmentations.append((
        elastic_aug,
        params.get("elastic_prob", 0.01)  # default: always apply elastic transformation
    ))

    # Random Scaling (intensity via "scaling_range")
    def scaling_aug(img, lbl):
        scale_factor = random.uniform(*params.get("scaling_range", (0.0, 0.0)))
        new_size = (int(size[0] * scale_factor), int(size[1] * scale_factor))
        return (F.resize(img, new_size, interpolation=Image.BILINEAR),
                F.resize(lbl, new_size, interpolation=Image.NEAREST))
    augmentations.append((
        scaling_aug,
        params.get("scaling_prob", 0.0)  # default: always apply scaling
    ))

    # Color augmentations: Gaussian blur and Color jitter.
    def color_aug(img, lbl):
        if random.random() < params.get("blur_prob", 0.0):
            img = img.filter(
                ImageFilter.GaussianBlur(
                    radius=random.uniform(*params.get("blur_radius_range", (0.5, 1.5)))
                )
            )
        color_jitter = transforms.ColorJitter(**params.get("color_jitter_params", {
            'brightness': 0.2, 'contrast': 0.2, 'saturation': 0.2, 'hue': 0.1
        }))
        img = color_jitter(img)
        return img, lbl
    augmentations.append((
        color_aug,
        params.get("color_prob", 0.25)  # default: always apply color adjustments if apply_color is True
    ))

    # Optionally randomize the order of augmentations.
    random.shuffle(augmentations)

    # Apply each augmentation based on its probability.
    for func, prob in augmentations:
        if random.random() < prob:
            image, label = func(image, label)

    # Optionally, enforce a final center crop to ensure the image is the desired size.
    image = F.center_crop(image, size)
    label = F.center_crop(label, size)

    return image, label

# --------------------------
# OOP Preprocessor Class
# --------------------------

class Preprocessor:
    def __init__(self, raw_color_path, raw_label_path,
                 proc_color_path, proc_label_path,
                 resize_dim=128, do_augmentation=True,
                 is_train=True, max_images=None, aug_count=10, aug_params=None):

        """
        Parameters:
          - raw_color_path: Relative path to raw color images.
          - raw_label_path: Relative path to raw label/mask images.
          - proc_color_path: Relative path to save processed color images.
          - proc_label_path: Relative path to save processed label images.
          - resize_dim: Target dimension for resizing/padding.
          - do_augmentation: Whether to augment (only for training).
          - is_train: True if processing training data.
          - max_images: Process only a subset of images (for testing the pipeline).
          - aug_count: Number of augmentations to create per image (train only).
        """
        self.aug_params = aug_params if aug_params is not None else {}

        self.raw_color_path = Path(raw_color_path)
        self.raw_label_path = Path(raw_label_path)
        self.proc_color_path = Path(proc_color_path)
        self.proc_label_path = Path(proc_label_path)
        self.resize_dim = resize_dim
        self.do_augmentation = do_augmentation and is_train
        self.is_train = is_train
        self.max_images = max_images
        self.aug_count = aug_count

        self.proc_color_path.mkdir(parents=True, exist_ok=True)
        self.proc_label_path.mkdir(parents=True, exist_ok=True)

        # Create transforms for images and labels
        self.transform_img = ResizeWithPadding(target_size=224, padding_mode='mean',
                                               force_resize=True, resize_dims=(resize_dim, resize_dim))
        self.transform_label = ResizeWithPaddingLabel(target_size=224, force_resize=True,
                                                      resize_dims=(resize_dim, resize_dim))

    def process(self):
        # Find all color images (assume image file names match label file names)
        image_extensions = (".jpg", ".jpeg", ".png")
        image_files = [f for f in self.raw_color_path.rglob("*") if f.suffix.lower() in image_extensions]
        if self.max_images is not None:
            image_files = image_files[:self.max_images]

        if not image_files:
            print("❌ No images found in", self.raw_color_path)
            return

        for img_file in image_files:
            label_file = self.raw_label_path / f"{img_file.stem}.png"
            if not label_file.exists():
                print(f"⚠️  Label for {img_file.name} not found, skipping.")
                continue

            # Open image and label (assume label is a segmentation mask in grayscale)
            img = Image.open(img_file).convert("RGB")
            label = Image.open(label_file).convert("L")

            # Apply resizing and padding to both image and label
            proc_img = self.transform_img(img)
            proc_label = self.transform_label(label)

            # Save processed (base) image and label
            proc_img.save(self.proc_color_path / f"processed_{img_file.name}")
            proc_label.save(self.proc_label_path / f"processed_{label_file.name}")
            print(f"Processed {img_file.name}")

            # If training and augmentation is enabled, create additional augmented pairs
            if self.is_train and self.do_augmentation:
                for i in range(self.aug_count):
                    aug_img, aug_label = augment_pair(proc_img, proc_label, size=(self.resize_dim, self.resize_dim), **self.aug_params)
                    aug_img.save(self.proc_color_path / f"processed_{img_file.stem}_aug_{i}{img_file.suffix}")
                    aug_label.save(self.proc_label_path / f"processed_{label_file.stem}_aug_{i}{label_file.suffix}")
                    print(f"Augmented {img_file.name} -> aug {i}")

# --------------------------
# Command Line Interface
# --------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Preprocess train/test images with labels/masks using relative paths."
    )
    parser.add_argument("--raw_color", type=str, required=True, default='raw/TrainVal/color',
                        help="Relative path to raw color images (e.g., Dataset/raw/TrainVal/color).")
    parser.add_argument("--raw_label", type=str, required=True, default='raw/TrainVal/label',
                        help="Relative path to raw label images (e.g., Dataset/raw/TrainVal/label).")
    parser.add_argument("--proc_color", type=str, required=True, default='Dataset/processed/TrainVal/color',
                        help="Relative path to save processed color images (e.g., Dataset/processed/TrainVal/color).")
    parser.add_argument("--proc_label", type=str, required=True, default='Dataset/processed/TrainVal/label',
                        help="Relative path to save processed label images (e.g., Dataset/processed/TrainVal/label).")
    parser.add_argument("--resize_dim", type=int, default=128,
                        help="Output dimension for resizing (e.g., 128 or 256).")
    parser.add_argument("--no_augment", action="store_true",
                        help="Disable augmentation (for test sets).")
    parser.add_argument("--max_images", type=int, default=None,
                        help="Process only a subset of images (default: all).")
    parser.add_argument("--aug_count", type=int, default=10,
                        help="Number of augmentations per image (train only).")
    parser.add_argument("--set_type", type=str, choices=["TrainVal", "Test"], default="TrainVal",
                        help="Dataset type: TrainVal or Test.")
    parser.add_argument("--flip_prob", type=float, default=0.5, help="Probability for horizontal flip.")
    parser.add_argument("--rotation_angle", type=float, default=5, help="Maximum rotation angle (in degrees).")
    parser.add_argument("--translate_factor", type=float, default=0.05, help="Translation factor (fraction of image size).")
    parser.add_argument("--crop_scale_min", type=float, default=0.9, help="Minimum scale for random resized crop.")
    parser.add_argument("--crop_scale_max", type=float, default=1.0, help="Maximum scale for random resized crop.")
    parser.add_argument("--crop_ratio_min", type=float, default=1.0, help="Minimum aspect ratio for random resized crop.")
    parser.add_argument("--crop_ratio_max", type=float, default=1.0, help="Maximum aspect ratio for random resized crop.")
    parser.add_argument("--elastic_alpha", type=float, default=15, help="Elastic transformation alpha parameter.")
    parser.add_argument("--elastic_sigma", type=float, default=2, help="Elastic transformation sigma parameter.")
    parser.add_argument("--scaling_min", type=float, default=0.0, help="Minimum scaling factor for random scaling.")
    parser.add_argument("--scaling_max", type=float, default=0.0, help="Maximum scaling factor for random scaling.")
    parser.add_argument("--blur_prob", type=float, default=0.3, help="Probability of applying Gaussian blur.")
    parser.add_argument("--blur_radius_min", type=float, default=0.5, help="Minimum radius for Gaussian blur.")
    parser.add_argument("--blur_radius_max", type=float, default=1.5, help="Maximum radius for Gaussian blur.")
    parser.add_argument("--color_jitter_brightness", type=float, default=0.2, help="Brightness for color jitter.")
    parser.add_argument("--color_jitter_contrast", type=float, default=0.2, help="Contrast for color jitter.")
    parser.add_argument("--color_jitter_saturation", type=float, default=0.2, help="Saturation for color jitter.")
    parser.add_argument("--color_jitter_hue", type=float, default=0.1, help="Hue for color jitter.")

    ## Test processing (resizing)
    sys.argv = ['preprocessing.py', '--raw_color', 'raw/Test/color',
                '--raw_label', '/raw/Test/label',
                '--proc_color', '/processed/Test/color',
                '--proc_label', '/processed/Test/label',
                '--resize_dim', '128',
                '--no_augment',
                '--set_type', 'Test']

    ## TrainVal augmentation
    sys.argv = ['preprocessing.py',
                '--raw_color', '/raw/TrainVal/color',
                '--raw_label', '/raw/TrainVal/label',
                '--proc_color', '/processed/TrainVal/color',
                '--proc_label', '/processed/TrainVal/label',
                '--resize_dim', '128',
                '--aug_count', '10',
                '--set_type', 'TrainVal',
                '--translate_factor', '0.02',
                '--elastic_alpha', '0.0',
                '--elastic_sigma', '0.05',
                '--blur_radius_min', '0.0',
                '--blur_radius_max', '0.05']

    args = parser.parse_args()

    # Define relative paths based on set type
    if args.set_type == "TrainVal":
        raw_color = Path("raw_Dataset/TrainVal/color")
        raw_label = Path("raw_Dataset/TrainVal/label")
        proc_color = Path("final_local_processed/TrainVal/color")
        proc_label = Path("final_local_processed/TrainVal/label")
    else:  # Test set
        raw_color = Path("raw_Dataset/Test/color")
        raw_label = Path("raw_Dataset/Test/label")
        proc_color = Path("final_local_processed/Test/color")
        proc_label = Path("final_local_processed/Test/label")

    aug_params = {
        "flip_prob": args.flip_prob,  # Already exists.
        "rotation_angle": args.rotation_angle,
        "rotation_prob": 0.25,         # Always apply rotation, for instance.
        "translate_factor": args.translate_factor,
        "translate_prob": 0.05,        # Always apply translation.
        "crop_scale_range": (args.crop_scale_min, args.crop_scale_max),
        "crop_ratio_range": (args.crop_ratio_min, args.crop_ratio_max),
        "crop_prob": 0.01,             # Always apply crop.
        "elastic_alpha": args.elastic_alpha,
        "elastic_sigma": args.elastic_sigma,
        "elastic_prob": 0.01,          # Always apply elastic transform.
        "scaling_range": (0.8, 1.2),
        "scaling_prob": 0.05,          # Always apply scaling.
        "blur_prob": args.blur_prob,
        "blur_radius_range": (args.blur_radius_min, args.blur_radius_max),
        "color_jitter_params": {
            "brightness": args.color_jitter_brightness,
            "contrast": args.color_jitter_contrast,
            "saturation": args.color_jitter_saturation,
            "hue": args.color_jitter_hue,
        },
        "color_prob": 1.0            # Always apply color adjustments if enabled.
    }

    preprocessor = Preprocessor(raw_color, raw_label,
                                proc_color, proc_label,
                                resize_dim=args.resize_dim,
                                do_augmentation=not args.no_augment,
                                is_train=(args.set_type == "Test"), # Test or TrainVal
                                max_images=args.max_images,
                                aug_count=args.aug_count,
                                aug_params=aug_params)

    preprocessor.process()

Processed basset_hound_112.jpg
Augmented basset_hound_112.jpg -> aug 0
Augmented basset_hound_112.jpg -> aug 1
Augmented basset_hound_112.jpg -> aug 2
Augmented basset_hound_112.jpg -> aug 3
Augmented basset_hound_112.jpg -> aug 4
Augmented basset_hound_112.jpg -> aug 5
Augmented basset_hound_112.jpg -> aug 6
Augmented basset_hound_112.jpg -> aug 7
Augmented basset_hound_112.jpg -> aug 8
Augmented basset_hound_112.jpg -> aug 9
Processed Siamese_193.jpg
Augmented Siamese_193.jpg -> aug 0
Augmented Siamese_193.jpg -> aug 1
Augmented Siamese_193.jpg -> aug 2
Augmented Siamese_193.jpg -> aug 3
Augmented Siamese_193.jpg -> aug 4
Augmented Siamese_193.jpg -> aug 5
Augmented Siamese_193.jpg -> aug 6
Augmented Siamese_193.jpg -> aug 7
Augmented Siamese_193.jpg -> aug 8
Augmented Siamese_193.jpg -> aug 9
Processed shiba_inu_122.jpg
Augmented shiba_inu_122.jpg -> aug 0
Augmented shiba_inu_122.jpg -> aug 1
Augmented shiba_inu_122.jpg -> aug 2
Augmented shiba_inu_122.jpg -> aug 3
Augmented shiba_i

In [None]:
import torch
from torch.utils.data import Dataset
from pathlib import Path
import numpy as np
from PIL import Image
from torchvision import transforms

class SegmentationDataset(Dataset):
    """
    A simple dataset for image segmentation.

    Assumes:
      - root_dir/color: contains the color (RGB) images.
      - root_dir/label: contains the corresponding label images.
      - Each label image has the same stem as its corresponding color image,
        and uses a .png extension.

    Optional transforms:
      - transform_img: transformation to apply to color images.
      - transform_label: transformation to apply to label images.

    Returns a tuple (image, label) where:
      - image is a Tensor of shape (C, H, W).
      - label is a Tensor of shape (H, W) containing integer class indices.
    """
    def __init__(self, root_dir, transform_img=None, transform_label=None):
        self.root_dir = Path(root_dir)
        self.color_dir = self.root_dir / "color"
        self.label_dir = self.root_dir / "label"
        self.transform_img = transform_img
        self.transform_label = transform_label

        # Gather all image files from the color directory.
        self.image_files = sorted([
            f for f in self.color_dir.iterdir()
            if f.is_file() and f.suffix.lower() in [".png", ".jpg", ".jpeg"]
        ])

        # print(f"Looking for images in {self.color_dir}")
        # print(f"Found {len(self.image_files)} images.")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load color image.
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert("RGB")

        # Build corresponding label path.
        label_path = self.label_dir / (img_path.stem + ".png")
        label = Image.open(label_path).convert("L")  # grayscale

        # print("Image path:", img_path)
        # print("Label path:", label_path)
        # print("Label file size:", label.size)
        raw_label = np.array(label, dtype=np.int64)
        # print("Unique values in raw label:", np.unique(raw_label))

        # Apply image transformation if provided; else default to ToTensor.
        if self.transform_img:
            image = self.transform_img(image)
        else:
            image = transforms.ToTensor()(image)

        # Convert label to NumPy array.
        label_np = np.array(label, dtype=np.int64)

        # Correct mapping:
        #   0 remains 0,
        #   38 or 75 become 1,
        #   255 becomes 2.
        label_new = np.zeros_like(label_np)
        label_new[label_np == 38] = 1
        label_new[label_np == 75] = 1
        label_new[label_np == 255] = 2
        label_np = label_new

        # Convert mapped label to a torch tensor (of type long).
        label_tensor = torch.from_numpy(label_np).long()

        # Optionally apply label transform.
        if self.transform_label:
            label_tensor = self.transform_label(label_tensor)

        return image, label_tensor

In [None]:
#############################################
#             Patch-based dataset           #
#############################################

class PatchBasedDataset(Dataset):
    def __init__(self, root_dir, patch_size=128, transform_img=None, transform_label=None):
        """
        root_dir: Path to e.g. /content/processed/TrainVal
        patch_size: size of the patch to randomly crop
        transform_img: transformations to apply to the color image
        transform_label: transformations to apply to the label (optional)
        """
        self.base_dataset = SegmentationDataset(root_dir, transform_img, transform_label)
        self.patch_size = patch_size

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        # Retrieve the full image and mask from the base dataset
        image, mask = self.base_dataset[idx]  # image: (C, H, W), mask: (H, W)

        # Convert to PIL for random crop
        # but do NOT use transforms.ToTensor() on the mask again
        to_pil_img = transforms.ToPILImage()

        pil_image = to_pil_img(image)  # This will produce a PIL image in RGB
        # For the mask, ensure it's uint8 so that 0..255 remains intact
        pil_mask = Image.fromarray(mask.cpu().numpy().astype(np.uint8), mode="L")

        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(
            pil_image, output_size=(self.patch_size, self.patch_size)
        )
        pil_image_patch = transforms.functional.crop(pil_image, i, j, h, w)
        pil_mask_patch = transforms.functional.crop(pil_mask, i, j, h, w)

        # Convert the cropped image back to a tensor (with normalization if needed)
        # e.g. transforms.ToTensor() or your transform_img pipeline
        # If your base_dataset already has transform_img, you can just do:
        image_patch = transforms.ToTensor()(pil_image_patch)

        # Convert the mask patch to a NumPy array to preserve integer classes
        mask_patch_np = np.array(pil_mask_patch, dtype=np.int64)

        # Convert to torch tensor of type long
        mask_patch = torch.from_numpy(mask_patch_np).long()

        return image_patch, mask_patch

In [None]:
SPLITTING

## Imports and Utils

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import pandas as pd  # if you want to save stats to CSV

# -------------- IoU for segmentation ---------------
def compute_iou(preds, targets, num_classes=3):
    """
    Computes the Intersection-over-Union (IoU) for each class.

    Args:
      preds: Tensor of predicted class indices (N, H, W)
      targets: Tensor of ground truth class indices (N, H, W)
      num_classes: Number of classes

    Returns:
      A list with IoU for each class.
    """
    ious = []
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (targets == cls)
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        if union == 0:
            ious.append(float('nan'))  # or 0.0
        else:
            ious.append(intersection / union)
    return ious

# -------------- Visualization ---------------
def visualize_reconstruction(model, dataset, device, num_samples=3):
    """
    Visualize a few reconstructed images from the autoencoder.
    """
    model.eval()
    loader = DataLoader(dataset, batch_size=num_samples, shuffle=True)
    images, _ = next(iter(loader))  # we don't care about labels for autoencoder
    images = images.to(device)
    with torch.no_grad():
        recons = model(images)
    # Move to CPU
    images = images.cpu()
    recons = recons.cpu()

    fig, axes = plt.subplots(2, num_samples, figsize=(3*num_samples, 6))
    for i in range(num_samples):
        # Original
        axes[0, i].imshow(images[i].permute(1,2,0).numpy())
        axes[0, i].set_title("Original")
        axes[0, i].axis("off")
        # Reconstructed
        axes[1, i].imshow(recons[i].permute(1,2,0).numpy())
        axes[1, i].set_title("Reconstructed")
        axes[1, i].axis("off")
    plt.tight_layout()
    plt.show()

def visualize_segmentation_predictions(model, dataset, device, num_samples=3, num_classes=3):
    """
    Visualize a few segmentation predictions from a trained model
    (frozen encoder + new decoder).
    """
    model.eval()
    loader = DataLoader(dataset, batch_size=num_samples, shuffle=True)
    images, masks = next(iter(loader))
    images = images.to(device)
    with torch.no_grad():
        outputs = model(images)  # raw logits
        preds = torch.argmax(outputs, dim=1)  # (N, H, W)
    # Move to CPU
    images = images.cpu()
    masks = masks.cpu()
    preds = preds.cpu()

    fig, axes = plt.subplots(3, num_samples, figsize=(3*num_samples, 9))
    for i in range(num_samples):
        # Input image
        axes[0, i].imshow(images[i].permute(1,2,0).numpy())
        axes[0, i].set_title("Input")
        axes[0, i].axis("off")
        # Ground truth mask
        axes[1, i].imshow(masks[i], cmap='gray', vmin=0, vmax=num_classes-1)
        axes[1, i].set_title("Ground Truth Mask")
        axes[1, i].axis("off")
        # Prediction
        axes[2, i].imshow(preds[i], cmap='gray', vmin=0, vmax=num_classes-1)
        axes[2, i].set_title("Prediction")
        axes[2, i].axis("off")
    plt.tight_layout()
    plt.show()

## Model

In [None]:
class SimpleAutoencoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=256):
        """
        A simple convolutional autoencoder for 128x128 color images.
        latent_dim: dimension of the flattened feature vector after encoder.
        """
        super().__init__()
        # ------------- ENCODER -------------
        # Input shape (B, 3, 128, 128)
        self.enc = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=4, stride=2, padding=1),  # -> (B, 32, 64, 64)
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),           # -> (B, 64, 32, 32)
            nn.ReLU(True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),          # -> (B, 128,16,16)
            nn.ReLU(True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),         # -> (B, 256, 8,8)
            nn.ReLU(True),
        )
        # Flatten before latent layer
        self.fc_enc = nn.Linear(256*8*8, latent_dim)

        # ------------- DECODER -------------
        self.fc_dec = nn.Linear(latent_dim, 256*8*8)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # -> (B,128,16,16)
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # -> (B,64,32,32)
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),   # -> (B,32,64,64)
            nn.ReLU(True),
            nn.ConvTranspose2d(32, in_channels, kernel_size=4, stride=2, padding=1), # -> (B,3,128,128)
            nn.Sigmoid()  # since we want outputs in [0,1]
        )

    def encode(self, x):
        x = self.enc(x)
        # Flatten
        x = x.view(x.size(0), -1)  # (B, 256*8*8)
        x = self.fc_enc(x)         # (B, latent_dim)
        return x

    def decode(self, z):
        z = self.fc_dec(z)                  # (B, 256*8*8)
        z = z.view(z.size(0), 256, 8, 8)    # (B, 256, 8, 8)
        out = self.dec(z)                   # (B, 3, 128, 128)
        return out

    def forward(self, x):
        z = self.encode(x)
        out = self.decode(z)
        return out

## Pretraining Encoder Reconstruction

In [None]:
def train_autoencoder(
    model, 
    train_dataset, 
    val_dataset=None, 
    epochs=50, 
    batch_size=16, 
    lr=1e-3, 
    save_every=5, 
    early_stopping_patience=None,
    device="cuda"  # or "cpu" or "mps"
):
    """
    Train the autoencoder on (image -> image) reconstruction.
    - model: SimpleAutoencoder
    - train_dataset: dataset returning (img, _)
    - val_dataset: optional validation dataset
    - epochs: max number of epochs
    - batch_size: training batch size
    - lr: learning rate
    - save_every: save model weights every N epochs
    - early_stopping_patience: number of epochs to wait for val loss improvement before stopping
    - device: "cuda", "cpu", or "mps"

    Returns the trained model.
    """
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    if val_dataset is not None:
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    else:
        val_loader = None

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    model = model.to(device)

    best_val_loss = float("inf")
    epochs_no_improve = 0

    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        for imgs, _ in train_loader:
            imgs = imgs.to(device)
            optimizer.zero_grad()
            recons = model(imgs)
            loss = criterion(recons, imgs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * imgs.size(0)

        train_epoch_loss = running_loss / len(train_loader.dataset)

        # Validation
        if val_loader is not None:
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for imgs, _ in val_loader:
                    imgs = imgs.to(device)
                    recons = model(imgs)
                    loss = criterion(recons, imgs)
                    val_loss += loss.item() * imgs.size(0)
            val_epoch_loss = val_loss / len(val_loader.dataset)
            print(f"Epoch [{epoch}/{epochs}] - Train Loss: {train_epoch_loss:.4f}, Val Loss: {val_epoch_loss:.4f}")
        else:
            val_epoch_loss = train_epoch_loss
            print(f"Epoch [{epoch}/{epochs}] - Train Loss: {train_epoch_loss:.4f}")

        # Save every 'save_every' epochs
        if epoch % save_every == 0:
            encoder_path = f"encoder_epoch_{epoch}.pth"
            decoder_path = f"decoder_epoch_{epoch}.pth"
            torch.save(model.enc.state_dict(), encoder_path)
            # The decoder is split between fc_dec and dec,
            # but you can just save the entire model's state dict,
            # or separate out the decoder parts if you prefer:
            torch.save(model.dec.state_dict(), decoder_path)
            print(f"Saved encoder & decoder checkpoints at epoch {epoch}")

        # Early stopping
        if val_loader is not None and early_stopping_patience is not None:
            if val_epoch_loss < best_val_loss:
                best_val_loss = val_epoch_loss
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= early_stopping_patience:
                    print(f"Early stopping triggered at epoch {epoch}")
                    break

    # Final save of the entire autoencoder
    torch.save(model.state_dict(), "autoencoder_final.pth")
    print("Training completed. Final model saved as autoencoder_final.pth")

    return model

In [None]:
class SegmentationDecoder(nn.Module):
    """
    A simple upsampling decoder that takes the compressed feature map
    from the autoencoder's encoder output (before flattening!)
    and produces segmentation logits.
    """
    def __init__(self, input_channels=256, num_classes=3):
        super().__init__()
        # For a 128x128 input, the autoencoder's last conv feature map 
        # is (256, 8, 8) if following the above structure. We'll decode from that shape.

        self.seg_dec = nn.Sequential(
            nn.ConvTranspose2d(input_channels, 128, kernel_size=4, stride=2, padding=1), # -> (128,16,16)
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),             # -> (64,32,32)
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),              # -> (32,64,64)
            nn.ReLU(True),
            nn.ConvTranspose2d(32, num_classes, kernel_size=4, stride=2, padding=1)      # -> (3,128,128) for 3 classes
        )
    
    def forward(self, feature_map):
        """
        feature_map: shape (B, 256, 8, 8) from the frozen encoder
        returns: (B, num_classes, 128, 128) logits
        """
        return self.seg_dec(feature_map)

class Segmenter(nn.Module):
    """
    Wraps a frozen autoencoder encoder and a new segmentation decoder.
    """
    def __init__(self, autoencoder, num_classes=3):
        super().__init__()
        # Freeze the encoder
        # We'll use everything up to the last conv layer in autoencoder.enc
        self.frozen_encoder = autoencoder.enc
        for param in self.frozen_encoder.parameters():
            param.requires_grad = False
        
        # Our new segmentation decoder
        self.seg_decoder = SegmentationDecoder(input_channels=256, num_classes=num_classes)

    def forward(self, x):
        """
        x: (B, 3, 128, 128)
        returns: logits of shape (B, num_classes, 128, 128)
        """
        # pass x through the autoencoder encoder
        with torch.no_grad():
            feature_map = self.frozen_encoder(x)  # (B, 256, 8, 8)
        # pass the feature map to our learnable segmentation decoder
        logits = self.seg_decoder(feature_map)
        return logits

## Training the New Segmentation decoder

In [None]:
def train_segmentation_decoder(
    model, 
    train_dataset, 
    val_dataset=None, 
    epochs=30, 
    batch_size=8, 
    lr=1e-3, 
    save_every=5, 
    early_stopping_patience=5, 
    num_classes=3,
    device="cuda"
):
    """
    Train the segmentation decoder (with frozen encoder) using cross-entropy.
    Evaluate IoU and/or validation loss for early stopping.

    - model: Segmenter(...) instance
    - train_dataset: segmentation dataset returning (img, mask)
    - val_dataset: optional dataset for validation
    - epochs, batch_size, lr: typical training hyperparams
    - save_every: save model weights every N epochs
    - early_stopping_patience: stops if val iou doesn't improve for this many epochs
    - num_classes: e.g. 3 (background, cat, other?), used for IoU
    - device: "cuda", "cpu", or "mps"

    Returns: the best model (highest val IoU).
    """
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    if val_dataset is not None:
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    else:
        val_loader = None
    
    # Right after you define train_loader, val_loader, etc.:
    training_stats = []

    # Only the segmentation decoder’s parameters are trainable
    optimizer = optim.Adam(model.seg_decoder.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model = model.to(device)

    best_val_iou = 0.0
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(1, epochs+1):
        start_time = time.time()

        model.train()
        running_loss = 0.0

        for imgs, masks in train_loader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            optimizer.zero_grad()
            logits = model(imgs)        # (B, num_classes, 128,128)
            loss = criterion(logits, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * imgs.size(0)

        train_epoch_loss = running_loss / len(train_loader.dataset)

        # Validation
        if val_loader is not None:
            model.eval()
            val_loss = 0.0
            all_preds = []
            all_targets = []
            with torch.no_grad():
                for val_imgs, val_masks in val_loader:
                    val_imgs = val_imgs.to(device)
                    val_masks = val_masks.to(device)
                    val_logits = model(val_imgs)
                    val_batch_loss = criterion(val_logits, val_masks)
                    val_loss += val_batch_loss.item() * val_imgs.size(0)

                    preds = torch.argmax(val_logits, dim=1)
                    all_preds.append(preds.cpu())
                    all_targets.append(val_masks.cpu())

            val_loss /= len(val_loader.dataset)
            all_preds = torch.cat(all_preds, dim=0)
            all_targets = torch.cat(all_targets, dim=0)
            ious = compute_iou(all_preds, all_targets, num_classes=num_classes)
            mean_iou = np.nanmean(ious)

            print(f"Epoch [{epoch}/{epochs}] - Train Loss: {train_epoch_loss:.4f}, "
                  f"Val Loss: {val_loss:.4f}, Val mIoU: {mean_iou:.4f}, per-class: {ious}")
        else:
            # If no val_loader, just print train loss
            val_loss = train_epoch_loss
            mean_iou = float('nan')
            print(f"Epoch [{epoch}/{epochs}] - Train Loss: {train_epoch_loss:.4f}")
        
        end_time = time.time()          # <--- End timer
        epoch_duration = end_time - start_time

        # Store stats for this epoch
        training_stats.append({
            'epoch': epoch,
            'train_loss': train_epoch_loss,
            'val_loss': val_loss,
            'val_mIoU': mean_iou,
            'time_sec': epoch_duration
        })

        print(f"Epoch [{epoch}/{epochs}] took {epoch_duration:.2f}s. "
              f"Train Loss: {train_epoch_loss:.4f}, Val Loss: {val_loss:.4f}, IoU: {mean_iou:.4f}")

        # Save checkpoint every 'save_every' epochs
        if epoch % save_every == 0:
            checkpoint_path = f"segmenter_epoch_{epoch}.pth"
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Saved segmentation model checkpoint at epoch {epoch}")

        # Early Stopping on IoU
        if val_loader is not None:
            if mean_iou > best_val_iou:
                best_val_iou = mean_iou
                best_model_state = model.state_dict()
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= early_stopping_patience:
                    print(f"Early stopping triggered at epoch {epoch}")
                    break

    # If we had validation, load the best model state
    if val_loader is not None and best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model state with mIoU={best_val_iou:.4f}")

    # Final save
    torch.save(model.state_dict(), "segmenter_final.pth")
    print("Segmentation training completed. Final model saved as segmenter_final.pth")

    # At the very end of the function (before return model)
    df = pd.DataFrame(training_stats)
    df.to_csv("segmentation_training_log.csv", index=False)
    print("Saved training log to segmentation_training_log.csv")
    
    return model

## Main Script

In [None]:
if __name__ == "__main__":
    import random
    from torchvision import transforms
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # ---------------------------
    # 1) Load your training data
    # ---------------------------
    # Suppose you have a train dataset of cats (128x128) for autoencoder:
    # We do not need the labels for autoencoder training, but the dataset
    # might give you (img, mask). We'll just ignore the mask in training.
    #from your_code import SegmentationDataset  # or define inline
    from your_code import SegmentationDataset  # or define inline
    transform_img = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_dataset = SegmentationDataset(
        root_dir="./Dataset/processed/TrainVal",
        transform_img=transform_img,
        transform_label=None
    )

    # Optionally split train_dataset into (train, val)
    val_ratio = 0.1
    val_size = int(len(train_dataset)*val_ratio)
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # ---------------------------
    # 2) Train the autoencoder
    # ---------------------------
    autoencoder = SimpleAutoencoder(in_channels=3, latent_dim=256)
    autoencoder = train_autoencoder(
        model=autoencoder,
        train_dataset=train_dataset,
        val_dataset=val_dataset,     # or None
        epochs=50,
        batch_size=16,
        lr=1e-3,
        save_every=5,
        early_stopping_patience=5,
        device=device
    )

    # Visualize some reconstructions
    visualize_reconstruction(autoencoder, train_dataset, device=device, num_samples=3)

    # ---------------------------
    # 3) Freeze encoder, build new segmentation model
    # ---------------------------
    segmenter = Segmenter(autoencoder, num_classes=3)

    # We now train the segmentation decoder with the same (img, mask) pairs.
    # Optionally load a different dataset or reuse partial.
    # If you have a specific segmentation training set separate from the cat images, load it here.
    # For illustration, let's reuse the same dataset (which may or may not have cat/dog masks).
    # But in practice, you'd have your labeled data for segmentation.
    seg_train_dataset, seg_val_dataset = random_split(train_dataset, [int(0.8*train_size), train_size - int(0.8*train_size)])
    
    segmenter = train_segmentation_decoder(
        model=segmenter,
        train_dataset=seg_train_dataset,
        val_dataset=seg_val_dataset,   # or your real val dataset
        epochs=30,
        batch_size=8,
        lr=1e-3,
        save_every=5,
        early_stopping_patience=5,
        num_classes=3,
        device=device
    )

    # ---------------------------
    # 4) Evaluate on Test Data
    # ---------------------------
    test_dataset = SegmentationDataset(
        root_dir="./Dataset/processed/Test",
        transform_img=transform_img,
        transform_label=None  # or your mask transform
    )

    # Evaluate IoU on test set
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)
    segmenter.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for imgs, masks in test_loader:
            imgs = imgs.to(device)
            masks = masks.to(device)
            logits = segmenter(imgs)
            loss = criterion(logits, masks)
            total_loss += loss.item() * imgs.size(0)
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_targets.append(masks.cpu())

    test_loss = total_loss / len(test_dataset)
    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    ious = compute_iou(all_preds, all_targets, num_classes=3)
    mean_iou = np.nanmean(ious)

    print(f"Test Loss: {test_loss:.4f}, Test mIoU: {mean_iou:.4f}, per-class IoU: {ious}")

    # Show a few test predictions
    visualize_segmentation_predictions(segmenter, test_dataset, device=device, num_samples=3, num_classes=3)