## Import


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import PIL

from utils.PathologyDataset import PathologyDataset
from utils.TissueExtractor import TissueExtractor

%config InlineBackend.figure_format = 'retina'
%matplotlib inline

## Check tissue extractor


In [None]:
def visualize_extraction_strategies(test_img, test_mask):
    extractor = TissueExtractor(patch_size=224, min_annotation_pixels=1)

    # Test random strategy
    print("\n--- Testing RANDOM strategy ---")
    images_random, _ = extractor.get_valid_patches(
        test_img, test_mask, num_patches=16, strategy="random", min_distance=32
    )
    print(f"Random: Extracted {len(images_random)} patches")

    # Test grid strategy
    print("\n--- Testing GRID strategy ---")
    images_grid, _ = extractor.get_valid_patches(
        test_img,
        test_mask,
        num_patches=16,
        strategy="grid",
    )
    print(f"Grid: Extracted {len(images_grid)} patches")

    # plot extracted patches
    n_cols = max(len(images_random), len(images_grid), 1)
    fig, axes = plt.subplots(2, n_cols, figsize=(n_cols * 2, 4), squeeze=False)

    # Plot Random patches
    for i in range(n_cols):
        if i < len(images_random):
            axes[0, i].imshow(images_random[i])
        axes[0, i].axis("off")
    axes[0, 0].set_title("RANDOM strategy extracted patches", loc="left")

    # Plot Grid patches
    for i in range(n_cols):
        if i < len(images_grid):
            axes[1, i].imshow(images_grid[i])
        axes[1, i].axis("off")
    axes[1, 0].set_title("GRID strategy extracted patches", loc="left")

    plt.tight_layout()
    plt.show()

### Check Triple negative


In [None]:
test_img = np.array(PIL.Image.open("data/train_data/img_0000.png").convert("RGB"))
test_mask = np.array(PIL.Image.open("data/train_data/mask_0000.png").convert("L"))
visualize_extraction_strategies(test_img, test_mask)

In [None]:
test_img = np.array(PIL.Image.open("data/train_data/img_0033.png").convert("RGB"))
test_mask = np.array(PIL.Image.open("data/train_data/mask_0033.png").convert("L"))
visualize_extraction_strategies(test_img, test_mask)

### Luminal A


In [None]:
test_img = np.array(PIL.Image.open("data/train_data/img_0006.png").convert("RGB"))
test_mask = np.array(PIL.Image.open("data/train_data/mask_0006.png").convert("L"))
visualize_extraction_strategies(test_img, test_mask)

In [None]:
test_img = np.array(PIL.Image.open("data/train_data/img_0011.png").convert("RGB"))
test_mask = np.array(PIL.Image.open("data/train_data/mask_0011.png").convert("L"))
visualize_extraction_strategies(test_img, test_mask)

### Luminal B


In [None]:
test_img = np.array(PIL.Image.open("data/train_data/img_0021.png").convert("RGB"))
test_mask = np.array(PIL.Image.open("data/train_data/mask_0021.png").convert("L"))
visualize_extraction_strategies(test_img, test_mask)

In [None]:
test_img = np.array(PIL.Image.open("data/train_data/img_0023.png").convert("RGB"))
test_mask = np.array(PIL.Image.open("data/train_data/mask_0023.png").convert("L"))
visualize_extraction_strategies(test_img, test_mask)

### HER2(+)


In [None]:
test_img = np.array(PIL.Image.open("data/train_data/img_0029.png").convert("RGB"))
test_mask = np.array(PIL.Image.open("data/train_data/mask_0029.png").convert("L"))
visualize_extraction_strategies(test_img, test_mask)

In [None]:
test_img = np.array(PIL.Image.open("data/train_data/img_0030.png").convert("RGB"))
test_mask = np.array(PIL.Image.open("data/train_data/mask_0030.png").convert("L"))
visualize_extraction_strategies(test_img, test_mask)

## Check Dataset


In [None]:
import pandas as pd
from torchvision import transforms

TRAIN_DATA_DIR = "./data/train_data"
TEST_DATA_DIR = "./data/test_data"
TRAIN_LABELS_PATH = "./data/train_labels.csv"
# Load training labels
train_df = pd.read_csv(TRAIN_LABELS_PATH)

### Test1: Without patches (whole images)


In [None]:
# Test 1: Without patches (whole images)
print("=" * 50)
print("Test 1: Whole Images (no patches)")
print("=" * 50)

