## **I. Google Colab Initializtion (Only on Colab)**

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')
# cur_dir = "/content/drive/Othercomputers/My laptop/POLIMI/AN2DL/AN2DL_CH_2/Notebooks"
# %cd $cur_dir


In [None]:
#%pip install torchview

## **1. Import Libraries**

In [None]:
# Set seed for reproducibility
SEED = 42

# Import necessary libraries
import os

# Set environment variables before importing modules
os.environ['PYTHONHASHSEED'] = str(SEED)
os.environ['MPLCONFIGDIR'] = os.getcwd() + '/configs/'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Suppress warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

# Import necessary modules
import logging
import random
import numpy as np

# Set seeds for random number generators in NumPy and Python
np.random.seed(SEED)
random.seed(SEED)

# Import PyTorch
import torch
torch.manual_seed(SEED)
from torch import nn
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision.transforms import v2 as transforms
from torch.utils.data import TensorDataset, DataLoader
from torchview import draw_graph
from scipy import ndimage
from PIL import Image
import timm
# Configurazione di TensorBoard e directory
logs_dir = "tensorboard"
!pkill -f tensorboard
%load_ext tensorboard
!mkdir -p models

if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")

# Import other libraries
import cv2
import copy
import shutil
from itertools import product
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from PIL import Image
import matplotlib.gridspec as gridspec
import requests
from io import BytesIO
from tqdm import tqdm
import glob
from pathlib import Path
import shutil
import gc
import torchvision.transforms as T
import csv

# Configure plot display settings
sns.set(font_scale=1.4)
sns.set_style('white')
plt.rc('font', size=14)
%matplotlib inline

## **2. Preprocessing**

- Preprocessing pipeline : 
    - Get Loaded Images
    - Create Goo Masks
    - Apply Goo Removal + Resizing
    - Discard Shrek Images
    - Apply the external Masks (optional)
    

### 2.1 Preprocessing Functions

#### 2.1.1 _get_smart_goo_mask

