In [None]:
from pathlib import Path

import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import Dataset

%config InlineBackend.figure_format = 'retina'
%matplotlib inline

## Load dataset


In [None]:
label_df = pd.read_csv("../data/train_labels.csv")
label_df["label"].value_counts()

In [None]:
# Select 2 samples from each class
samples_per_class = 2

# Group by label and sample
selected_samples = []
for label in label_df["label"].unique():
    class_samples = label_df[label_df["label"] == label].head(samples_per_class)
    selected_samples.append(class_samples)

# Combine all selected samples
selected_df = pd.concat(selected_samples, ignore_index=True)
selected_df

## Visualize some samples


### Luminal A samples


In [None]:
# Visualize 4 Luminal A samples with their masks
luminal_a_samples = label_df[label_df["label"] == "Luminal A"].head(4)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, (_, row) in enumerate(luminal_a_samples.iterrows()):
    img_name = row["sample_index"]
    img_path = Path("../data/train_data") / img_name

    # Remove "img_" prefix if present to find mask
    mask_name = img_name[4:] if img_name.startswith("img_") else img_name
    mask_path = Path("../data/train_data") / f"mask_{mask_name}"

    # Load original image
    image = Image.open(img_path)
    axes[0, idx].imshow(image)
    axes[0, idx].axis("off")
    axes[0, idx].set_title(f"{img_name}", fontsize=10)

    # Load and display mask
    if mask_path.exists():
        mask = Image.open(mask_path).convert("L")
        axes[1, idx].imshow(mask, cmap="gray")
        axes[1, idx].axis("off")
        axes[1, idx].set_title(f"mask_{mask_name}", fontsize=10)
    else:
        axes[1, idx].text(0.5, 0.5, "No mask", ha="center", va="center")
        axes[1, idx].axis("off")

axes[0, 0].set_ylabel("Original", fontsize=12)
axes[1, 0].set_ylabel("Mask", fontsize=12)

fig.suptitle("Luminal A Samples with Masks", fontsize=14)
plt.tight_layout()
plt.show()

### Luminal B samples


In [None]:
# Visualize 4 Luminal B samples with their masks
luminal_b_samples = label_df[label_df["label"] == "Luminal B"].head(4)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, (_, row) in enumerate(luminal_b_samples.iterrows()):
    img_name = row["sample_index"]
    img_path = Path("../data/train_data") / img_name

    # Remove "img_" prefix if present to find mask
    mask_name = img_name[4:] if img_name.startswith("img_") else img_name
    mask_path = Path("../data/train_data") / f"mask_{mask_name}"

    # Load original image
    image = Image.open(img_path)
    axes[0, idx].imshow(image)
    axes[0, idx].axis("off")
    axes[0, idx].set_title(f"{img_name}", fontsize=10)

    # Load and display mask
    if mask_path.exists():
        mask = Image.open(mask_path).convert("L")
        axes[1, idx].imshow(mask, cmap="gray")
        axes[1, idx].axis("off")
        axes[1, idx].set_title(f"mask_{mask_name}", fontsize=10)
    else:
        axes[1, idx].text(0.5, 0.5, "No mask", ha="center", va="center")
        axes[1, idx].axis("off")

axes[0, 0].set_ylabel("Original", fontsize=12)
axes[1, 0].set_ylabel("Mask", fontsize=12)

fig.suptitle("Luminal B Samples with Masks", fontsize=14)
plt.tight_layout()
plt.show()

### HER2(+) samples


In [None]:
# Visualize 4 HER2(+) samples with their masks
her2_samples = label_df[label_df["label"] == "HER2(+)"].head(4)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, (_, row) in enumerate(her2_samples.iterrows()):
    img_name = row["sample_index"]
    img_path = Path("../data/train_data") / img_name

    # Remove "img_" prefix if present to find mask
    mask_name = img_name[4:] if img_name.startswith("img_") else img_name
    mask_path = Path("../data/train_data") / f"mask_{mask_name}"

    # Load original image
    image = Image.open(img_path)
    axes[0, idx].imshow(image)
    axes[0, idx].axis("off")
    axes[0, idx].set_title(f"{img_name}", fontsize=10)

    # Load and display mask
    if mask_path.exists():
        mask = Image.open(mask_path).convert("L")
        axes[1, idx].imshow(mask, cmap="gray")
        axes[1, idx].axis("off")
        axes[1, idx].set_title(f"mask_{mask_name}", fontsize=10)
    else:
        axes[1, idx].text(0.5, 0.5, "No mask", ha="center", va="center")
        axes[1, idx].axis("off")