# Define simple transforms for testing
test1_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

dataset_test1 = PathologyDataset(
    data_dir=TRAIN_DATA_DIR,
    labels_df=train_df,
    transform=test1_transform,
    use_mask=True,
    use_patches=False,
    use_stain_norm=False,
    is_test=False,
)

# Gather first 4 images
num_samples = 4
imgs = []
labels = []
sample_ids = []

for i in range(num_samples):
    # Get sample
    img_tensor, label_tensor = dataset_test1[i]

    # Store data
    imgs.append(img_tensor)
    labels.append(label_tensor)
    sample_ids.append(dataset_test1.samples[i])

# Print shapes for Test 1
print(f"Test 1 - single image tensor shape: {imgs[0].shape}")
print(f"Test 1 - label tensor: {labels}")

# Plot
fig, axes = plt.subplots(1, num_samples, figsize=(16, 4))
for i in range(num_samples):
    img = imgs[i].permute(1, 2, 0).numpy()
    label_name = dataset_test1.label_encoder.inverse_transform([labels[i].item()])[0]

    axes[i].imshow(img)
    axes[i].set_title(f"{label_name}\n{sample_ids[i]}")
    axes[i].axis("off")
plt.suptitle(
    "PathologyDataset - Whole Images (First Batch)",
    fontsize=14,
)
plt.tight_layout()
plt.show(block=False)

### Test 2: With patches


In [None]:
# Test 2: With patches (random patches)
print("\n" + "=" * 50)
print("Test 2: Patch-based Images")
print("=" * 50)

dataset_patches = PathologyDataset(
    data_dir=TRAIN_DATA_DIR,
    labels_df=train_df,
    transform=test1_transform,  # Reuse transform from Test 1
    use_mask=True,
    use_patches=True,
    patch_size=32,
    num_patches=16,
    patch_strategy="random",
    min_tissue_ratio=0.05,
    use_stain_norm=False,
    is_test=False,
)

# Fetch patches for a specific sample (e.g., index 0)
idx = 0
patches_batch, label_tensor = dataset_patches[idx]  # [num_patches, C, H, W]
sample_id = dataset_patches.samples[idx]
label_name = dataset_patches.label_encoder.inverse_transform([label_tensor.item()])[0]

# Print shapes for Test 2
print(f"Sample: {sample_id} ({label_name})")
print(f"Test 2 - patches tensor shape: {patches_batch.shape}")
print(f"Test 2 - label tensor: {label_tensor}")

# Plot patches
fig, axes = plt.subplots(2, 8, figsize=(20, 6))

for i in range(16):
    row = i // 8
    col = i % 8
    if i < len(patches_batch):
        img = patches_batch[i].permute(1, 2, 0).numpy()
        axes[row, col].imshow(img)
    axes[row, col].set_title(f"Patch {i}")
    axes[row, col].axis("off")

plt.suptitle(f"PathologyDataset - Patches from {sample_id} ({label_name})", fontsize=14)
plt.tight_layout()
plt.show(block=False)

### Test 3: Grid-based patch extraction


In [None]:
# Test 3: Grid-based patch extraction
print("\n" + "=" * 50)
print("Test 3: Grid-based Patch Extraction")
print("=" * 50)

dataset_grid = PathologyDataset(
    data_dir=TRAIN_DATA_DIR,
    labels_df=train_df,
    transform=test1_transform,
    use_mask=True,
    use_patches=True,
    patch_size=64,
    num_patches=16,
    patch_strategy="grid",
    min_tissue_ratio=0.05,
    use_stain_norm=False,
    is_test=False,
)

# Fetch patches for a specific sample
idx = 0
patches_batch, label_tensor = dataset_grid[idx]
sample_id = dataset_grid.samples[idx]
label_name = dataset_grid.label_encoder.inverse_transform([label_tensor.item()])[0]

# Print shapes
print(f"Sample: {sample_id} ({label_name})")
print(f"Test 3 - patches tensor shape: {patches_batch.shape}")
print(f"Test 3 - label tensor: {label_tensor}")

# Plot patches
fig, axes = plt.subplots(2, 8, figsize=(20, 6))

for i in range(16):
    row = i // 8
    col = i % 8
    if i < len(patches_batch):
        img = patches_batch[i].permute(1, 2, 0).numpy()
        axes[row, col].imshow(img)
    axes[row, col].set_title(f"Patch {i}")
    axes[row, col].axis("off")

