## üåê **Google Drive and Kaggle Connection**

In [None]:
# Install and configure Kaggle API
!pip install -q kaggle

import os
from google.colab import drive
from google.colab import files

# --- Configuration ---
# Percorso di destinazione sul Drive
DRIVE_MOUNT_PATH = "/content/gdrive"
DATASET_PATH = f"{DRIVE_MOUNT_PATH}/MyDrive/Artificial_Neural_Networks/Images_Classification_Challenge/dataset"
COMPETITION_ID = "an2dl2526c2v2"
ZIP_FILENAME = f"{COMPETITION_ID}.zip"
EXPECTED_ZIP_FILE = os.path.join(DATASET_PATH, ZIP_FILENAME)

# --- 1. Mount Google Drive ---
print("1. Mounting Google Drive...")
# Mount G-Drive if it's not already mounted
if not os.path.exists(DRIVE_MOUNT_PATH):
    drive.mount(DRIVE_MOUNT_PATH)
else:
    print("Drive already mounted.")

# --- 2. Check for Existing Data ---
if os.path.exists(EXPECTED_ZIP_FILE):
    print(f"\n‚úÖ Dataset found at {DATASET_PATH}. Skipping download and setup.")
    # You can also add a check here for the unzipped folders if you prefer.
else:
    # --- 3. Setup Kaggle Credentials (Only if download is needed) ---
    print("\n‚è≥ Dataset not found. Starting Kaggle setup and download.")

    # 3a. Upload kaggle.json
    print("Carica il file kaggle.json (scaricabile dal tuo profilo Kaggle)")

    # Check if files.upload() returned a file (if running interactively)
    # The uploaded dictionary keys are the filenames.
    uploaded = files.upload()

    if "kaggle.json" in uploaded:
        # 3b. Configura le credenziali
        !mkdir -p ~/.kaggle
        !mv kaggle.json ~/.kaggle/
        !chmod 600 ~/.kaggle/kaggle.json
        print("Kaggle credentials configured.")

        # 3c. Create destination folder and Download
        !mkdir -p {DATASET_PATH}
        print(f"Downloading dataset to: {DATASET_PATH}")
        # Scarica il dataset direttamente da Kaggle nella cartella scelta
        !kaggle competitions download -c {COMPETITION_ID} -p {DATASET_PATH}

    else:
        print("\n‚ö†Ô∏è kaggle.json not uploaded. Cannot proceed with download.")


# --- 4. Decompress Data (Always check for existence before unzipping) ---

EXPECTED_UNZIPPED_FILE = os.path.join(DATASET_PATH, "train_labels.csv")

if os.path.exists(EXPECTED_UNZIPPED_FILE):
    # Check if the key unzipped file is already there
    print(f"\nüì¶ Data appears to be already unzipped (found: {os.path.basename(EXPECTED_UNZIPPED_FILE)}). Skipping decompression.")

elif os.path.exists(EXPECTED_ZIP_FILE):
    # Only unzip if the zip file is present AND the unzipped files are missing
    print("\nüì¶ Competition zip found but data not yet extracted. Starting decompression...")

    # -o flag is usually kept for safety, but if you want STRICTLY NO overwrite, you can remove it.
    # For speed optimization, we rely on the outer 'if' block to skip the entire step.
    !unzip -o {EXPECTED_ZIP_FILE} -d {DATASET_PATH}

else:
    # This scenario means neither the zip nor the unzipped files were found.
    # This should only happen if the preceding download step failed or was skipped.
    print("\n‚ö†Ô∏è Cannot decompress: Competition zip file is missing.")

print(f"\nFinal status: Dataset available in: {DATASET_PATH}")

1. Mounting Google Drive...
Mounted at /content/gdrive

‚úÖ Dataset found at /content/gdrive/MyDrive/Artificial_Neural_Networks/Images_Classification_Challenge/dataset. Skipping download and setup.

üì¶ Data appears to be already unzipped (found: train_labels.csv). Skipping decompression.

Final status: Dataset available in: /content/gdrive/MyDrive/Artificial_Neural_Networks/Images_Classification_Challenge/dataset


## ‚öôÔ∏è **Libraries Import**

In [None]:
# Set seed for reproducibility
SEED = 42

# Import necessary libraries
import os

# Set environment variables before importing modules
os.environ['PYTHONHASHSEED'] = str(SEED)
os.environ['MPLCONFIGDIR'] = os.getcwd() + '/configs/'

# Suppress warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

# Import necessary modules
import logging
import random
import numpy as np

# Set seeds for random number generators in NumPy and Python
np.random.seed(SEED)
random.seed(SEED)

# Import PyTorch
import torch
torch.manual_seed(SEED)
from torch import nn
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision.transforms import v2 as transforms
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import OxfordIIITPet
from torchvision.transforms import InterpolationMode
!pip install torchview
from torchview import draw_graph

# Configurazione di TensorBoard e directory
logs_dir = "tensorboard"
!pkill -f tensorboard
%load_ext tensorboard
!mkdir -p models

if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")

# Import other libraries
import requests
from io import BytesIO
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# Configure plot display settings
sns.set(font_scale=1.4)
sns.set_style('white')
plt.rc('font', size=14)
%matplotlib inline

Collecting torchview
  Downloading torchview-0.2.7-py3-none-any.whl.metadata (13 kB)
Downloading torchview-0.2.7-py3-none-any.whl (26 kB)
Installing collected packages: torchview
Successfully installed torchview-0.2.7
PyTorch version: 2.9.0+cu126
Device: cuda


## ‚è≥ **Data Loading**

In [None]:
import os
import pandas as pd

