In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset

%config InlineBackend.figure_format = 'retina'
%matplotlib inline

In [None]:
label_df = pd.read_csv("data/train_labels.csv")
label_df["label"].value_counts()

In [None]:
# Select 2 samples from each class
samples_per_class = 2

# Group by label and sample
selected_samples = []
for label in label_df["label"].unique():
    class_samples = label_df[label_df["label"] == label].head(samples_per_class)
    selected_samples.append(class_samples)

# Combine all selected samples
selected_df = pd.concat(selected_samples, ignore_index=True)
selected_df

In [None]:
# visualize the selected samples
num_classes = len(label_df["label"].unique())
fig, axes = plt.subplots(nrows=num_classes, ncols=samples_per_class, figsize=(12, 16))

for class_idx, label in enumerate(sorted(label_df["label"].unique())):
    class_samples = selected_df[selected_df["label"] == label].iloc[:samples_per_class]

    for sample_idx, (_, row) in enumerate(class_samples.iterrows()):
        img_path = Path("data/train_data") / row["sample_index"]
        image = Image.open(img_path)

        axes[class_idx, sample_idx].imshow(image)
        axes[class_idx, sample_idx].axis("off")

        # Add label on the first column
        if sample_idx == 0:
            axes[class_idx, sample_idx].set_ylabel(
                f"{label}", fontsize=12, rotation=0, labelpad=40
            )

        # Add image filename and label as title
        axes[class_idx, sample_idx].set_title(
            f"{row['sample_index']}\n{row['label']}", fontsize=8
        )

# Add overall title
fig.suptitle("Sample Images by Class", fontsize=16, y=0.995)

plt.tight_layout()
plt.show()

In [None]:
# compare original and masked images
img_name = "0000.png"
img_path = Path("data/train_data") / f"img_{img_name}"
mask_path = Path("data/train_data") / f"mask_{img_name}"

# Load original image and mask
original_image = Image.open(img_path).convert("RGB")
mask_image = Image.open(mask_path).convert("L")

# Apply mask to create masked image
mask_np = np.array(mask_image)
mask_binary = (mask_np > 100).astype(np.uint8)
mask_3ch = np.stack([mask_binary] * 3, axis=-1)

image_np = np.array(original_image)
image_masked = image_np * mask_3ch
masked_image = Image.fromarray(image_masked)

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(original_image)
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(mask_image, cmap="gray")
axes[1].set_title("Mask")
axes[1].axis("off")

axes[2].imshow(masked_image)
axes[2].set_title("Masked Image")
axes[2].axis("off")

plt.tight_layout()
plt.show()

In [None]:
class SubtypeDataset(Dataset):
    def __init__(
        self,
        img_dir: str,
        train_labels_path: str,
        mode: str,
        transform=None,
    ):
        """
        Args:
            img_dir (str): Directory with all the images.
            train_labels_path (str): Path to the CSV file with training labels.
            mode (str): One of 'train' or 'test'.
            transform: Optional transform to be applied on a sample.
        """
        self.img_dir = Path(img_dir)
        self.transform = transform

        assert mode in ["train", "test"], "mode must be 'train', or 'test'"
        self.mode = mode

        self.label_to_idx = {}
        self.idx_to_label = {}

        if not self.mode == "test":
            # If in training mode, load labels
            if train_labels_path is None:
                raise ValueError("Training mode requires a train_labels_path!")

            df = pd.read_csv(train_labels_path)
            self.img_ids = df.iloc[:, 0].values
            self.labels = df.iloc[:, 1].values

            # Create label mappings
            unique_labels = sorted(list(set(self.labels)))
            self.label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
            self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}

            print(f"Label Mapping: {self.label_to_idx}")

        else:
            # Test mode: load all image ids from directory
            self.img_ids = sorted(
                [
                    f.name
                    for f in self.img_dir.iterdir()
                    if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
                    and not f.name.startswith("mask_")
                ]
            )

    def __len__(self):
        return len(self.img_ids)

    def _load_masked_image(self, img_name):
        """Load image and apply mask to remove background"""
        img_path = self.img_dir / img_name
        image = Image.open(img_path).convert("RGB")

        # Remove "img_" prefix if present
        if img_name.startswith("img_"):
            img_name = img_name[4:]

        # Find corresponding mask
        mask_name = f"mask_{img_name}"
        mask_path = self.img_dir / mask_name

        if mask_path.exists():
            try:
                mask = Image.open(mask_path).convert("L")
                if mask.size != image.size:
                    mask = mask.resize(image.size, resample=Image.NEAREST)

                mask_np = np.array(mask)
                mask_binary = (mask_np > 100).astype(np.uint8)
                mask_3ch = np.stack([mask_binary] * 3, axis=-1)

                image_np = np.array(image)
                image_masked = image_np * mask_3ch
                image = Image.fromarray(image_masked)
            except Exception as e:
                print(f"Error applying mask for {img_name}: {e}")

        return image

    def __getitem__(self, idx):
        img_name = self.img_ids[idx]

        # Load and apply mask
        image = self._load_masked_image(img_name)

        # Apply Transforms (Resize, Tensor, Norm)
        if self.transform:
            image = self.transform(image)

        if self.mode == "test":
            return image, img_name
        else:
            label_str = self.labels[idx]
            label_idx = self.label_to_idx[label_str]
            return image, torch.tensor(label_idx, dtype=torch.long)

In [None]:
# Test the SubtypeDataset class
dataset = SubtypeDataset(
    img_dir="data/train_data",
    train_labels_path="data/train_labels.csv",
    mode="train",
    transform=None,
)

# Plot some samples to check if dataset is correct
num_samples = 8
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i in range(num_samples):
    image, label_idx = dataset[i]
    label_str = dataset.idx_to_label[label_idx.item()]
    img_name = dataset.img_ids[i]

    axes[i].imshow(image)
    axes[i].set_title(f"{img_name}\nLabel: {label_str}", fontsize=10)
    axes[i].axis("off")

plt.suptitle("Sample Images from SubtypeDataset (with masks applied)", fontsize=14)
plt.tight_layout()
plt.show()

# Print dataset info
print(f"\nDataset size: {len(dataset)}")
print(f"Label mapping: {dataset.label_to_idx}")