plt.suptitle(
    f"PathologyDataset - Grid Patches from {sample_id} ({label_name})", fontsize=14
)
plt.tight_layout()
plt.show(block=False)

## Train the first batch


In [None]:
import torch


def get_balanced_batch(dataset, num_classes=4, samples_per_class=4):
    """
    Construct a balanced batch from the dataset.
    """
    selected_indices = []

    for class_idx in range(num_classes):
        all_indices = np.where(dataset.encoded_labels == class_idx)[0]

        if len(all_indices) >= samples_per_class:
            chosen = np.random.choice(all_indices, samples_per_class, replace=False)
        else:
            chosen = np.random.choice(all_indices, samples_per_class, replace=True)

        selected_indices.extend(chosen)

    np.random.shuffle(selected_indices)

    batch_patches = []
    batch_labels = []

    print(f"Constructing balanced batch with {len(selected_indices)} samples...")
    for idx in selected_indices:
        p, l = dataset[idx]
        batch_patches.append(p)
        batch_labels.append(l)

    patches_tensor = torch.stack(batch_patches)
    labels_tensor = torch.stack(batch_labels)

    return patches_tensor, labels_tensor

### Plot the first batch


In [None]:
# ...existing code...
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

train_dataset = PathologyDataset(
    data_dir=TRAIN_DATA_DIR,
    labels_df=train_df,
    transform=test1_transform,
    use_mask=True,
    use_patches=True,
    patch_size=64,
    num_patches=16,
    patch_strategy="random",
    min_tissue_ratio=0.05,
    use_stain_norm=False,
    is_test=False,
)

# Create DataLoader with batch_size=16
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

patches, labels = get_balanced_batch(train_dataset, num_classes=4, samples_per_class=4)
# patches shape: [Batch_Size, Num_Patches, 3, H, W] -> [16, 16, 3, 224, 224]
# labels shape: [Batch_Size] -> [16]

print(f"Input batch shape: {patches.shape}")
print(f"Labels batch shape: {labels.shape}")
print(f"Labels: {labels}")
print(
    f"Label names: {[train_dataset.label_encoder.inverse_transform([label.item()])[0] for label in labels]}"
)

# Plot the first batch (show first 4 images with their 16 patches each)
num_samples_to_plot = 4
fig, axes = plt.subplots(
    num_samples_to_plot, 16, figsize=(24, num_samples_to_plot * 1.5)
)

for sample_idx in range(num_samples_to_plot):
    label_idx = labels[sample_idx].item()
    label_name = train_dataset.label_encoder.inverse_transform([label_idx])[0]

    for patch_idx in range(16):
        ax = axes[sample_idx, patch_idx] if num_samples_to_plot > 1 else axes[patch_idx]

        # Get patch and convert to numpy
        patch = patches[sample_idx, patch_idx].permute(1, 2, 0).numpy()
        ax.imshow(patch)
        ax.axis("off")

        # Add title only to first patch of each row
        if patch_idx == 0:
            ax.set_title(f"{label_name} (Label: {label_idx})", loc="left", fontsize=10)

plt.suptitle("First Batch - 4 Samples with 16 Patches Each", fontsize=14)
plt.tight_layout()
plt.show()

### Train simple CNN using first batch


In [None]:
# # Define a simple and fast CNN model
# class SimpleCNN(nn.Module):
#     def __init__(self, num_classes=4):
#         super(SimpleCNN, self).__init__()
#         self.features = nn.Sequential(
#             # Conv block 1
#             nn.Conv2d(3, 32, kernel_size=3, padding=1),
#             nn.ReLU(inplace=True),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             # Conv block 2
#             nn.Conv2d(32, 64, kernel_size=3, padding=1),
#             nn.ReLU(inplace=True),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             # Conv block 3
#             nn.Conv2d(64, 128, kernel_size=3, padding=1),
#             nn.ReLU(inplace=True),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#         )

#         self.classifier = nn.Sequential(
#             nn.AdaptiveAvgPool2d((1, 1)),
#             nn.Flatten(),
#             nn.Linear(128, 64),
#             nn.ReLU(inplace=True),
#             nn.Dropout(0.5),
#             nn.Linear(64, num_classes),
#         )

#     def forward(self, x):
#         x = self.features(x)
#         x = self.classifier(x)
#         return x


# # 2. Define Model (Simple CNN)
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# print(f"Using device: {device}")

# num_classes = len(train_dataset.label_encoder.classes_)
# model = SimpleCNN(num_classes=num_classes)
# model = model.to(device)