def delete_samples_from_list(txt_file, img_dir, labels_csv):
    """
    - Reads IDs from the txt file
    # - Converts them into file names like img_XXXX.png
    - Deletes img_XXXX.png and mask_XXXX.png
    - Removes the row in the CSV where sample_index == img_XXXX.png
    """

    # --- 1. Load IDs from the file (without modification)
    with open(txt_file, "r") as f:
        ids_raw = [line.strip() for line in f if line.strip()]

    print(f"IDs to be deleted found: {len(ids_raw)}")

    # --- 2. Create the complete file names as they appear in the CSV
    filenames = [f"img_{idx}.png" for idx in ids_raw]

    # --- 3. Delete images and masks
    removed_files = []

    for fname in filenames:
        img_path  = os.path.join(img_dir, fname)
        mask_path = os.path.join(img_dir, fname.replace("img_", "mask_"))

        for path in (img_path, mask_path):
            if os.path.exists(path):
                os.remove(path)
                removed_files.append(path)

    print(f"üóëÔ∏è Deleted {len(removed_files)} image/mask files.")

    # --- 4. Delete the labels from the CSV
    df = pd.read_csv(labels_csv)

    # Ensure the column is a string type
    df["sample_index"] = df["sample_index"].astype(str)

    # Filter out rows where 'sample_index' is in the list of 'filenames'
    df_new = df[~df["sample_index"].isin(filenames)]

    # Save the new DataFrame back to the CSV
    df_new.to_csv(labels_csv, index=False)

    print(f"üìÑ Remaining labels saved: {len(df_new)}")
    print("‚úÖ Cleanup completed.")

In [None]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
from PIL import Image

# Define the final directories containing the images/masks
# These are the directories created by the unzipping above.
train_img_dir = os.path.join(DATASET_PATH, "train_data")
test_img_dir = os.path.join(DATASET_PATH, "test_data")
labels_file = os.path.join(DATASET_PATH, "train_labels.csv")

# Delete Shrek ans splash images
delete_path = os.path.join(DATASET_PATH, "delete.txt")
delete_samples_from_list(
     txt_file=delete_path,
     img_dir=train_img_dir,
     labels_csv=labels_file
)

# --- Load Labels and Map Classes ---

# Load the labels file
labels_df = pd.read_csv(labels_file)

# The classes are string labels (e.g., 'HER2(+)', 'Luminal B'). We need to map them to integers.
# This also ensures we get the ordered list of class names.
le = LabelEncoder()
labels_df['label_encoded'] = le.fit_transform(labels_df['label'])

# Store the class names and mapping
class_names = list(le.classes_)
num_classes = len(class_names)
class_mapping = dict(zip(le.classes_, le.transform(le.classes_)))

print(f"\nFound {len(labels_df)} training samples.")
print(f"Number of classes: {num_classes}")
print(f"Class Names: {class_names}")
print(f"Label Mapping: {class_mapping}")
print(labels_df)


IDs to be deleted found: 111
üóëÔ∏è Deleted 0 image/mask files.
üìÑ Remaining labels saved: 581
‚úÖ Cleanup completed.

Found 581 training samples.
Number of classes: 4
Class Names: ['HER2(+)', 'Luminal A', 'Luminal B', 'Triple negative']
Label Mapping: {'HER2(+)': np.int64(0), 'Luminal A': np.int64(1), 'Luminal B': np.int64(2), 'Triple negative': np.int64(3)}
     sample_index            label  label_encoded
0    img_0000.png  Triple negative              3
1    img_0002.png        Luminal B              2
2    img_0003.png        Luminal B              2
3    img_0004.png        Luminal B              2
4    img_0006.png        Luminal A              1
..            ...              ...            ...
576  img_0686.png  Triple negative              3
577  img_0687.png  Triple negative              3
578  img_0688.png        Luminal A              1
579  img_0689.png        Luminal A              1
580  img_0690.png        Luminal A              1

[581 rows x 3 columns]


In [None]:
# --- Class Distribution ---

class_counts = labels_df['label'].value_counts().sort_index()
class_percent = labels_df['label'].value_counts(normalize=True).sort_index() * 100

print("\n=== Class Distribution ===")
for cls in class_counts.index:
    print(f"{cls:15}  Count: {class_counts[cls]:4d}   ({class_percent[cls]:5.2f}%)")



=== Class Distribution ===
HER2(+)          Count:  150   (25.82%)
Luminal A        Count:  158   (27.19%)
Luminal B        Count:  204   (35.11%)
Triple negative  Count:   69   (11.88%)


In [None]:
import numpy as np
import cv2

def is_green_artifact_lab(img, a_thresh=100, ratio_thresh=0.01):
    """
    Identifica artefatti verdi usando il canale 'a' di LAB.
    """
    img_arr = np.array(img)
    lab = cv2.cvtColor(img_arr, cv2.COLOR_RGB2LAB)
    A = lab[:,:,1]

    green_pixels = (A < a_thresh)
    ratio = green_pixels.mean()

    return ratio > ratio_thresh

In [None]:
import os
from PIL import Image

contaminated = []
clean = []

for fname in os.listdir(train_img_dir):
     if not fname.startswith("img_"):
         continue

     path = os.path.join(train_img_dir, fname)
     img = Image.open(path).convert("RGB")

     if is_green_artifact_lab(img):
         contaminated.append(fname)
     else:
         clean.append(fname)

print("Totale immagini analizzate :", len(clean) + len(contaminated))
print("Immagini contaminate       :", len(contaminated))
print("Immagini pulite            :", len(clean))

In [None]:
import matplotlib.pyplot as plt