axes[0, 0].set_ylabel("Original", fontsize=12)
axes[1, 0].set_ylabel("Mask", fontsize=12)

fig.suptitle("HER2(+) Samples with Masks", fontsize=14)
plt.tight_layout()
plt.show()

### Triple negative


In [None]:
# Visualize 4 Triple negative samples with their masks
triple_negative_samples = label_df[label_df["label"] == "Triple negative"].head(4)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, (_, row) in enumerate(triple_negative_samples.iterrows()):
    img_name = row["sample_index"]
    img_path = Path("../data/train_data") / img_name

    # Remove "img_" prefix if present to find mask
    mask_name = img_name[4:] if img_name.startswith("img_") else img_name
    mask_path = Path("../data/train_data") / f"mask_{mask_name}"

    # Load original image
    image = Image.open(img_path)
    axes[0, idx].imshow(image)
    axes[0, idx].axis("off")
    axes[0, idx].set_title(f"{img_name}", fontsize=10)

    # Load and display mask
    if mask_path.exists():
        mask = Image.open(mask_path).convert("L")
        axes[1, idx].imshow(mask, cmap="gray")
        axes[1, idx].axis("off")
        axes[1, idx].set_title(f"mask_{mask_name}", fontsize=10)
    else:
        axes[1, idx].text(0.5, 0.5, "No mask", ha="center", va="center")
        axes[1, idx].axis("off")

axes[0, 0].set_ylabel("Original", fontsize=12)
axes[1, 0].set_ylabel("Mask", fontsize=12)

fig.suptitle("Triple Negative Samples with Masks", fontsize=14)
plt.tight_layout()
plt.show()

## Apply the mask to the original image


In [None]:
def apply_mask_to_image(img_name, img_dir="../data/train_data", threshold=100):
    """
    Apply binary mask to an image to remove background.

    Args:
        img_name: str, image filename (e.g., "img_0000.png" or "0000.png")
        img_dir: str, directory containing images and masks (default: "data/train_data")
        threshold: int, threshold value for binarizing mask (default: 100)

    Returns:
        PIL Image: Masked image with background removed
    """
    img_dir = Path(img_dir)

    # Construct image path
    img_path = img_dir / img_name

    # Load original image
    image = Image.open(img_path).convert("RGB")

    # Remove "img_" prefix if present to find mask
    mask_name = img_name[4:] if img_name.startswith("img_") else img_name
    mask_path = img_dir / f"mask_{mask_name}"

    # Check if mask exists
    if not mask_path.exists():
        print(f"Warning: Mask not found for {img_name}")
        return image

    # Load mask
    mask = Image.open(mask_path).convert("L")

    # Resize mask if necessary
    if mask.size != image.size:
        mask = mask.resize(image.size, resample=Image.NEAREST)

    # Convert to numpy arrays
    image_np = np.array(image)
    mask_np = np.array(mask)

    # Create binary mask
    mask_binary = (mask_np > threshold).astype(np.uint8)

    # Expand mask to 3 channels for RGB image
    mask_3ch = np.stack([mask_binary] * 3, axis=-1)

    # Apply mask
    image_masked = image_np * mask_3ch

    # Convert back to PIL Image
    masked_image = Image.fromarray(image_masked.astype(np.uint8))

    return masked_image

In [None]:
# Test the apply_mask_to_image function
img_name = "0000.png"
img_path = Path("../data/train_data") / f"img_{img_name}"
mask_path = Path("../data/train_data") / f"mask_{img_name}"

# Load original image and mask
original_image = Image.open(img_path).convert("RGB")
mask_image = Image.open(mask_path).convert("L")

# Apply mask using the function
masked_image = apply_mask_to_image(f"img_{img_name}", threshold=100)

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(original_image)
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(mask_image, cmap="gray")
axes[1].set_title("Mask")
axes[1].axis("off")

axes[2].imshow(masked_image)
axes[2].set_title("Masked Image")
axes[2].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# compare original and masked images
img_name = "0109.png"

img_path = Path("../data/train_data") / f"img_{img_name}"
mask_path = Path("../data/train_data") / f"mask_{img_name}"
# Load original image and mask
original_image = Image.open(img_path).convert("RGB")
mask_image = Image.open(mask_path).convert("L")
# Apply mask using the function
masked_image = apply_mask_to_image(f"img_{img_name}", threshold=100)


# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(original_image)
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(mask_image, cmap="gray")
axes[1].set_title("Mask")
axes[1].axis("off")

axes[2].imshow(masked_image)
axes[2].set_title("Masked Image")
axes[2].axis("off")

plt.tight_layout()
plt.show()

## Data Augmentation


### Image resize


In [None]:
# Test image resize transformation using Albumentations for 4 classes (1 sample per class)
# Define resize transformation with Albumentations
resize_transform = A.Compose(
    [
        A.Resize(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]
)

# Get one sample from each class
samples_per_class = {}
for label in label_df["label"].unique():
    class_sample = label_df[label_df["label"] == label].iloc[0]
    samples_per_class[label] = class_sample["sample_index"]

# Visualize original vs resized images
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, (label, img_name) in enumerate(samples_per_class.items()):
    # Load original image
    img_path = Path("../data/train_data") / img_name
    original_image = Image.open(img_path).convert("RGB")
    original_np = np.array(original_image)

    # Apply Albumentations transformation
    transformed = resize_transform(image=original_np)
    resized_tensor = transformed["image"]

    # Plot original
    axes[0, idx].imshow(original_image)
    axes[0, idx].set_title(
        f"{label}\n{img_name}\nSize: {original_image.size}", fontsize=9
    )
    axes[0, idx].axis("off")

    # Plot resized (denormalize and convert tensor back to displayable format)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    resized_display = resized_tensor.permute(1, 2, 0).numpy()
    resized_display = std * resized_display + mean
    resized_display = np.clip(resized_display, 0, 1)

    axes[1, idx].imshow(resized_display)
    axes[1, idx].set_title(
        f"Resized: {resized_tensor.shape[1]}x{resized_tensor.shape[2]}", fontsize=9
    )
    axes[1, idx].axis("off")

axes[0, 0].set_ylabel("Original", fontsize=12)
axes[1, 0].set_ylabel("Resized (224x224)", fontsize=12)

fig.suptitle("Albumentations Resize Transformation (1 Sample per Class)", fontsize=14)
plt.tight_layout()
plt.show()

### Color Jitter


In [None]:
# Test Color Jitter transformation using Albumentations for 4 classes (1 sample per class)
# Define color jitter transformation with Albumentations
color_jitter_transform = A.Compose(
    [
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=1.0),
    ]
)

# Get one sample from each class
samples_per_class = {}
for label in label_df["label"].unique():
    class_sample = label_df[label_df["label"] == label].iloc[0]
    samples_per_class[label] = class_sample["sample_index"]

# Visualize original vs color jittered images
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, (label, img_name) in enumerate(samples_per_class.items()):
    # Load original image
    img_path = Path("../data/train_data") / img_name
    original_image = Image.open(img_path).convert("RGB")
    original_np = np.array(original_image)

    # Apply Color Jitter transformation
    transformed = color_jitter_transform(image=original_np)
    jittered_image = transformed["image"]

    # Plot original
    axes[0, idx].imshow(original_image)
    axes[0, idx].set_title(f"{label}\n{img_name}", fontsize=9)
    axes[0, idx].axis("off")

    # Plot color jittered
    axes[1, idx].imshow(jittered_image)
    axes[1, idx].set_title("Color Jittered", fontsize=9)
    axes[1, idx].axis("off")

axes[0, 0].set_ylabel("Original", fontsize=12)
axes[1, 0].set_ylabel("Color Jitter", fontsize=12)

fig.suptitle("Color Jitter Transformation (1 Sample per Class)", fontsize=14)
plt.tight_layout()
plt.show()

### Split masked image to multiple patches