In [None]:
def _get_smart_goo_mask(img_bgr):
    """
    Internal helper to detect goo using Core & Shell logic + 1px Nudge.
    Returns a binary mask (White = Goo, Black = Safe).
    """
    # 1. Convert to HSV
    hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)

    # 2. Define Ranges
    # CORE: Solid Green (Strict)
    core_lower = np.array([35, 100, 50])
    core_upper = np.array([85, 255, 255])
    
    # SHELL: Faint Halo (Loose/Transparent)
    shell_lower = np.array([30, 30, 30])
    shell_upper = np.array([95, 255, 255])

    # 3. Create initial masks
    mask_core = cv2.inRange(hsv, core_lower, core_upper)
    mask_shell = cv2.inRange(hsv, shell_lower, shell_upper)

    # 4. Smart Combine (Connected Components)
    # Keep 'Shell' blobs ONLY if they touch 'Core' blobs
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_shell, connectivity=8)
    smart_mask = np.zeros_like(mask_core)

    for label_id in range(1, num_labels): # Skip background (0)
        blob_mask = (labels == label_id).astype(np.uint8) * 255
        
        # Check overlap with Core
        overlap = cv2.bitwise_and(blob_mask, mask_core)
        
        # If there is ANY overlap, keep the blob
        if cv2.countNonZero(overlap) > 0:
            smart_mask = cv2.bitwise_or(smart_mask, blob_mask)

    # 5. Fill Holes (in case the goo has shiny reflections)
    contours, _ = cv2.findContours(smart_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    final_filled_mask = np.zeros_like(smart_mask)
    for contour in contours:
        # Minimum area filter (200px) to remove tiny stray noise
        if cv2.contourArea(contour) > 200:
            cv2.drawContours(final_filled_mask, [contour], -1, (255), thickness=cv2.FILLED)

    # 6. The "1-Pixel Nudge"
    # Safely expand by 1 pixel to cover the final anti-aliased fringe
    kernel = np.ones((3, 3), np.uint8)
    final_expanded_mask = cv2.dilate(final_filled_mask, kernel, iterations=1)

    return final_expanded_mask



#### 2.1.2 remove_goo

In [None]:
def remove_goo(input_dir, output_dir, target_size=(224, 224), remove_goo=True, save_masks=True, replacement_color=(0, 0, 0)):
    """
    Iterates through input_dir, finds 'img_xxxx', resizes them to target_size, 
    and saves the result to output_dir.
    If remove_goo is True, it replaces green pixels (using Smart Core/Shell logic) with replacement_color.
    replacement_color: Tuple of (B, G, R) values. Default is black (0, 0, 0).
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    
    # Create output directory if it doesn't exist
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Extensions to look for
    valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
    
    # 1. Gather all valid image files first
    print(f"Scanning for images in: {input_dir}...")
    image_files = [
        f for f in input_dir.iterdir() 
        if f.name.startswith('img_') and f.suffix.lower() in valid_extensions
    ]
    
    if not image_files:
        print("No images found starting with 'img_' in the directory.")
        return

    # 2. Iterate with tqdm
    count = 0
    
    for file_path in tqdm(image_files, desc="Removing Goo from Images", unit="img"):
        output_path = output_dir / file_path.name
        
        if output_path.exists():
            # Skip silently
            continue

        img = cv2.imread(str(file_path))
        if img is None:
            continue
            
        if target_size is not None:
            img = cv2.resize(img, target_size)
            
        if remove_goo:
            # --- NEW SMART GOO LOGIC ---
            # Get the smart mask (White = Goo)
            goo_mask = _get_smart_goo_mask(img)
            
            # Invert Goo Mask (White = Safe)
            not_goo_mask = cv2.bitwise_not(goo_mask)
            
            # Apply Mask to keep safe areas (Goo areas become black/0)
            img_safe = cv2.bitwise_and(img, img, mask=not_goo_mask)
            
            # Create background with replacement color
            bg = np.full_like(img, replacement_color)
            
            # Keep background only where Goo is
            bg_goo = cv2.bitwise_and(bg, bg, mask=goo_mask)
            
            # Combine: Safe Image + Colored Goo Areas
            img = cv2.add(img_safe, bg_goo)

            if save_masks:
                # Save the mask (White = Safe/Tissue, Black = Goo)
                mask_name = file_path.name.replace('img_', 'goo_mask_', 1)
                mask_output_path = os.path.join(output_dir, "goo_masks", mask_name)
                Path(os.path.dirname(mask_output_path)).mkdir(parents=True, exist_ok=True)
                cv2.imwrite(str(mask_output_path), not_goo_mask)
            
        cv2.imwrite(str(output_path), img)
        count += 1

    print(f"Resizing complete. Processed {count} new images.")

#### 2.1.3 clean_and_save_masks

In [None]:
def clean_and_save_masks(goo_masks_dir, external_masks_dir, output_dir, target_size=(224, 224)):
    """
    Loads goo masks (White=Safe) generated by remove_goo and original external masks.
    Removes goo areas from external masks (intersection of Safe and External).
    Saves the cleaned masks to output_dir.
    TODO: Add check for if the masks already exist to skip processing.
    """
    goo_masks_dir = Path(goo_masks_dir)
    external_masks_dir = Path(external_masks_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Find all goo masks
    goo_mask_files = list(goo_masks_dir.glob('goo_mask_*.png'))
    
    if not goo_mask_files:
        print(f"No goo masks found in {goo_masks_dir}")
        return

    print(f"Found {len(goo_mask_files)} goo masks. Processing...")

    count = 0
    for goo_mask_path in tqdm(goo_mask_files, desc="Cleaning External Masks"):
        # Derive external mask filename: goo_mask_xxxx.png -> mask_xxxx.png
        mask_name = goo_mask_path.name.replace('goo_mask_', 'mask_', 1)
        external_mask_path = external_masks_dir / mask_name
        
        # Fallback: check for .png if not found (though goo_mask is .png)
        if not external_mask_path.exists():
             # Try finding with same stem but different extension if needed
             pass

        if not external_mask_path.exists():
            # tqdm.write(f"External mask not found for {goo_mask_path.name}")
            continue

        # Load masks
        # Goo mask: White = Safe, Black = Goo
        goo_mask = cv2.imread(str(goo_mask_path), cv2.IMREAD_GRAYSCALE)
        # External mask: White = ROI, Black = Background
        external_mask = cv2.imread(str(external_mask_path), cv2.IMREAD_GRAYSCALE)

        if goo_mask is None or external_mask is None:
            continue

        # Resize external mask if target_size is provided
        # Note: goo_mask is already at the size produced by remove_goo
        if target_size is not None:
             external_mask = cv2.resize(external_mask, target_size, interpolation=cv2.INTER_NEAREST)
             # Ensure goo_mask matches if it wasn't already (e.g. if remove_goo used different settings)
             if goo_mask.shape[:2] != (target_size[1], target_size[0]):
                 goo_mask = cv2.resize(goo_mask, target_size, interpolation=cv2.INTER_NEAREST)

        # Ensure binary (0 or 255)
        _, external_mask = cv2.threshold(external_mask, 127, 255, cv2.THRESH_BINARY)
        _, goo_mask = cv2.threshold(goo_mask, 127, 255, cv2.THRESH_BINARY)

        # Combine: Result is White only if BOTH are White
        # i.e. It is in the External Mask AND it is Safe (not goo)
        cleaned_mask = cv2.bitwise_and(external_mask, goo_mask)

        # Save
        output_path = output_dir / mask_name
        cv2.imwrite(str(output_path), cleaned_mask)
        count += 1

    print(f"Cleaned masks saved to {output_dir}. Processed {count} masks.")

#### 2.1.4 apply_mask

#### 2.1.5 filter_bright_green_areas

In [None]:
def filter_bright_green_areas(image, lg_H=20, lg_S=45, lg_V=0, ug_H=84, ug_S=255, ug_V=255, dilate_iterations=2):
    """
    Filters out bright green areas from the input image with improved residual removal.
    
    Args:
        image: Input image in RGB format (0-1 range)
        lg_H, lg_S, lg_V: Lower bounds for HSV green detection
        ug_H, ug_S, ug_V: Upper bounds for HSV green detection
        dilate_iterations: Number of dilation iterations to expand mask (removes edge artifacts)
    """

    # Convert from RGB (0-1) to BGR (0-255) for OpenCV
    original_bgr = (image * 255).astype(np.uint8)[..., ::-1]

    # 1. Convert to HSV (Hue, Saturation, Value)
    hsv = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2HSV)

    # 2. Define the "Bright Green" Range
    lower_green = (lg_H, lg_S, lg_V)
    upper_green = (ug_H, ug_S, ug_V)

    # Create the initial mask
    mask = cv2.inRange(hsv, lower_green, upper_green)

    # 3. Morphological operations to clean up the mask
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))

    # OPEN: Remove small noise
    clean_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)

    # DILATE: Expand the mask to catch edge artifacts and residuals
    # This ensures we remove green pixels at the boundaries
    if dilate_iterations > 0:
        clean_mask = cv2.dilate(clean_mask, kernel, iterations=dilate_iterations)

    # 4. Additional step: Detect any remaining green-ish pixels
    # Create a more aggressive mask for subtle green tones
    lower_green_subtle = (max(0, lg_H - 10), max(0, lg_S - 10), 0)
    upper_green_subtle = (min(180, ug_H + 10), 255, 255)
    subtle_mask = cv2.inRange(hsv, lower_green_subtle, upper_green_subtle)
    
    # Only keep subtle green pixels that are near the main green area
    subtle_mask = cv2.morphologyEx(subtle_mask, cv2.MORPH_OPEN, kernel, iterations=1)
    
    # Combine masks
    combined_mask = cv2.bitwise_or(clean_mask, subtle_mask)

    # 5. Invert mask to keep the useful parts
    mask_inv = cv2.bitwise_not(combined_mask)

    # 6. Apply the mask
    result_bgr = cv2.bitwise_and(original_bgr, original_bgr, mask=mask_inv)

    return result_bgr, combined_mask

#### 2.1.6 analyze_dataset_for_shreks

In [None]:
def analyze_dataset_for_shreks(directory, ratio_threshold=0.0125):
    shrek_images = []
    tissue_images = []
    
    image_files = glob.glob(os.path.join(directory, 'img_*.png'))
    print(f"Found {len(image_files)} images in {directory}")

    for f in tqdm(image_files, desc="Analyzing for Shreks"):
        try:
            # Load image (BGR)
            img = cv2.imread(f)
            if img is None: continue
            
            # Prepare image for the new filter: Convert BGR to RGB (0-1 float)
            img_rgb_norm = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
            
            # Apply Filter
            result_bgr, mask = filter_bright_green_areas(img_rgb_norm)
            
            # Calculate Ratio of Green Pixels from the combined mask
            total_pixels = img.shape[0] * img.shape[1]
            green_pixels = np.count_nonzero(mask)
            ratio = green_pixels / total_pixels
            
            # Convert BGR to RGB for plotting
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            entry = {
                'name': os.path.basename(f), 
                'path': f,
                'img': img_rgb, 
                'ratio': ratio,
                'mask': mask
            }

            # === CLASSIFICATION LOGIC ===
            if ratio > ratio_threshold:
                shrek_images.append(entry)
            else:
                tissue_images.append(entry)
                
        except Exception as e:
            print(f"Skipping {f}: {e}")

    return shrek_images, tissue_images

#### 2.1.8 process_classification_results

In [None]:
def process_classification_results(shrek_list, tissue_list, shrek_dir, tissue_dir, threshold, visualize = True):
    """
    Saves classified images to respective directories and visualizes the results.
    
    Args:
        shrek_list (list): List of dicts containing Shrek image data.
        tissue_list (list): List of dicts containing Tissue image data.
        shrek_dir (str): Path to save Shrek images.
        tissue_dir (str): Path to save Tissue images.
        threshold (float): The green ratio threshold used for classification.
    """
    
    # 1. Print Summary
    print(f"Classified {len(shrek_list)} as Shrek")
    print(f"Classified {len(tissue_list)} as Tissue")

    # Ensure directories exist
    os.makedirs(shrek_dir, exist_ok=True)
    os.makedirs(tissue_dir, exist_ok=True)

    # 2. Save Shrek images
    print(f"Saving {len(shrek_list)} Shrek images to {shrek_dir}...")
    for item in tqdm(shrek_list, desc="Saving Shrek Images"):
        dest_path = os.path.join(shrek_dir, item['name'])
        
        # Check if file exists to prevent overwriting
        if os.path.exists(dest_path):
            continue

        try:
            shutil.copy2(item['path'], dest_path)
        except Exception as e:
            print(f"Error copying {item['name']} to shrek folder: {e}")

    # 3. Save Tissue images
    print(f"Saving {len(tissue_list)} Tissue images to {tissue_dir}...")
    for item in tqdm(tissue_list, desc="Saving Tissue Images"):
        dest_path = os.path.join(tissue_dir, item['name'])
        
        # Check if file exists to prevent overwriting
        if os.path.exists(dest_path):
            continue

        try:
            shutil.copy2(item['path'], dest_path)
        except Exception as e:
            print(f"Error copying {item['name']} to tissue folder: {e}")
    if visualize == True:
        # 4. Visualize Examples (2x2 Grid)
        if len(shrek_list) >= 2 and len(tissue_list) >= 2:
            fig_ex, axes = plt.subplots(2, 2, figsize=(12, 10))
            fig_ex.suptitle(f"Classification Results (Threshold: {threshold:.1%})", fontsize=16)

            def show_img(ax, item, label):
                ax.imshow(item['img'])
                # Show the Green Ratio in the title so you can see WHY it was classified
                ax.set_title(f"{label}\n{item['name']}\nGreen Pixels: {item['ratio']:.2%}")
                ax.axis('off')

            # Row 1: Detected Shrek
            show_img(axes[0, 0], shrek_list[0], "Detected Shrek")
            show_img(axes[0, 1], shrek_list[1], "Detected Shrek")

            # Row 2: Detected Tissue
            show_img(axes[1, 0], tissue_list[0], "Detected Tissue")
            show_img(axes[1, 1], tissue_list[1], "Detected Tissue")

            plt.tight_layout()
            plt.show()
        else:
            print("Not enough images in one or both classes to generate 2x2 sample grid.")

        # 5. Plot Scatter Distribution for Tuning
        shrek_ratios = [x['ratio'] for x in shrek_list]
        tissue_ratios = [x['ratio'] for x in tissue_list]

        plt.figure(figsize=(12, 6))

        # Plot Tissue points (Blue)
        plt.scatter(range(len(tissue_ratios)), tissue_ratios, color='blue', alpha=0.6, label='Classified as Tissue')

        # Plot Shrek points (Green) - Shifted on x-axis to be distinct
        # We shift the x-axis index for Shrek so they appear after the tissue points
        plt.scatter(range(len(tissue_ratios), len(tissue_ratios) + len(shrek_ratios)), shrek_ratios, color='green', alpha=0.6, label='Classified as Shrek')

        # Draw the Threshold Line
        plt.axhline(y=threshold, color='red', linestyle='--', linewidth=2, label=f'Threshold ({threshold:.1%})')

        plt.title('Green Pixel Ratio per Image', fontsize=14)
        plt.ylabel('Ratio of Green Pixels (0.0 - 1.0)')
        plt.xlabel('Image Index')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
    else:
        print("Shrek removal visulization OFF.")

# --- Example Usage ---
# shrek_list, tissue_list = analyze_dataset(DATASET_PATH) # Assuming this runs before
# process_classification_results(shrek_list, tissue_list, SHREK_DIR, TISSUE_DIR, RATIO_THRESHOLD)

#### 2.1.9 copy_masks

In [None]:
def copy_masks(image_list, masks_dir, output_dir):
    image_names = image_list
    mask_names = [name.replace('img_', 'mask_', 1) for name in image_names]

    for mask_name in tqdm(mask_names, desc="Copying Masks"):
        src_path = os.path.join(masks_dir, mask_name)
        dst_path = os.path.join(output_dir, mask_name)
        shutil.copy(src_path, dst_path)
    
    return

#### 2.1.10 extract_smart_patches

In [None]:
def extract_smart_patches(img_path, mask_path, patch_size=224, stride=224, threshold=0.30):
    """
    Intelligently extracts patches. 
    UPDATED: Groups nearby tumor spots and centers the patch on the region.
    """
    # Load images
    try:
        img = Image.open(img_path).convert("RGB")
    except FileNotFoundError:
        # Fallback for common extension swap if .png not found
        if img_path.endswith(".png"):
            img = Image.open(img_path.replace(".png", ".jpg")).convert("RGB")
        else:
            raise

    mask = Image.open(mask_path).convert("L")
    img_arr = np.array(img)
    mask_arr = np.array(mask)

    # Normalize mask
    if mask_arr.max() <= 1:
        mask_check = mask_arr * 255
    else:
        mask_check = mask_arr
    
    h, w, _ = img_arr.shape
    
    if isinstance(patch_size, int):
        ph, pw = patch_size, patch_size
    else:
        ph, pw = patch_size
        
    if isinstance(stride, int):
        sh, sw = stride, stride
    else:
        sh, sw = stride

    # --- Intelligent Extraction Logic ---
    
    # 1. GROUPING: Dilate the mask to merge nearby small dots into larger regions.
    # This prevents generating 1 patch per pixel-sized dot.
    # iterations=15 means dots within ~15 pixels of each other get merged.
    dilated_mask = ndimage.binary_dilation(mask_check > 128, iterations=15)
    
    # 2. Label the merged regions
    labeled_mask, num_features = ndimage.label(dilated_mask)
    objects = ndimage.find_objects(labeled_mask)
    
    candidate_coords = set()
    
    def get_valid_start(val, max_val, p_dim):
        return max(0, min(val, max_val - p_dim))

    # print(f"Found {num_features} clustered tumor regions.") # Commented out for batch processing

    for i, slice_obj in enumerate(objects):
        y_slice, x_slice = slice_obj
        
        # Region boundaries
        y_min, y_max = y_slice.start, y_slice.stop
        x_min, x_max = x_slice.start, x_slice.stop
        
        # --- Strategy: Center on Blob ---
        # We calculate the center of the blob and place the patch there.
        
        blob_cy = (y_min + y_max) // 2
        blob_cx = (x_min + x_max) // 2
        
        # Top-left corner for the patch to be centered on the blob center
        start_y = blob_cy - ph // 2
        start_x = blob_cx - pw // 2
        
        valid_y = get_valid_start(start_y, h, ph)
        valid_x = get_valid_start(start_x, w, pw)
        
        candidate_coords.add((valid_x, valid_y))

    # 3. Final Validation
    patches = []
    coords = []
    
    for (x, y) in candidate_coords:
        mask_patch = mask_check[y:y+ph, x:x+pw]
        img_patch = img_arr[y:y+ph, x:x+pw]
        
        # Use Tissue Threshold (non-white pixels)
        img_gray = np.mean(img_patch, axis=2)
        tissue_ratio = np.sum(img_gray < 235) / (ph * pw)
        
        # Use Mask Threshold (tumor pixels)
        mask_ratio = np.sum(mask_patch > 128) / (ph * pw)
        
        # Keep patch only if it has enough tumor AND enough tissue
        if mask_ratio >= threshold and tissue_ratio > 0.15:
            patches.append(img_patch)
            coords.append((x, y))

    return patches, coords, img_arr, mask_arr

#### 2.1.11 create_patches_dataset

In [None]:
def create_patches_dataset(input_dir, output_dir, mask_dir, patch_size=224, stride=224, threshold=0.30):
    """
    Iterates over images in input_dir, finds corresponding masks in mask_dir, 
    extracts smart patches, and saves them to output_dir.
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    mask_dir = Path(mask_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Filter for image files
    image_files = [
        f for f in input_dir.iterdir() 
        if f.name.startswith('img_') and f.suffix.lower() in {'.png', '.jpg', '.jpeg'}
    ]
    
    count = 0
    print(f"Starting patch extraction from {input_dir} to {output_dir} using masks from {mask_dir}...")
    
    for file_path in tqdm(image_files, desc="Extracting Patches"):
        # Construct mask filename (assuming mask_xxxx.png matches img_xxxx.png)
        mask_name = file_path.name.replace('img_', 'mask_', 1)
        mask_path = mask_dir / mask_name
        
        # Fallback if mask has different extension or wasn't found
        if not mask_path.exists():
             mask_path = mask_dir / (file_path.stem.replace('img_', 'mask_', 1) + ".png")

        if mask_path.exists():
            # Extract patches
            patches, coords, _, _ = extract_smart_patches(
                str(file_path), 
                str(mask_path), 
                patch_size=patch_size, 
                stride=stride, 
                threshold=threshold
            )
            
            # Save each patch
            base_name = file_path.stem
            for i, patch in enumerate(patches):
                # patch is a numpy array (H, W, 3), convert to PIL Image to save
                patch_img = Image.fromarray(patch)
                
                # Construct a unique filename for the patch
                save_name = f"{base_name}_p{i}.png"
                patch_img.save(output_dir / save_name)
                count += 1
        else:
            # Optional: Print if mask is missing
            # tqdm.write(f"Mask not found for {file_path.name}")
            pass
            
    print(f"Extraction complete. Saved {count} patches to {output_dir}")

# --- Execute Patch Extraction ---
# We use TISSUE_OUT because it contains the images we want to process (after goo/shrek removal)
# and their corresponding masks (copied in step 3).

### 2.2 Run Preprocessing

In [None]:
datasets_path = os.path.join(os.path.pardir, "an2dl2526c2")

train_data_path = os.path.join(datasets_path, "train_data")
train_labels_path = os.path.join(datasets_path, "train_labels.csv")



CSV_PATH = train_labels_path                # Path to the CSV file with labels
SOURCE_FOLDER = train_data_path

print(f"Dataset path: {datasets_path}")
print(f"Train data path: {train_data_path}")
print(f"Train labels path: {train_labels_path}")



#preprocessing step 1 output path
GOO_REMOVAL_OUT = os.path.join(datasets_path, "train_nogoo")

#preprocessing step 2 output path
SHREK_REMOVAL_OUT = os.path.join(datasets_path, "train_noshreks")
SHREKS_OUT = os.path.join(SHREK_REMOVAL_OUT, "train_shreks")
TISSUE_OUT = os.path.join(SHREK_REMOVAL_OUT, "train_tissue")

#preprocessing step 3 output path
FINAL_TRAIN_OUT = os.path.join(datasets_path, "train_masked_noshreks")
  # Where the resized unmasked images will be saved
PATCHES_OUT = os.path.join(datasets_path, "train_patches")

TARGET_SIZE = (224, 224)                    # Target size for the resized images and masks

In [None]:
# Step 1: Remove goo and do not resize images
remove_goo(SOURCE_FOLDER,GOO_REMOVAL_OUT, target_size=None, remove_goo=True, save_masks=True, replacement_color=(195, 195, 195))

In [None]:
clean_and_save_masks(
    goo_masks_dir=os.path.join(GOO_REMOVAL_OUT, "goo_masks"), 
    external_masks_dir=SOURCE_FOLDER, 
    output_dir=os.path.join(GOO_REMOVAL_OUT, "cleaned_masks"), 
    target_size=None
)

In [None]:
#Step 2: Discard Shrek Images
shreks_list, tissue_list = analyze_dataset_for_shreks(GOO_REMOVAL_OUT, ratio_threshold=0.0125)

process_classification_results(shreks_list, tissue_list, SHREKS_OUT, TISSUE_OUT, 0.0125, visualize=False)

tissue_image_names = [item['name'] for item in tissue_list]
copy_masks(tissue_image_names, SOURCE_FOLDER, TISSUE_OUT)

# Clean up memory
del shreks_list
del tissue_list
gc.collect()

In [None]:
CLEANED_MASKS_DIR = os.path.join(GOO_REMOVAL_OUT, "cleaned_masks")
create_patches_dataset(
    TISSUE_OUT, 
    PATCHES_OUT, 
    mask_dir=CLEANED_MASKS_DIR,
    patch_size=224, 
    stride=224, 
    threshold=0.01
)

## **5. Train/Val Split**

## experiment

In [None]:
def create_metadata_dataframe(patches_dir, labels_csv_path):
    """
    Creates a DataFrame mapping patch filenames to their Bag IDs and Labels.
    """
    # 1. Load the labels CSV
    # Assuming CSV structure: [image_id, label] or similar
    df_labels = pd.read_csv(labels_csv_path)
    
    # Standardize column names for easier merging
    # We assume the first column is the ID and the second is the Label
    id_col = df_labels.columns[0]
    label_col = df_labels.columns[1]
    
    # Ensure IDs in CSV are strings (to match filenames)
    df_labels[id_col] = df_labels[id_col].astype(str)
    
    # If the CSV IDs contain extensions (e.g., 'img_001.png'), remove them
    # because our parsed Bag IDs won't have them.
    df_labels[id_col] = df_labels[id_col].apply(lambda x: os.path.splitext(x)[0])

    # 2. List all patch files
    patch_files = [f for f in os.listdir(patches_dir) if f.endswith('.png')]
    
    # 3. Parse filenames to get Bag IDs
    data = []
    print(f"Found {len(patch_files)} patches. Parsing metadata...")
    
    for filename in patch_files:
        # Expected format from your preprocessing: {base_name}_p{i}.png
        # Example: "img_0015_p12.png" -> Bag ID should be "img_0015"
        
        # Split from the right on '_p' to separate Bag ID from Patch Index
        # "img_0015_p12.png" -> ["img_0015", "12.png"]
        try:
            bag_id = filename.rsplit('_p', 1)[0]
            
            data.append({
                'filename': filename,
                'sample_id': bag_id,
                'path': os.path.join(patches_dir, filename)
            })
        except IndexError:
            print(f"Skipping malformed filename: {filename}")

    # Create temporary patches DataFrame
    df_patches = pd.DataFrame(data)
    
    # 4. Merge patches with labels
    # This assigns the correct Bag Label to every Patch in that Bag
    df = pd.merge(df_patches, df_labels, left_on='sample_id', right_on=id_col, how='inner')
    
    # 5. Clean up and Rename
    # Keep only required columns
    df = df[['filename', label_col, 'sample_id', 'path']]
    
    # Rename label column to standard 'label' if it isn't already
    df = df.rename(columns={label_col: 'label'})
    
    print(f"Successfully created DataFrame with {len(df)} rows.")
    return df

In [None]:
patches_metadata_df = create_metadata_dataframe(PATCHES_OUT, CSV_PATH)

# Verify the result
print("\nFirst 5 rows:")
print(patches_metadata_df.head().drop(columns=['path']))
print("\nPatches per Bag (Distribution):")
print(patches_metadata_df['sample_id'].value_counts().describe())

In [None]:
# Add Label Encoding
print("\n" + "="*50)
print("Label Encoding")
print("="*50)

label_encoder = LabelEncoder()
patches_metadata_df['label_encoded'] = label_encoder.fit_transform(patches_metadata_df['label'])

print(f"\nOriginal Labels: {label_encoder.classes_}")
print(f"Encoded as: {list(range(len(label_encoder.classes_)))}")
print(f"\nLabel Mapping:")
for orig, enc in zip(label_encoder.classes_, range(len(label_encoder.classes_))):
    print(f"  {orig} -> {enc}")

In [None]:
def prepare_bag_lists(df):
    """
    Groups the dataframe by 'sample_id' and returns lists of paths and labels.
    Uses encoded labels for training.
    """
    # Get unique bag IDs in this dataframe slice
    unique_bags = df['sample_id'].unique()
    
    bag_paths_list = []
    bag_labels_list = []
    
    print(f"Processing {len(unique_bags)} bags...")
    
    for bag_id in unique_bags:
        # Get all rows for this specific bag
        group = df[df['sample_id'] == bag_id]
        
        # Get all file paths for this bag
        paths = group['path'].tolist()
        
        # Get the ENCODED label (they are all the same for one bag)
        label = group['label_encoded'].iloc[0]
        
        bag_paths_list.append(paths)
        bag_labels_list.append(label)
        
    return bag_paths_list, bag_labels_list

In [None]:
# 1. Split unique Bag IDs (NOT images)
unique_bag_ids = patches_metadata_df['sample_id'].unique()
# Create a mapping: Bag ID -> Label (take first occurrence since all patches in a bag have the same label)
bag_id_to_label = patches_metadata_df.groupby('sample_id')['label'].first()
# Get labels for stratification (in the same order as unique_bag_ids)
stratify_labels = [bag_id_to_label[bag_id] for bag_id in unique_bag_ids]

train_ids, val_ids = train_test_split(unique_bag_ids, test_size=0.2, random_state=42, stratify=stratify_labels)
# Note: If you have class imbalance, pass the bag labels to 'stratify' above

# 2. Slice the main DataFrame based on these IDs
train_df = patches_metadata_df[patches_metadata_df['sample_id'].isin(train_ids)]
val_df = patches_metadata_df[patches_metadata_df['sample_id'].isin(val_ids)]

print(f"Total Bags: {len(unique_bag_ids)}")
print(f"Train Bags: {len(train_ids)} | Val Bags: {len(val_ids)}")

# 3. Convert DataFrames to Lists for the Dataset class
train_paths, train_labels = prepare_bag_lists(train_df)
val_paths, val_labels = prepare_bag_lists(val_df)

# 4. Define Transforms
# Train: heavy augmentation to prevent overfitting
train_transform = T.Compose([
    T.Resize((224, 224)), # Ensure size matches ResNet input
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(15),
    T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Val: No augmentation, just resize and normalize
val_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

### 4.1 Set Input Size and Number of Classes

In [None]:
# Define the input shape based on the training data
input_shape = (3, 224, 224)
num_classes = len(label_encoder.classes_)

print("Input Shape:", input_shape)
print("Number of Classes:", num_classes)
print("Class Names:", label_encoder.classes_)

### 4.2 Create MIL Datasets and DataLoaders

In [None]:
class DynamicMILDataset(torch.utils.data.Dataset):
    def __init__(self, bag_paths, bag_labels, transform=None):
        self.bag_paths = bag_paths
        self.bag_labels = bag_labels
        self.transform = transform

    def __len__(self):
        return len(self.bag_labels)

    def __getitem__(self, idx):
        # 1. Get paths and label for this bag
        paths = self.bag_paths[idx]
        label = self.bag_labels[idx]
        
        # 2. Load images on the fly
        images = []
        for p in paths:
            img = Image.open(p).convert('RGB')
            if self.transform:
                img = self.transform(img)
            images.append(img)
            
        # 3. Stack into (N, C, H, W)
        bag_tensor = torch.stack(images)
        
        # FIX: Return tensor label directly (no need for .long() call)
        return bag_tensor, label  # label is already int from list


In [None]:
def mil_collate_fn(batch):
    """
    Custom collate function for variable-length bags.
    Since batch_size=1, this just unpacks the single item.
    """
    # batch is a list with 1 element: [(bag_tensor, label)]
    return batch[0]

In [None]:
train_dataset = DynamicMILDataset(train_paths, train_labels, transform=train_transform)
val_dataset = DynamicMILDataset(val_paths, val_labels, transform=val_transform)

# FIX: Use collate_fn to properly handle the batch
train_loader = DataLoader(
    train_dataset, 
    batch_size=1, 
    shuffle=True, 
    num_workers=0,
    pin_memory=True,
    collate_fn=mil_collate_fn  # <- ADDED
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=1, 
    shuffle=False, 
    num_workers=0,
    pin_memory=True,
    collate_fn=mil_collate_fn  # <- ADDED
)
# Test one batch to verify
data, label = next(iter(train_loader))
# Shape should be: [1, N_patches, 3, 224, 224]
print(f"Bag Shape: {data.shape}") 
# We usually need to squeeze the batch dimension for the model: [N_patches, 3, 224, 224]
print(f"Model Input Shape: {data.squeeze(0).shape}")

In [None]:
class PhikonFeatureExtractor(nn.Module):
    """
    Fixed version that handles all ViT output types correctly.
    """
    def __init__(self, freeze_backbone=True):
        super(PhikonFeatureExtractor, self).__init__()
        
        self.backbone_type = None
        
        # Try loading Phikon from Hugging Face
        try:
            from transformers import ViTModel
            print("Attempting to load Phikon from Hugging Face...")
            
            self.backbone = ViTModel.from_pretrained(
                "owkin/phikon",
                add_pooling_layer=False
            )
            self.backbone_type = "hf_vit"
            self.feature_dim = 768
            print("✓ Phikon loaded successfully from Hugging Face!")
            
        except Exception as e:
            print(f"Could not load Phikon: {e}")
            print("Loading ImageNet ViT-B/16 as fallback...")
            
            # Fallback to timm
            try:
                import timm
                self.backbone = timm.create_model(
                    'vit_base_patch16_224',
                    pretrained=True,
                    num_classes=0  # Remove classification head
                )
                self.backbone_type = "timm_vit"
                self.feature_dim = 768
                print("✓ ImageNet ViT-B/16 loaded as fallback")
            except Exception as e2:
                raise RuntimeError(f"Could not load any ViT model: {e2}")
        
        if freeze_backbone:
            self._freeze_backbone()
    
    def _freeze_backbone(self):
        """Freeze all backbone parameters."""
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        frozen = sum(p.numel() for p in self.backbone.parameters() if not p.requires_grad)
        print(f"✓ Backbone frozen ({frozen:,} parameters)")
    
    def forward(self, x):
        """
        Extract features from patches.
        
        Args:
            x: (N_patches, 3, 224, 224) - batch of image patches
            
        Returns:
            features: (N_patches, 768) - extracted features
        """
        # Get raw outputs from backbone
        outputs = self.backbone(x)
        
        # Handle different output types
        if self.backbone_type == "hf_vit":
            # Hugging Face ViT returns BaseModelOutputWithPooling
            # Access last_hidden_state and take CLS token (index 0)
            features = outputs.last_hidden_state[:, 0, :]
            
        elif self.backbone_type == "timm_vit":
            # timm with num_classes=0 returns tensor directly
            features = outputs
            
        else:
            # Generic fallback
            if isinstance(outputs, torch.Tensor):
                features = outputs
            elif hasattr(outputs, 'last_hidden_state'):
                features = outputs.last_hidden_state[:, 0, :]
            elif hasattr(outputs, 'pooler_output'):
                features = outputs.pooler_output
            else:
                raise TypeError(f"Unexpected output type: {type(outputs)}")
        
        return features


# ============================================================
# Test the Feature Extractor
# ============================================================

def test_feature_extractor():
    """Test that feature extraction works correctly."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print("\n" + "="*60)
    print("Testing Feature Extractor")
    print("="*60)
    
    # Create model
    extractor = PhikonFeatureExtractor(freeze_backbone=True).to(device)
    extractor.eval()
    
    # Create dummy input (10 patches)
    dummy_bag = torch.randn(10, 3, 224, 224).to(device)
    
    print(f"\nInput shape: {dummy_bag.shape}")
    
    # Extract features
    with torch.no_grad():
        features = extractor(dummy_bag)
    
    print(f"Output shape: {features.shape}")
    print(f"Expected shape: (10, 768)")
    
    # Verify shape
    assert features.shape == (10, 768), f"Wrong shape! Got {features.shape}"
    print("\n✓ Feature extraction works correctly!")
    
    return extractor

In [None]:
class AttentionMILAggregator(nn.Module):
    """Gated attention mechanism for MIL."""
    def __init__(self, feature_dim=768, hidden_dim=256, num_classes=4):
        super(AttentionMILAggregator, self).__init__()
        
        self.attention_V = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.Tanh()
        )
        self.attention_U = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.Sigmoid()
        )
        self.attention_weights = nn.Linear(hidden_dim, 1)
        
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, features):
        """
        Args:
            features: (N_patches, 768)
        Returns:
            logits: (1, num_classes)
            attention: (1, N_patches)
        """
        A_V = self.attention_V(features)
        A_U = self.attention_U(features)
        A = self.attention_weights(A_V * A_U)
        A = torch.transpose(A, 1, 0)
        A = torch.softmax(A, dim=1)
        
        M = torch.mm(A, features)
        logits = self.classifier(M)
        
        return logits, A


class PhikonMIL(nn.Module):
    """Complete Phikon-based MIL model."""
    def __init__(self, num_classes=4, freeze_backbone=True):
        super(PhikonMIL, self).__init__()
        
        self.feature_extractor = PhikonFeatureExtractor(freeze_backbone=freeze_backbone)
        self.aggregator = AttentionMILAggregator(
            feature_dim=768,
            hidden_dim=256,
            num_classes=num_classes
        )
    
    def forward(self, bag):
        """
        Args:
            bag: (N_patches, 3, 224, 224)
        Returns:
            logits: (1, num_classes)
            attention: (1, N_patches)
        """
        features = self.feature_extractor(bag)
        logits, attention = self.aggregator(features)
        return logits, attention

In [None]:

# NEW CODE (Phikon):
mil_model = PhikonMIL(num_classes=num_classes, freeze_backbone=True).to(device)
# Test it first!
test_bag = torch.randn(5, 3, 224, 224).to(device)
test_logits, test_attn = mil_model(test_bag)
print(f"Test output shape: {test_logits.shape}")  # Should be (1, 4)

# Setup optimizer (only trainable parameters)
trainable_params = [p for p in mil_model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=1e-4, weight_decay=1e-4)

# Print parameter counts
total = sum(p.numel() for p in mil_model.parameters())
trainable = sum(p.numel() for p in trainable_params)
print(f"\nTotal parameters: {total:,}")
print(f"Trainable parameters: {trainable:,} ({100*trainable/total:.2f}%)")

In [None]:

epochs = 5
# Then proceed with training as before!

# Setup optimizer - ONLY train the aggregator
trainable_params = [p for p in mil_model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=1e-4, weight_decay=1e-4)

# Add learning rate scheduler (IMPORTANT for transfer learning)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=epochs, 
    eta_min=1e-4
)

# Print parameter counts
total = sum(p.numel() for p in mil_model.parameters())
trainable = sum(p.numel() for p in trainable_params)
print(f"Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.1f}%)")

In [None]:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, mil_model.parameters()),
    lr=1e-3, weight_decay=1e-4
)


# Early stopping variables
patience = 10
best_val_f1 = -np.inf
patience_counter = 0
best_model_state = None

# Track metrics for plotting
train_losses = []
val_losses = []
train_f1_scores = []
val_f1_scores = []

for epoch in range(epochs):
    mil_model.train()
    train_loss = 0.0
    train_preds = []
    train_labels_list = []
    
    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
    
    for bag, label in train_loop:
        bag = bag.to(device)
        if not isinstance(label, torch.Tensor):
            label = torch.tensor(label, dtype=torch.long)
        label = label.to(device)
        
        optimizer.zero_grad()
        logits, _ = mil_model(bag)
        
        loss = criterion(logits.squeeze(0), label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(logits, 1)
        train_preds.append(predicted.cpu().item())
        train_labels_list.append(label.cpu().item() if isinstance(label, torch.Tensor) else label)
        
        train_loop.set_postfix(loss=loss.item())
    
    # Calculate metrics
    train_f1 = f1_score(train_labels_list, train_preds, average='weighted', zero_division=0)
    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    train_f1_scores.append(train_f1)
    
    # Validation (same as before)
    mil_model.eval()
    val_loss = 0.0
    val_preds = []
    val_labels_list = []
    
    with torch.no_grad():
        for bag, label in val_loader:
            bag = bag.to(device)
            if not isinstance(label, torch.Tensor):
                label = torch.tensor(label, dtype=torch.long)
            label = label.to(device)
            
            logits, _ = mil_model(bag)
            loss = criterion(logits.squeeze(0), label)
            val_loss += loss.item()
            
            _, predicted = torch.max(logits, 1)
            val_preds.append(predicted.cpu().item())
            val_labels_list.append(label.cpu().item() if isinstance(label, torch.Tensor) else label)
    
    val_f1 = f1_score(val_labels_list, val_preds, average='weighted', zero_division=0)
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    val_f1_scores.append(val_f1)
    
    # NEW: Step the scheduler
    scheduler.step()
    
    print(f"Epoch {epoch+1:2d} | Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | Train F1: {train_f1:.4f} | "
          f"Val F1: {val_f1:.4f} | LR: {scheduler.get_last_lr()[0]:.6f}", end="")
    
    # Early stopping (same as before)
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        patience_counter = 0
        best_model_state = copy.deepcopy(mil_model.state_dict())
        print(" (Best F1!)")
    else:
        patience_counter += 1
        print(f" (Patience: {patience_counter}/{patience})")
        
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered! Best Val F1: {best_val_f1:.4f}")
            mil_model.load_state_dict(best_model_state)
            break

In [None]:
# Save the best model

MODEL_NAME = "attention_mil_best_model.pt"
model_save_path = os.path.join("models", MODEL_NAME)
torch.save(best_model_state, model_save_path)
print(f"Best model saved to {model_save_path}")
print(f"Best Validation F1 Score: {best_val_f1:.4f}")

In [None]:
# Create figure with subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Generate confusion matrix on validation set
mil_model.eval()
final_val_preds = []
final_val_labels = []

with torch.no_grad():
    for bag, label in val_loader:
        bag = bag.to(device)
        if not isinstance(label, torch.Tensor):
            label = torch.tensor(label, dtype=torch.long)
        label = label.to(device)
        logits, _ = mil_model(bag)
        _, predicted = torch.max(logits, 1)
        final_val_preds.append(predicted.cpu().item())
        final_val_labels.append(label.cpu().item())

# 1. Confusion Matrix
cm = confusion_matrix(final_val_labels, final_val_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0], cbar=True)
axes[0].set_title('Validation Confusion Matrix (Bag Level)', fontsize=14, fontweight='bold')
axes[0].set_ylabel('True Label')
axes[0].set_xlabel('Predicted Label')

# 2. Loss Plot
axes[1].plot(range(1, epoch+2), train_losses, 'o-', label='Train Loss', linewidth=2, markersize=6)
axes[1].plot(range(1, epoch+2), val_losses, 's-', label='Val Loss', linewidth=2, markersize=6)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Loss', fontsize=12)
axes[1].set_title('Training vs Validation Loss', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

# 3. F1 Score Plot
axes[2].plot(range(1, epoch+2), train_f1_scores, 'o-', label='Train F1', linewidth=2, markersize=6)
axes[2].plot(range(1, epoch+2), val_f1_scores, 's-', label='Val F1', linewidth=2, markersize=6)
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('F1 Score', fontsize=12)
axes[2].set_title('Training vs Validation F1 Score', fontsize=14, fontweight='bold')
axes[2].legend(fontsize=11)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## x. Submission

In [None]:
SUBMISSION_PATH = os.path.join(datasets_path, "test_data")

In [None]:
def create_submission_csv_phikon(model, submission_path, label_encoder, output_csv="submission.csv"):
    """
    Updated submission function for Phikon-MIL model.
    """
    model.eval()
    results = []
    
    image_files = sorted([
        f for f in os.listdir(submission_path) 
        if f.startswith('img_') and f.lower().endswith(('.png', '.jpg', '.jpeg'))
    ])
    
    # Define transform (MUST match training transform!)
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    print(f"Found {len(image_files)} images in {submission_path}")
    
    with torch.no_grad():
        for filename in tqdm(image_files, desc="Processing Submission Images"):
            img_path = os.path.join(submission_path, filename)
            mask_filename = filename.replace('img_', 'mask_')
            mask_path = os.path.join(submission_path, mask_filename)
            
            try:
                patches, _, _, _ = extract_smart_patches(
                    img_path, 
                    mask_path, 
                    patch_size=224, 
                    stride=224, 
                    threshold=0.01 
                )
            except Exception as e:
                print(f"Error extracting patches for {filename}: {e}")
                patches = []

            if len(patches) == 0:
                pred_idx = 0
            else:
                # Convert patches to tensor
                patches_list = []
                for patch in patches:
                    patch_pil = Image.fromarray(patch)
                    patch_tensor = transform(patch_pil)
                    patches_list.append(patch_tensor)
                
                bag_tensor = torch.stack(patches_list).to(device)
                
                # Predict
                logits, _ = model(bag_tensor)
                _, predicted = torch.max(logits, 1)
                pred_idx = predicted.item()
            
            pred_label = label_encoder.inverse_transform([pred_idx])[0]
            results.append([filename, pred_label])
            
    # Write CSV
    import csv
    with open(output_csv, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["sample_index", "label"])
        writer.writerows(results)
        
    print(f"Submission file saved to {output_csv}")

# Run submission
create_submission_csv_phikon(mil_model, SUBMISSION_PATH, label_encoder, "phikon_submission.csv")