# print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# # 3. Define Loss and Optimizer
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# # 4. Train on the First Batch
# print("\n" + "=" * 50)
# print("Overfitting on the First Batch (50 Epochs)")
# print("=" * 50)

# model.train()

# # Get the first batch
# patches, labels = get_balanced_batch(train_dataset, num_classes=4, samples_per_class=4)
# # patches shape: [Batch_Size, Num_Patches, 3, H, W] -> [16, 16, 3, 224, 224]
# # labels shape: [Batch_Size] -> [16]

# print(f"Input batch shape: {patches.shape}")
# print(f"Labels batch shape: {labels.shape}")

# # Reshape for CNN: treat every patch as an independent image sharing the bag label
# # New shape: [Batch_Size * Num_Patches, 3, H, W]
# batch_size, num_patches, c, h, w = patches.shape
# inputs = patches.view(-1, c, h, w).to(device)  # Flatten batch and patches dimensions
# targets = (
#     labels.view(-1, 1).expand(-1, num_patches).reshape(-1).to(device)
# )  # Repeat labels for each patch

# print(f"Reshaped inputs for CNN: {inputs.shape}")
# print(f"Reshaped targets: {targets.shape}")

# # Train for 50 epochs on this single batch
# num_epochs = 500
# print(f"\nStarting training for {num_epochs} epochs...")

# for epoch in range(num_epochs):
#     # Forward pass
#     optimizer.zero_grad()
#     outputs = model(inputs)
#     loss = criterion(outputs, targets)

#     # Backward pass and optimize
#     loss.backward()
#     optimizer.step()

#     # Calculate accuracy
#     _, preds = torch.max(outputs, 1)
#     acc = (preds == targets).sum().item() / targets.size(0)

#     print(
#         f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Patch Acc: {acc:.4f}"
#     )

### Pretrained


In [None]:
# ...existing code...
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models

# 1. Setup Dataset and DataLoader
# We use the patch-based dataset configuration
train_dataset = PathologyDataset(
    data_dir=TRAIN_DATA_DIR,
    labels_df=train_df,
    transform=test1_transform,  # Resize(224) + ToTensor
    use_mask=True,
    use_patches=True,
    patch_size=64,
    num_patches=16,
    patch_strategy="random",
    min_tissue_ratio=0.05,
    use_stain_norm=False,
    is_test=False,
)

# Create DataLoader with batch_size=16
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# 2. Define Model (Small Pretrained Model - MobileNetV2)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Load pretrained MobileNetV2 (small and efficient)
model = models.mobilenet_v2(pretrained=True)

# Modify the classifier for our number of classes
num_classes = len(train_dataset.label_encoder.classes_)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
model = model.to(device)

print(f"Model: MobileNetV2")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# 3. Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. Train on the First Batch
print("\n" + "=" * 50)
print("Overfitting on the First Batch (50 Epochs)")
print("=" * 50)

model.train()

# Get the first batch
data_iter = iter(train_loader)
patches, labels = get_balanced_batch(train_dataset, num_classes=4, samples_per_class=2)
# patches shape: [Batch_Size, Num_Patches, 3, H, W] -> [16, 16, 3, 224, 224]
# labels shape: [Batch_Size] -> [16]

print(f"Input batch shape: {patches.shape}")
print(f"Labels batch shape: {labels.shape}")

# Reshape for CNN: treat every patch as an independent image sharing the bag label
# New shape: [Batch_Size * Num_Patches, 3, H, W]
batch_size, num_patches, c, h, w = patches.shape
inputs = patches.view(-1, c, h, w).to(device)  # Flatten batch and patches dimensions
targets = (
    labels.view(-1, 1).expand(-1, num_patches).reshape(-1).to(device)
)  # Repeat labels for each patch

print(f"Reshaped inputs for MobileNetV2: {inputs.shape}")
print(f"Reshaped targets: {targets.shape}")

# Train for 50 epochs on this single batch
num_epochs = 50
print(f"\nStarting training for {num_epochs} epochs...")

for epoch in range(num_epochs):
    # Forward pass
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward pass and optimize
    loss.backward()
    optimizer.step()

    # Calculate accuracy
    _, preds = torch.max(outputs, 1)
    acc = (preds == targets).sum().item() / targets.size(0)

    if (epoch + 1) % 5 == 0:
        print(
            f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Patch Acc: {acc:.4f}"
        )

print("\nTraining completed successfully!")