In [None]:
def split_image_to_patches(
    image, mask, patch_size=224, overlap=0, min_tissue_ratio=0.5, threshold=100
):
    """
    Split an image into multiple patches, filtering by tissue content.

    Args:
        image: PIL Image or numpy array (RGB image)
        mask: PIL Image or numpy array (grayscale mask)
        patch_size: int, size of each square patch (default: 224)
        overlap: int, overlap between patches in pixels (default: 0)
        min_tissue_ratio: float, minimum ratio of tissue pixels required (default: 0.5)
        threshold: int, threshold value for binarizing mask (default: 100)

    Returns:
        list of PIL Images: List of valid image patches
        list of tuples: List of (x, y) coordinates for each patch
        list of float: List of tissue ratios for each patch
    """
    # Convert to numpy if needed
    if isinstance(image, Image.Image):
        img_np = np.array(image)
    else:
        img_np = image

    if isinstance(mask, Image.Image):
        mask_np = np.array(mask)
    else:
        mask_np = mask

    # Ensure mask is grayscale
    if len(mask_np.shape) == 3:
        mask_np = mask_np[:, :, 0]

    # Create binary mask
    mask_binary = (mask_np > threshold).astype(np.uint8)

    height, width = img_np.shape[:2]
    stride = patch_size - overlap

    patches = []
    coordinates = []
    tissue_ratios = []

    total_patch_pixels = patch_size * patch_size

    # Iterate over the image with the specified stride
    for y in range(0, height - patch_size + 1, stride):
        for x in range(0, width - patch_size + 1, stride):
            # Extract mask patch
            mask_patch = mask_binary[y : y + patch_size, x : x + patch_size]

            # Calculate tissue ratio
            tissue_pixels = np.sum(mask_patch)
            tissue_ratio = tissue_pixels / total_patch_pixels

            # Only keep patches with sufficient tissue content
            if tissue_ratio >= min_tissue_ratio:
                # Extract image patch
                img_patch = img_np[y : y + patch_size, x : x + patch_size]

                # Convert to PIL Image
                patch_img = Image.fromarray(img_patch.astype(np.uint8))
                patches.append(patch_img)
                coordinates.append((x, y))
                tissue_ratios.append(tissue_ratio)

    return patches, coordinates, tissue_ratios

### Macenko normalization


In [None]:
class SubtypeDataset(Dataset):
    def __init__(
        self,
        img_dir: str,
        train_labels_path: str,
        mode: str,
        transform=None,
    ):
        """
        Args:
            img_dir (str): Directory with all the images.
            train_labels_path (str): Path to the CSV file with training labels.
            mode (str): One of 'train' or 'test'.
            transform: Optional transform to be applied on a sample.
        """
        self.img_dir = Path(img_dir)
        self.transform = transform

        assert mode in ["train", "test"], "mode must be 'train', or 'test'"
        self.mode = mode

        self.label_to_idx = {}
        self.idx_to_label = {}

        if not self.mode == "test":
            # If in training mode, load labels
            if train_labels_path is None:
                raise ValueError("Training mode requires a train_labels_path!")

            df = pd.read_csv(train_labels_path)
            self.img_ids = df.iloc[:, 0].values
            self.labels = df.iloc[:, 1].values

            # Create label mappings
            unique_labels = sorted(list(set(self.labels)))
            self.label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
            self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}

            print(f"Label Mapping: {self.label_to_idx}")

        else:
            # Test mode: load all image ids from directory
            self.img_ids = sorted(
                [
                    f.name
                    for f in self.img_dir.iterdir()
                    if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
                    and not f.name.startswith("mask_")
                ]
            )

    def __len__(self):
        return len(self.img_ids)

    def _load_masked_image(self, img_name):
        """Load image and apply mask to remove background"""
        img_path = self.img_dir / img_name
        image = Image.open(img_path).convert("RGB")

        # Remove "img_" prefix if present
        if img_name.startswith("img_"):
            img_name = img_name[4:]

        # Find corresponding mask
        mask_name = f"mask_{img_name}"
        mask_path = self.img_dir / mask_name

        if mask_path.exists():
            try:
                mask = Image.open(mask_path).convert("L")
                if mask.size != image.size:
                    mask = mask.resize(image.size, resample=Image.NEAREST)

                mask_np = np.array(mask)
                mask_binary = (mask_np > 100).astype(np.uint8)
                mask_3ch = np.stack([mask_binary] * 3, axis=-1)

                image_np = np.array(image)
                image_masked = image_np * mask_3ch
                image = Image.fromarray(image_masked)
            except Exception as e:
                print(f"Error applying mask for {img_name}: {e}")

        return image

    def __getitem__(self, idx):
        img_name = self.img_ids[idx]

        # Load and apply mask
        image = self._load_masked_image(img_name)

        # Apply Transforms (Resize, Tensor, Norm)
        if self.transform:
            image = self.transform(image)

        if self.mode == "test":
            return image, img_name
        else:
            label_str = self.labels[idx]
            label_idx = self.label_to_idx[label_str]
            return image, torch.tensor(label_idx, dtype=torch.long)