def show_contaminated(samples, base_path, n=20):
    if len(samples) == 0:
         print("Nessuna immagine contaminata trovata.")
         return

     plt.figure(figsize=(40,20))

     for i, fname in enumerate(samples[:n]):
         img = Image.open(os.path.join(base_path, fname))

         plt.subplot(2, (n+1)//2, i+1)
         plt.imshow(img)
         plt.title(fname)
         plt.axis("off")

     plt.suptitle("Immagini contaminate rilevate dal filtro LAB")
     plt.show()


show_contaminated(contaminated, train_img_dir, n=6)

In [None]:
import pandas as pd

def delete_contaminated_samples(img_list, img_dir, labels_csv):
    """
    Cancella automaticamente:
#     - img_XXXX.png
#     - mask_XXXX.png
#     - la riga corrispondente nel CSV (sample_index)
     """
     removed_files = []

     for fname in img_list:
         img_path  = os.path.join(img_dir, fname)
         mask_path = os.path.join(img_dir, fname.replace("img_", "mask_"))

         # Cancella immagine
         if os.path.exists(img_path):
             os.remove(img_path)
             removed_files.append(img_path)

         # Cancella maschera associata
         if os.path.exists(mask_path):
             os.remove(mask_path)
             removed_files.append(mask_path)

     print(f"üóëÔ∏è Deleted {len(removed_files)} files (images + masks).")

     # --- Aggiorna il CSV ---
     df = pd.read_csv(labels_csv)
     df["sample_index"] = df["sample_index"].astype(str)

     df_new = df[~df["sample_index"].isin(img_list)]
     df_new.to_csv(labels_csv, index=False)

     print(f"üìÑ Remaining labels saved: {len(df_new)}")
     print("‚úÖ Automatic cleanup completed.")

In [None]:
delete_contaminated_samples(contaminated, train_img_dir, labels_file)



##  **Data Prepocessing NEW**

In [None]:
import os
import pandas as pd
import numpy as np
import cv2
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

TRAIN_IMG_DIR = os.path.join(DATASET_PATH, "train_data")
LABELS_CSV = os.path.join(DATASET_PATH, "train_labels.csv")

# ---------------------------------------------------------
#  LOADING FUNCTIONS
# ---------------------------------------------------------

def load_rgb(path):
    return Image.open(path).convert("RGB")

def load_mask(path):
    return Image.open(path).convert("L")


# ---------------------------------------------------------
#  MASK CLEANING (closing + dilation)
# ---------------------------------------------------------

def clean_mask(mask_arr, close_size=10, dilate_size=10):
    """
    Cleans the tumor mask via morphological closing + dilation,
    producing smoother, more contiguous regions suitable for centroid extraction.
    """
    kernel_close  = np.ones((close_size, close_size), np.uint8)
    kernel_dilate = np.ones((dilate_size, dilate_size), np.uint8)

    # Closing fills holes and connects nearby fragments
    mask_closed = cv2.morphologyEx(mask_arr, cv2.MORPH_CLOSE, kernel_close)

    # Dilation enlarges tumor regions slightly
    mask_dilated = cv2.dilate(mask_closed, kernel_dilate, iterations=1)

    return mask_dilated


# ---------------------------------------------------------
#  TILE EXTRACTION FROM CENTROIDS
# ---------------------------------------------------------

def extract_tiles_from_centroids(img, mask, tile_size=224, min_mask_ratio=0.01):
    """
    Extract tiles centered on connected components of the cleaned tumor mask.
    Produces significantly better tiles than sliding-window approaches.
    """
    img_arr  = np.array(img)
    mask_arr = np.array(mask)

    # Clean tumor regions
    mask_clean = clean_mask(mask_arr)

    # Connected components extraction
    num_labels, labels = cv2.connectedComponents((mask_clean > 0).astype(np.uint8))

    H, W = img_arr.shape[:2]
    tiles = []
    half = tile_size // 2

    for lbl in range(1, num_labels):  # skip background (label 0)
        ys, xs = np.where(labels == lbl)
        if len(xs) == 0:
            continue

        # Centroid of tumor region
        cx, cy = int(xs.mean()), int(ys.mean())

        # Tile boundaries (centered on tumor)
        x1, x2 = cx - half, cx + half
        y1, y2 = cy - half, cy + half

        # Padding if tile goes out of bounds
        pad_left   = max(0, -x1)
        pad_top    = max(0, -y1)
        pad_right  = max(0, x2 - W)
        pad_bottom = max(0, y2 - H)

        tile_img = img_arr[max(y1, 0):min(y2, H), max(x1, 0):min(x2, W)]
        tile_mask = mask_clean[max(y1, 0):min(y2, H), max(x1, 0):min(x2, W)]

        if pad_left or pad_top or pad_right or pad_bottom:
            tile_img = cv2.copyMakeBorder(tile_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
            tile_mask = cv2.copyMakeBorder(tile_mask, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)

        # Reject tiles with insufficient tumor content
        if (tile_mask > 0).mean() < min_mask_ratio:
            continue

        tiles.append((Image.fromarray(tile_img), tile_mask))

    return tiles


# ---------------------------------------------------------
#  OPTIONAL FALLBACK: SLIDING WINDOW (for safety)
# ---------------------------------------------------------

def fallback_sliding_window(img, mask, tile_size=224, min_mask_ratio=0.01):
    """
    A light fallback in case centroid-based extraction produces no tiles.
    """
    img_arr  = np.array(img)
    mask_arr = np.array(mask)

    H, W = img_arr.shape[:2]
    tiles = []

    stride = tile_size // 2  # 50% overlap

    for y in range(0, H - tile_size + 1, stride):
        for x in range(0, W - tile_size + 1, stride):

            tile_mask = mask_arr[y:y+tile_size, x:x+tile_size]
            if (tile_mask > 0).mean() < min_mask_ratio:
                continue

            tile_img = img_arr[y:y+tile_size, x:x+tile_size]
            tiles.append((Image.fromarray(tile_img), tile_mask))

    return tiles

In [None]:
# ============================================================
# MULTISCALE TILE EXTRACTION VIA CENTROIDS (need also the clean_mask cells runned)
# ============================================================

def extract_centered_crop(img_arr, cx, cy, crop_size):
    """
    Extract a square crop centered at (cx, cy) with padding if needed.
    """
    H, W = img_arr.shape[:2]
    half = crop_size // 2

    x1, x2 = cx - half, cx + half
    y1, y2 = cy - half, cy + half

    pad_left   = max(0, -x1)
    pad_top    = max(0, -y1)
    pad_right  = max(0, x2 - W)
    pad_bottom = max(0, y2 - H)

    crop = img_arr[max(y1, 0):min(y2, H),
                   max(x1, 0):min(x2, W)]

    if pad_left or pad_top or pad_right or pad_bottom:
        crop = cv2.copyMakeBorder(
            crop,
            pad_top, pad_bottom, pad_left, pad_right,
            cv2.BORDER_CONSTANT,
            value=0
        )

    return crop

def extract_multiscale_tiles_from_centroids(
    img,
    mask,
    tile_size=224,
    min_mask_ratio=0.01,
    zoom_factors=(0.58, 1.0)  # zoom-in, base, zoom-out
):
    """
    Extract multiscale tiles (zoom-in / base / zoom-out) centered on tumor regions.
    All outputs are resized to tile_size x tile_size.
    """
    img_arr  = np.array(img)
    mask_arr = np.array(mask)

    mask_clean = clean_mask(mask_arr)
    num_labels, labels = cv2.connectedComponents((mask_clean > 0).astype(np.uint8))

    H, W = img_arr.shape[:2]
    tiles = []

    for lbl in range(1, num_labels):
        ys, xs = np.where(labels == lbl)
        if len(xs) == 0:
            continue

        cx, cy = int(xs.mean()), int(ys.mean())

        for zf in zoom_factors:
            crop_size = int(tile_size * zf)

            # extract image + mask crop
            crop_img  = extract_centered_crop(img_arr,  cx, cy, crop_size)
            crop_mask = extract_centered_crop(mask_clean, cx, cy, crop_size)

            # resize to 224x224
            crop_img  = cv2.resize(crop_img,  (tile_size, tile_size), interpolation=cv2.INTER_LINEAR)
            crop_mask = cv2.resize(crop_mask, (tile_size, tile_size), interpolation=cv2.INTER_NEAREST)

            # tumor content check (on resized mask)
            if (crop_mask > 0).mean() < min_mask_ratio:
                continue

            tiles.append((Image.fromarray(crop_img), crop_mask))

    return tiles

In [None]:
from torch.utils.data import Dataset

class HistologyOvRTiles(Dataset):
    """
    OvR Dataset working directly on in-memory tiles.
    labels_df must contain:
      - tile_img (PIL.Image)
      - binary_label (0/1)
    """
    def __init__(self, labels_df, transform=None):
        self.labels_df = labels_df.reset_index(drop=True)
        self.transform = transform

        self.samples = [
            (row["tile_img"], int(row["binary_label"]))
            for _, row in self.labels_df.iterrows()
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img, y = self.samples[idx]

        if self.transform is not None:
            img = self.transform(img)

        return img, y



##  **Data Split and Data Transformation**

In [None]:
import pandas as pd

def extract_tiles_in_memory(
    img_dir,
    slide_df,
    tile_size=224,
    min_mask_ratio=0.01,
    zoom_factors=(0.58, 1.0)
):
    """
    Returns tiles_df with columns:
      sample_index | label | tile_img (PIL.Image)
    """
    rows = []

    for _, row in slide_df.iterrows():
        slide_id = row["sample_index"]
        label    = row["label"]

        img_path  = os.path.join(img_dir, slide_id)
        mask_path = img_path.replace("img_", "mask_")

        if not os.path.exists(img_path) or not os.path.exists(mask_path):
            print(f"[WARNING] Missing img/mask for {slide_id}")
            continue

        img  = load_rgb(img_path)
        mask = load_mask(mask_path)

        tiles = extract_multiscale_tiles_from_centroids(
            img,
            mask,
            tile_size=tile_size,
            min_mask_ratio=min_mask_ratio,
            zoom_factors=zoom_factors
        )

        if not tiles:
            tiles = fallback_sliding_window(
                img, mask,
                tile_size=tile_size,
                min_mask_ratio=min_mask_ratio * 0.1
            )

        if not tiles:
            print(f"[ERROR] No tiles for {slide_id}")
            continue

        for tile_img, _ in tiles:
            rows.append({
                "sample_index": slide_id,
                "label": label,
                "tile_img": tile_img
            })

    tiles_df = pd.DataFrame(rows)
    print(f"Total tiles extracted: {len(tiles_df)}")
    return tiles_df

def create_ovr_dataframe_tiles(tiles_df, target_class, seed=42):
    """
    tiles_df: columns [sample_index, label, tile_img]
    returns tiles_df with column binary_label
    """
    target_df = tiles_df[tiles_df["label"] == target_class].copy()
    target_df["binary_label"] = 1

    rest_df = tiles_df[tiles_df["label"] != target_class].copy()
    rest_classes = rest_df["label"].unique()

    target_count = len(target_df)
    samples_per_rest = max(1, target_count // max(1, len(rest_classes)))

    rest_samples = []
    for rc in rest_classes:
        rc_df = rest_df[rest_df["label"] == rc]
        n = min(samples_per_rest, len(rc_df))
        rest_samples.append(rc_df.sample(n=n, random_state=seed))

    rest_balanced_df = pd.concat(rest_samples).copy()
    rest_balanced_df["binary_label"] = 0

    ovr_df = pd.concat([target_df, rest_balanced_df], ignore_index=True)
    ovr_df = ovr_df.sample(frac=1, random_state=seed).reset_index(drop=True)

    return ovr_df


In [None]:

from sklearn.model_selection import train_test_split

SEED = 42

slide_df = labels_df[["sample_index", "label"]].drop_duplicates().copy()

# 70% for training OvR, 15% for validation and 15% for XGBoost
cnn_train_ids, temp_ids = train_test_split(
    slide_df["sample_index"],
    test_size=0.30,
    random_state=SEED,
    stratify=slide_df["label"]
)

cnn_val_ids, xgb_slide_ids = train_test_split(
    temp_ids,
    test_size=0.50,
    random_state=SEED,
    stratify=slide_df.loc[
        slide_df["sample_index"].isin(temp_ids), "label"
    ]
)

print("CNN train slides:", len(cnn_train_ids))
print("CNN val slides:",   len(cnn_val_ids))
print("XGB slides:",       len(xgb_slide_ids))
print("Total:", len(cnn_train_ids) + len(cnn_val_ids) + len(xgb_slide_ids))



CNN train slides: 406
CNN val slides: 87
XGB slides: 88
Total: 581


In [None]:
tiles_train_df = extract_tiles_in_memory(
    img_dir=TRAIN_IMG_DIR,
    slide_df=slide_df[slide_df["sample_index"].isin(cnn_train_ids)],
    tile_size=224
)

tiles_val_df = extract_tiles_in_memory(
    img_dir=TRAIN_IMG_DIR,
    slide_df=slide_df[slide_df["sample_index"].isin(cnn_val_ids)],
    tile_size=224
)

Total tiles extracted: 5113
Total tiles extracted: 988


In [None]:
## custom 90-degree rotation function ##
class Random90Rotation(object):
    """Rotates image randomly by 0, 90, 180, or 270 degrees."""
    def __init__(self, p=0.5):
        self.degrees = [0, 90, 180, 270]
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            angle = random.choice(self.degrees)
            return img.rotate(angle, resample=Image.BICUBIC)
        return img


In [None]:

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_transforms = transforms.Compose([
    Random90Rotation(p=0.75),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.RandomAffine(degrees=20, translate=(0.05,0.05), scale=(0.95,1.05)),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.25),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

class_names = sorted(slide_df["label"].unique())

ovr_train_datasets = {}
ovr_val_datasets   = {}

for cls in class_names:
    print(f"\nPreparing OvR dataset for class: {cls}")

    train_ovr_df = create_ovr_dataframe_tiles(tiles_train_df, cls)
    val_ovr_df   = tiles_val_df.copy()
    val_ovr_df["binary_label"] = (val_ovr_df["label"] == cls).astype(int)

    ovr_train_datasets[cls] = HistologyOvRTiles(
        labels_df=train_ovr_df,
        transform=train_transforms
    )

    ovr_val_datasets[cls] = HistologyOvRTiles(
        labels_df=val_ovr_df,
        transform=val_transforms
    )

    print(f"  Train tiles: {len(ovr_train_datasets[cls])}")
    print(f"  Val tiles:   {len(ovr_val_datasets[cls])}")



Preparing OvR dataset for class: HER2(+)
  Train tiles: 2808
  Val tiles:   988

Preparing OvR dataset for class: Luminal A
  Train tiles: 2762
  Val tiles:   988

Preparing OvR dataset for class: Luminal B
  Train tiles: 3525
  Val tiles:   988

Preparing OvR dataset for class: Triple negative
  Train tiles: 1056
  Val tiles:   988


##  **Data Loader Creation**

In [None]:

"""##  **Data Loader Creation (OvR ‚Äì UPDATED)**"""

from torch.utils.data import DataLoader

# ---------------------------------------------------------
# Configuration
# ---------------------------------------------------------
BATCH_SIZE  = 64
NUM_WORKERS = 4
PIN_MEMORY = (device.type == "cuda")

ovr_train_loaders = {}
ovr_val_loaders   = {}

print("\n--- Creating OvR DataLoaders (from in-memory datasets) ---")

for class_name in class_names:
    print(f"Processing {class_name}...")

    train_dataset = ovr_train_datasets[class_name]
    val_dataset   = ovr_val_datasets[class_name]

    ovr_train_loaders[class_name] = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        drop_last=True
    )

    ovr_val_loaders[class_name] = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY
    )

    print(f"  -> Train tiles: {len(train_dataset)} | Val tiles: {len(val_dataset)}")

print("\n‚úÖ All 4 OvR DataLoaders created correctly.")



--- Creating OvR DataLoaders (from in-memory datasets) ---
Processing HER2(+)...
  -> Train tiles: 2808 | Val tiles: 988
Processing Luminal A...
  -> Train tiles: 2762 | Val tiles: 988
Processing Luminal B...
  -> Train tiles: 3525 | Val tiles: 988
Processing Triple negative...
  -> Train tiles: 1056 | Val tiles: 988

‚úÖ All 4 OvR DataLoaders created correctly.


##  üßÆ **Network Parameters**

In [None]:

"""##  üßÆ **Network Parameters**"""

LEARNING_RATE = 5e-4
EPOCHS        = 200
PATIENCE      = 20

DROPOUT_RATE  = 0.3

criterion     = nn.CrossEntropyLoss()
criterion_val = nn.CrossEntropyLoss()

# Print the defined parameters
print("Epochs:", EPOCHS)
print("Batch Size:", BATCH_SIZE)
print("Learning Rate:", LEARNING_RATE)
print("Dropout Rate:", DROPOUT_RATE)
print("Patience:", PATIENCE)


Epochs: 200
Batch Size: 64
Learning Rate: 0.0005
Dropout Rate: 0.3
Patience: 20


##  üß† **Training Functions**

In [None]:

"""##  üß† **Training Functions**"""

def train_one_epoch(model, train_loader, criterion, optimizer, scaler, device):
    model.train()
    running_loss = 0.0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
            logits = model(inputs)
            loss = criterion(logits, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    return epoch_loss


from sklearn.metrics import accuracy_score, f1_score

def validate_one_epoch(model, val_loader, criterion, device):
    model.eval()

    running_loss = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
                logits = model(inputs)
                loss = criterion(logits, targets)

            running_loss += loss.item() * inputs.size(0)
            preds = logits.argmax(dim=1)

            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    y_true = np.concatenate(all_targets)
    y_pred = np.concatenate(all_preds)

    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc  = accuracy_score(y_true, y_pred)

    # Binary F1 (for monitoring only, not early stopping)
    epoch_f1 = f1_score(y_true, y_pred)

    return epoch_loss, epoch_acc, epoch_f1


def fit(
    model, train_loader, val_loader,
    epochs, criterion, optimizer, scaler, device,
    unfreeze_epoch=0, fine_tune_lr=None,
    patience=15,
    restore_best_weights=True,
    writer=None,
    verbose=1,
    experiment_name=""
):
    history = {
        "train_loss": [],
        "val_loss": [],
        "val_acc": [],
        "val_f1": []
    }

    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(1, epochs + 1):

        # Unfreeze backbone
        if epoch == unfreeze_epoch and unfreeze_epoch > 0:
            for p in model.parameters():
                p.requires_grad = True
            optimizer = torch.optim.Adam(model.parameters(), lr=fine_tune_lr)

        train_loss = train_one_epoch(
            model, train_loader, criterion, optimizer, scaler, device
        )

        val_loss, val_acc, val_f1 = validate_one_epoch(
            model, val_loader, criterion, device
        )

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["val_f1"].append(val_f1)

        if writer:
            writer.add_scalar("Loss/Train", train_loss, epoch)
            writer.add_scalar("Loss/Val", val_loss, epoch)
            writer.add_scalar("Acc/Val", val_acc, epoch)
            writer.add_scalar("F1/Val", val_f1, epoch)

        if verbose:
            print(f"Epoch {epoch:03d} | "
                  f"Train Loss {train_loss:.4f} | "
                  f"Val Loss {val_loss:.4f} | "
                  f"Acc {val_acc:.4f} | F1 {val_f1:.4f}")

        # Early stopping on VAL LOSS
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f"models/{experiment_name}.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    if restore_best_weights:
        model.load_state_dict(torch.load(f"models/{experiment_name}.pt"))

    return model, history




##  **One vs Rest with XGBoost**

In [None]:

"""## üõ†Ô∏è **One-vs-Rest **"""

import torch
import torch.nn as nn
import torchvision
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

# ---------------------------------------------------------
# Configuration
# ---------------------------------------------------------
MODEL_BACKBONE = 'resnet18'
FREEZE_BACKBONE = True
UNFREEZE_EPOCH = 20
FINE_TUNE_LEARNING_RATE = 1e-5
NUM_OVR_CLASSES = 2
input_shape = (3, 224, 224)

# ---------------------------------------------------------
# Binary ResNet18 (OvR)
# ---------------------------------------------------------
class BinaryResNet18(nn.Module):
    """
    ResNet18 with pretrained weights for binary OvR classification.
    """
    def __init__(self, dropout_rate=0.3, freeze_backbone=True):
        super().__init__()

        self.backbone = torchvision.models.resnet18(
            weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        )

        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        in_features = self.backbone.fc.in_features

        self.backbone.fc = nn.Sequential(
            nn.Linear(in_features, in_features // 2),
            nn.ReLU(),
            nn.BatchNorm1d(in_features // 2),
            nn.Dropout(dropout_rate),
            nn.Linear(in_features // 2, NUM_OVR_CLASSES)
        )

    def forward(self, x):
        return self.backbone(x)


In [None]:

# ---------------------------------------------------------
# Training 4 OvR classifiers
# ---------------------------------------------------------
OVR_TRAINED_MODELS = {}
OVR_TRAINING_HISTORIES = {}

print("\n=======================================================")
print("== STARTING TRAINING FOR 4 ONE-VS-REST CLASSIFIERS ==")
print("=======================================================")

for i, class_name in enumerate(class_names):

    print(f"\n>>>> TRAINING CLASSIFIER {i+1}/4: {class_name} vs REST <<<<")

    # -----------------------------------------------------
    # Model
    # -----------------------------------------------------
    model_ovr = BinaryResNet18(
        dropout_rate=DROPOUT_RATE,
        freeze_backbone=FREEZE_BACKBONE
    ).to(device)

    # -----------------------------------------------------
    # Loss (STANDARD CE)
    # -----------------------------------------------------
    criterion = nn.CrossEntropyLoss()

    # -----------------------------------------------------
    # Optimizer (head only initially)
    # -----------------------------------------------------
    optimizer = Adam(
        filter(lambda p: p.requires_grad, model_ovr.parameters()),
        lr=LEARNING_RATE
    )

    scaler = torch.amp.GradScaler(enabled=(device.type == 'cuda'))

    # -----------------------------------------------------
    # Logging
    # -----------------------------------------------------
    experiment_name = f"ovr_{class_name.replace(' ', '_').replace('+', 'pos')}_{MODEL_BACKBONE}"
    writer = SummaryWriter(f"./{logs_dir}/{experiment_name}")

    print(f"Starting experiment: {experiment_name}")

    # -----------------------------------------------------
    # Training
    # -----------------------------------------------------
    trained_model, history = fit(
        model=model_ovr,
        train_loader=ovr_train_loaders[class_name],
        val_loader=ovr_val_loaders[class_name],
        epochs=EPOCHS,
        criterion=criterion,
        optimizer=optimizer,
        scaler=scaler,
        device=device,
        writer=writer,
        verbose=5,
        experiment_name=experiment_name,
        patience=PATIENCE,
        unfreeze_epoch=UNFREEZE_EPOCH,
        fine_tune_lr=FINE_TUNE_LEARNING_RATE
    )

    # -----------------------------------------------------
    # Save results
    # -----------------------------------------------------
    OVR_TRAINED_MODELS[class_name] = trained_model
    OVR_TRAINING_HISTORIES[class_name] = history

    model_path = f"models/{experiment_name}.pt"
    torch.save(trained_model.state_dict(), model_path)

    print(f"Saved model for {class_name} ‚Üí {model_path}")
    print(f"<<<< CLASSIFIER {class_name} TRAINING COMPLETE >>>>")

print("\n=======================================================")
print("== ALL 4 OVR CLASSIFIERS TRAINED ==")
print("=======================================================")



== STARTING TRAINING FOR 4 ONE-VS-REST CLASSIFIERS ==

>>>> TRAINING CLASSIFIER 1/4: HER2(+) vs REST <<<<
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 44.7M/44.7M [00:00<00:00, 243MB/s]


Starting experiment: ovr_HER2(pos)_resnet18
Epoch 001 | Train Loss 0.7120 | Val Loss 0.7068 | Acc 0.5698 | F1 0.3451
Epoch 002 | Train Loss 0.6737 | Val Loss 0.8331 | Acc 0.4555 | F1 0.4049
Epoch 003 | Train Loss 0.6698 | Val Loss 0.8391 | Acc 0.3957 | F1 0.3889
Epoch 004 | Train Loss 0.6597 | Val Loss 0.6809 | Acc 0.5962 | F1 0.3179
Epoch 005 | Train Loss 0.6578 | Val Loss 0.7464 | Acc 0.5111 | F1 0.3620
Epoch 006 | Train Loss 0.6468 | Val Loss 0.6836 | Acc 0.5860 | F1 0.3103
Epoch 007 | Train Loss 0.6406 | Val Loss 0.7339 | Acc 0.5071 | F1 0.3301
Epoch 008 | Train Loss 0.6484 | Val Loss 0.6948 | Acc 0.5749 | F1 0.2977
Epoch 009 | Train Loss 0.6451 | Val Loss 0.6941 | Acc 0.5536 | F1 0.3055
Epoch 010 | Train Loss 0.6368 | Val Loss 0.7434 | Acc 0.5091 | F1 0.3383
Epoch 011 | Train Loss 0.6433 | Val Loss 0.6817 | Acc 0.5769 | F1 0.3301
Epoch 012 | Train Loss 0.6399 | Val Loss 0.7676 | Acc 0.4889 | F1 0.4229
Epoch 013 | Train Loss 0.6361 | Val Loss 0.7353 | Acc 0.5061 | F1 0.3838
Epoch 0

In [None]:

""" XGboost new """
# ---------------------------------------------------------
# XGBoost WSI-level split
# ---------------------------------------------------------


xgb_slide_df = slide_df[slide_df["sample_index"].isin(xgb_slide_ids)]

print(f"XGBoost slides: {len(xgb_slide_df)}")

# ---------------------------------------------------------
# Extract tiles for XGBoost WSIs (Without overlap with CNN)
# ---------------------------------------------------------
tiles_xgb_df = extract_tiles_in_memory(
    img_dir=TRAIN_IMG_DIR,
    slide_df=xgb_slide_df,
    tile_size=224
)

print(f"Total XGBoost tiles extracted: {len(tiles_xgb_df)}")

assert len(
    set(tiles_xgb_df["sample_index"]) &
    set(tiles_val_df["sample_index"])
) == 0, "‚ùå DATA LEAKAGE: XGB tiles overlap CNN validation tiles"


XGBoost slides: 88
Total tiles extracted: 988
Total XGBoost tiles extracted: 988


In [None]:

import torch.nn.functional as F

def infer_tiles_ovr(model, tiles_df, transform, device):
    """
    tiles_df: columns [sample_index, tile_img]
    Returns: dict {sample_index: np.array(probabilities)}
    """
    model.eval()
    slide_probs = {}

    with torch.no_grad():
        for slide_id, group in tiles_df.groupby("sample_index"):
            probs = []

            for img in group["tile_img"]:
                x = transform(img).unsqueeze(0).to(device)

                with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
                    logits = model(x)
                    p = F.softmax(logits, dim=1)[0, 1].item()  # prob positive

                probs.append(p)

            slide_probs[slide_id] = np.array(probs)

    return slide_probs

from sklearn.metrics import precision_recall_curve

def compute_pr_threshold(y_true, y_scores):
    """
    Returns threshold that maximizes F1 (PR-based).
    """
    precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    best_idx = np.argmax(f1)
    return thresholds[best_idx]


In [None]:

# ---------------------------------------------------------
# Compute thresholds œÑ_c for each OvR
# ---------------------------------------------------------

ovr_thresholds = {}

for cls in class_names:
    print(f"\nComputing PR threshold for class {cls}")

    model = OVR_TRAINED_MODELS[cls]

    slide_probs = infer_tiles_ovr(
        model,
        tiles_val_df,
        val_transforms,
        device
    )

    y_true = []
    y_score = []

    for slide_id, probs in slide_probs.items():
        label = slide_df.loc[slide_df["sample_index"] == slide_id, "label"].values[0]
        y_true.extend([1 if label == cls else 0] * len(probs))
        y_score.extend(probs.tolist())

    tau = compute_pr_threshold(np.array(y_true), np.array(y_score))
    ovr_thresholds[cls] = tau

    print(f"  œÑ_{cls} = {tau:.4f}")

def build_xgb_features(slide_ids, tiles_df, ovr_models, ovr_thresholds, transform, device):
    """
    Returns:
      X: np.ndarray [n_slides, 8]
      y: np.ndarray [n_slides]
    """
    X, y = [], []

    for slide_id in slide_ids:
        slide_tiles = tiles_df[tiles_df["sample_index"] == slide_id]

        features = []

        for cls in class_names:
            model = ovr_models[cls]
            tau   = ovr_thresholds[cls]

            probs = []
            with torch.no_grad():
                for img in slide_tiles["tile_img"]:
                    x = transform(img).unsqueeze(0).to(device)
                    with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
                        p = F.softmax(model(x), dim=1)[0, 1].item()
                    probs.append(p)

            probs = np.array(probs)
            features.extend([
                np.sum(probs >= tau),
                np.sum(probs <  tau)
            ])

        X.append(features)
        label = slide_df.loc[slide_df["sample_index"] == slide_id, "label"].values[0]
        y.append(label)

    return np.array(X), np.array(y)



Computing PR threshold for class HER2(+)
  œÑ_HER2(+) = 0.1669

Computing PR threshold for class Luminal A
  œÑ_Luminal A = 0.4395

Computing PR threshold for class Luminal B
  œÑ_Luminal B = 0.1357

Computing PR threshold for class Triple negative
  œÑ_Triple negative = 0.3996


In [None]:
from xgboost import XGBClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
import xgboost as xgb

# Encode labels
le_xgb = LabelEncoder()
y_encoded = le_xgb.fit_transform(xgb_slide_df["label"])

X_xgb, y_xgb = build_xgb_features(
    slide_ids=xgb_slide_df["sample_index"].values,
    tiles_df=tiles_xgb_df,   # tiles from CNN-val / held-out
    ovr_models=OVR_TRAINED_MODELS,
    ovr_thresholds=ovr_thresholds,
    transform=val_transforms,
    device=device
)

# Train / Val split for XGB
X_tr, X_va, y_tr, y_va = train_test_split(
    X_xgb, y_encoded,
    test_size=0.2,
    random_state=SEED,
    stratify=y_encoded
)

# ---------------------------------------------------------
# XGBoost model
# ---------------------------------------------------------
dtrain = xgb.DMatrix(X_tr, label=y_tr)
dval   = xgb.DMatrix(X_va, label=y_va)

params = {
    "objective": "multi:softprob",
    "num_class": len(class_names),
    "max_depth": 2,
    "eta": 0.03,
    "subsample": 0.8,
    "colsample_bytree": 0.8,
    "eval_metric": "mlogloss",
    "seed": SEED
}

evals = [(dtrain, "train"), (dval, "val")]

xgb_model = xgb.train(
    params=params,
    dtrain=dtrain,
    num_boost_round=200,
    evals=evals,
    early_stopping_rounds=20,
    verbose_eval=True
)


y_pred_proba = xgb_model.predict(dval)
y_pred = y_pred_proba.argmax(axis=1)

print(classification_report(y_va, y_pred, target_names=le_xgb.classes_))
print(confusion_matrix(y_va, y_pred))


[0]	train-mlogloss:1.33861	val-mlogloss:1.35753
[1]	train-mlogloss:1.32809	val-mlogloss:1.35950
[2]	train-mlogloss:1.31843	val-mlogloss:1.35974
[3]	train-mlogloss:1.30858	val-mlogloss:1.36079
[4]	train-mlogloss:1.29988	val-mlogloss:1.35264
[5]	train-mlogloss:1.28942	val-mlogloss:1.35158
[6]	train-mlogloss:1.28034	val-mlogloss:1.35264
[7]	train-mlogloss:1.27118	val-mlogloss:1.35341
[8]	train-mlogloss:1.26036	val-mlogloss:1.35072
[9]	train-mlogloss:1.25337	val-mlogloss:1.35040
[10]	train-mlogloss:1.24566	val-mlogloss:1.34817
[11]	train-mlogloss:1.23740	val-mlogloss:1.34946
[12]	train-mlogloss:1.23046	val-mlogloss:1.35196
[13]	train-mlogloss:1.22159	val-mlogloss:1.34532
[14]	train-mlogloss:1.21443	val-mlogloss:1.34430
[15]	train-mlogloss:1.20616	val-mlogloss:1.34476
[16]	train-mlogloss:1.19815	val-mlogloss:1.34282
[17]	train-mlogloss:1.19106	val-mlogloss:1.34186
[18]	train-mlogloss:1.18394	val-mlogloss:1.34131
[19]	train-mlogloss:1.17619	val-mlogloss:1.33980
[20]	train-mlogloss:1.16973	va

## **Inference**

In [None]:
test_slide_ids = sorted(
    f for f in os.listdir(test_img_dir)
    if f.startswith("img_")
)

print(f"Test slides: {len(test_slide_ids)}")

test_slide_df = pd.DataFrame({
    "sample_index": test_slide_ids,
    "label": "unknown" # Add a dummy label for the test set
})

tiles_test_df = extract_tiles_in_memory(
    img_dir=test_img_dir,
    slide_df=test_slide_df,
    tile_size=224
)

print(f"Total test tiles extracted: {len(tiles_test_df)}")

Test slides: 477
Total tiles extracted: 5533
Total test tiles extracted: 5533


In [None]:
def infer_tiles_ovr_test(model, tiles_df, transform, device):
    model.eval()
    slide_probs = {}

    with torch.no_grad():
        for slide_id, group in tiles_df.groupby("sample_index"):
            probs = []

            for img in group["tile_img"]:
                x = transform(img).unsqueeze(0).to(device)

                with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
                    p = torch.softmax(model(x), dim=1)[0, 1].item()

                probs.append(p)

            slide_probs[slide_id] = np.array(probs)

    return slide_probs


In [None]:
def build_xgb_features_test(
    slide_ids, tiles_df,
    ovr_models, ovr_thresholds,
    transform, device
):
    X = []

    for slide_id in slide_ids:
        slide_tiles = tiles_df[tiles_df["sample_index"] == slide_id]

        features = []

        for cls in class_names:
            model = ovr_models[cls]
            tau   = ovr_thresholds[cls]

            probs = []
            with torch.no_grad():
                for img in slide_tiles["tile_img"]:
                    x = transform(img).unsqueeze(0).to(device)
                    with torch.amp.autocast(device_type=device.type, enabled=(device.type == 'cuda')):
                        p = torch.softmax(model(x), dim=1)[0, 1].item()
                    probs.append(p)

            probs = np.array(probs)

            # SAME FEATURES AS TRAINING
            features.extend([
                np.sum(probs >= tau),
                np.sum(probs <  tau)
            ])

        X.append(features)

    return np.array(X)


In [None]:

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

X_test = build_xgb_features_test(
    slide_ids=test_slide_ids,
    tiles_df=tiles_test_df,
    ovr_models=OVR_TRAINED_MODELS,
    ovr_thresholds=ovr_thresholds,
    transform=test_transforms,
    device=device
)


In [None]:
import xgboost as xgb

dtest = xgb.DMatrix(X_test)

y_test_proba = xgb_model.predict(dtest)
y_test_pred = y_test_proba.argmax(axis=1)

y_test_labels = le_xgb.inverse_transform(y_test_pred)

submission_df = pd.DataFrame({
    "sample_index": test_slide_ids,
    "predicted_label": y_test_labels
})

submission_df.to_csv("oneVSrest.csv", index=False)
print(submission_df.head())

from google.colab import files
files.download("oneVSrest.csv")



   sample_index predicted_label
0  img_0000.png       Luminal A
1  img_0001.png         HER2(+)
2  img_0002.png       Luminal B
3  img_0003.png         HER2(+)
4  img_0004.png       Luminal B
