<a href="https://colab.research.google.com/github/Sibusisongwenya/WIP-Project/blob/main/ucmayo4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import glob
import re
from PIL import Image
from typing import List, Tuple, Optional
import torch
from torch.utils.data import Dataset
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class UCMayo4(Dataset):
    """
    Ulcerative Colitis dataset grouped according to the Endoscopic Mayo scoring system.
    This dataset is intended for regression, returning a continuous Mayo score.
    """

    def __init__(self, root_dir: str, transform: Optional[callable] = None):
        """
        Args:
            root_dir (str): Path to the parent folder where class folders are located.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.class_names: List[str] = []  # Valid class names
        self.samples: List[Tuple[Image.Image, float]] = []  # (image, label) tuples

        # Precompile regex for efficiency; supports decimal scores (e.g., "Mayo 1.5")
        class_pattern = re.compile(r"Mayo\s*(\d+(\.\d+)?)", re.IGNORECASE)

        # Find subfolders and extract valid class names and labels.
        subfolders = sorted(glob.glob(os.path.join(root_dir, "*")))
        for folder in subfolders:
            if os.path.isdir(folder):
                folder_name = os.path.basename(folder)
                m = class_pattern.match(folder_name)
                if m:
                    try:
                        # This conversion ensures that the folder name is valid.
                        float(m.group(1))
                        self.class_names.append(folder_name)
                    except ValueError:
                        logging.warning(f"Could not convert Mayo score to float in folder '{folder_name}'. Skipping folder.")
                else:
                    logging.warning(f"Skipping folder '{folder_name}' as it does not match 'Mayo [score]' format.")

        self.number_of_class = len(self.class_names)
        logging.info(f"Found {self.number_of_class} valid Mayo classes: {self.class_names}")

        if not self.class_names:
            raise RuntimeError(f"No valid Mayo class folders found in '{root_dir}'. Please ensure subfolders are named like 'Mayo 0', 'Mayo 1', etc.")

        # Load images and corresponding labels.
        for folder in subfolders:
            if os.path.isdir(folder):
                folder_name = os.path.basename(folder)
                if folder_name in self.class_names:
                    m = class_pattern.match(folder_name)
                    if not m:
                        continue
                    try:
                        label = float(m.group(1))
                    except ValueError:
                        logging.warning(f"Failed to extract label from folder '{folder_name}'. Skipping folder.")
                        continue
                    image_paths = glob.glob(os.path.join(folder, "*"))
                    for image_path in image_paths:
                        try:
                            image = Image.open(image_path).convert('RGB')
                            image.load()
                            self.samples.append((image, label))
                        except FileNotFoundError:
                            logging.warning(f"Image file not found: {image_path}. Skipping.")
                        except Image.UnidentifiedImageError:
                            logging.warning(f"Could not decode image: {image_path}. Skipping.")
                        except Exception as e:
                            logging.error(f"Error loading image {image_path}: {e}")

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieve a sample from the dataset by index.

        Args:
            idx (int): Index of the sample.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: (transformed image, Mayo score label).
        """
        sample_image, label = self.samples[idx]
        if self.transform:
            if callable(self.transform):
                sample_image = self.transform(sample_image)
            else:
                logging.warning("Provided transform is not callable. Skipping transformation.")
        return sample_image, torch.tensor(label, dtype=torch.float32)


class UCMayo4Remission(Dataset):
    """
    Ulcerative Colitis dataset with binary remission labels based on Mayo scores.
    In this setup, scores in the 'remission' list are labeled as 0 (non-remission),
    and all other scores are labeled as 1 (remission).
    """

    def __init__(self, root_dir: str, remission: List[int] = [2, 3], transform: Optional[callable] = None):
        """
        Args:
            root_dir (str): Path to the parent folder where class folders are located.
            remission (List[int]): Mayo scores regarded as non-remission (label 0); others are remission (label 1).
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.remission = remission
        self.class_names: List[str] = []
        self.samples: List[Tuple[Image.Image, int]] = []

        # Precompile regex for efficiency; expects integer scores only.
        class_pattern = re.compile(r"Mayo\s*(\d+)", re.IGNORECASE)

        # Extract valid class names.
        subfolders = sorted(glob.glob(os.path.join(root_dir, "*")))
        for folder in subfolders:
            if os.path.isdir(folder):
                folder_name = os.path.basename(folder)
                m = class_pattern.match(folder_name)
                if m:
                    try:
                        int(m.group(1))  # Ensure the score is valid.
                        self.class_names.append(folder_name)
                    except ValueError:
                        logging.warning(f"Could not convert Mayo score to integer in folder '{folder_name}'. Skipping folder.")
                else:
                    logging.warning(f"Skipping folder '{folder_name}' as it does not match 'Mayo [score]' format.")

        self.number_of_class = len(self.class_names)
        logging.info(f"Found {self.number_of_class} valid Mayo classes for remission dataset: {self.class_names}")

        if not self.class_names:
            raise RuntimeError(f"No valid Mayo class folders found in '{root_dir}' for remission dataset.")

        # Load images and assign binary remission labels.
        for folder in subfolders:
            if os.path.isdir(folder):
                folder_name = os.path.basename(folder)
                if folder_name in self.class_names:
                    m = class_pattern.match(folder_name)
                    if not m:
                        continue
                    try:
                        label_score = int(m.group(1))
                    except ValueError:
                        logging.warning(f"Failed to extract label from folder '{folder_name}'. Skipping folder.")
                        continue
                    label = 0 if label_score in self.remission else 1
                    image_paths = glob.glob(os.path.join(folder, "*"))
                    for image_path in image_paths:
                        try:
                            image = Image.open(image_path).convert('RGB')
                            image.load()
                            self.samples.append((image, label))
                        except FileNotFoundError:
                            logging.warning(f"Image file not found: {image_path}. Skipping.")
                        except Image.UnidentifiedImageError:
                            logging.warning(f"Could not decode image: {image_path}. Skipping.")
                        except Exception as e:
                            logging.error(f"Error loading image {image_path}: {e}")

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieve a sample from the dataset by index.

        Args:
            idx (int): Index of the sample.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: (transformed image, binary remission label).
        """
        sample_image, label = self.samples[idx]
        if self.transform:
            if callable(self.transform):
                sample_image = self.transform(sample_image)
            else:
                logging.warning("Provided transform is not callable. Skipping transformation.")
        return sample_image, torch.tensor(label, dtype=torch.float32)
