## **I. Google Colab Initializtion (Only on Colab)**

In [21]:
# from google.colab import drive
# drive.mount('/content/drive')
# cur_dir = "/content/drive/MyDrive/CH2/Notebooks"
# %cd $cur_dir



In [22]:
#%pip install torchview

## **1. Import Libraries**

In [23]:
# Set seed for reproducibility
SEED = 42

# Import necessary libraries
import os

# Set environment variables before importing modules
os.environ['PYTHONHASHSEED'] = str(SEED)
os.environ['MPLCONFIGDIR'] = os.getcwd() + '/configs/'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Suppress warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

# Import necessary modules
import logging
import random
import numpy as np

# Set seeds for random number generators in NumPy and Python
np.random.seed(SEED)
random.seed(SEED)


from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision.transforms import v2 as transforms
from torch.utils.data import TensorDataset, DataLoader
from torchview import draw_graph
from scipy import ndimage
from PIL import Image


# Import other libraries
import cv2
import copy
import shutil
from itertools import product
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


from PIL import Image
import matplotlib.gridspec as gridspec

from tqdm import tqdm
import glob
from pathlib import Path
import shutil
import gc

# Configure plot display settings
sns.set(font_scale=1.4)
sns.set_style('white')
plt.rc('font', size=14)
%matplotlib inline

## **2. Preprocessing**

- Preprocessing pipeline : 
    - Get Loaded Images
    - Create Goo Masks
    - Apply Goo Removal + Resizing
    - Discard Shrek Images
    - Apply the external Masks (optional)
    

### 2.1 Preprocessing Functions

#### 2.1.1 _get_smart_goo_mask

In [24]:
def _get_smart_goo_mask(img_bgr):
    """
    Internal helper to detect goo using Core & Shell logic + 1px Nudge.
    Returns a binary mask (White = Goo, Black = Safe).
    """
    # 1. Convert to HSV
    hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)

    # 2. Define Ranges
    # CORE: Solid Green (Strict)
    core_lower = np.array([35, 100, 50])
    core_upper = np.array([85, 255, 255])
    
    # SHELL: Faint Halo (Loose/Transparent)
    shell_lower = np.array([30, 30, 30])
    shell_upper = np.array([95, 255, 255])

    # 3. Create initial masks
    mask_core = cv2.inRange(hsv, core_lower, core_upper)
    mask_shell = cv2.inRange(hsv, shell_lower, shell_upper)

    # 4. Smart Combine (Connected Components)
    # Keep 'Shell' blobs ONLY if they touch 'Core' blobs
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_shell, connectivity=8)
    smart_mask = np.zeros_like(mask_core)

    for label_id in range(1, num_labels): # Skip background (0)
        blob_mask = (labels == label_id).astype(np.uint8) * 255
        
        # Check overlap with Core
        overlap = cv2.bitwise_and(blob_mask, mask_core)
        
        # If there is ANY overlap, keep the blob
        if cv2.countNonZero(overlap) > 0:
            smart_mask = cv2.bitwise_or(smart_mask, blob_mask)

    # 5. Fill Holes (in case the goo has shiny reflections)
    contours, _ = cv2.findContours(smart_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    final_filled_mask = np.zeros_like(smart_mask)
    for contour in contours:
        # Minimum area filter (200px) to remove tiny stray noise
        if cv2.contourArea(contour) > 200:
            cv2.drawContours(final_filled_mask, [contour], -1, (255), thickness=cv2.FILLED)

    # 6. The "1-Pixel Nudge"
    # Safely expand by 1 pixel to cover the final anti-aliased fringe
    kernel = np.ones((3, 3), np.uint8)
    final_expanded_mask = cv2.dilate(final_filled_mask, kernel, iterations=1)

    return final_expanded_mask



#### 2.1.2 remove_goo

In [25]:
def remove_goo(input_dir, output_dir, target_size=(224, 224), remove_goo=True, save_masks=True, replacement_color=(0, 0, 0)):
    """
    Iterates through input_dir, finds 'img_xxxx', resizes them to target_size, 
    and saves the result to output_dir.
    If remove_goo is True, it replaces green pixels (using Smart Core/Shell logic) with replacement_color.
    replacement_color: Tuple of (B, G, R) values. Default is black (0, 0, 0).
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    
    # Create output directory if it doesn't exist
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Extensions to look for
    valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
    
    # 1. Gather all valid image files first
    print(f"Scanning for images in: {input_dir}...")
    image_files = [
        f for f in input_dir.iterdir() 
        if f.name.startswith('img_') and f.suffix.lower() in valid_extensions
    ]
    
    if not image_files:
        print("No images found starting with 'img_' in the directory.")
        return

    # 2. Iterate with tqdm
    count = 0
    
    for file_path in tqdm(image_files, desc="Removing Goo from Images", unit="img"):
        output_path = output_dir / file_path.name
        
        if output_path.exists():
            # Skip silently
            continue

        img = cv2.imread(str(file_path))
        if img is None:
            continue
            
        if target_size is not None:
            img = cv2.resize(img, target_size)
            
        if remove_goo:
            # --- NEW SMART GOO LOGIC ---
            # Get the smart mask (White = Goo)
            goo_mask = _get_smart_goo_mask(img)
            
            # Invert Goo Mask (White = Safe)
            not_goo_mask = cv2.bitwise_not(goo_mask)
            
            # Apply Mask to keep safe areas (Goo areas become black/0)
            img_safe = cv2.bitwise_and(img, img, mask=not_goo_mask)
            
            # Create background with replacement color
            bg = np.full_like(img, replacement_color)
            
            # Keep background only where Goo is
            bg_goo = cv2.bitwise_and(bg, bg, mask=goo_mask)
            
            # Combine: Safe Image + Colored Goo Areas
            img = cv2.add(img_safe, bg_goo)

            if save_masks:
                # Save the mask (White = Safe/Tissue, Black = Goo)
                mask_name = file_path.name.replace('img_', 'goo_mask_', 1)
                mask_output_path = os.path.join(output_dir, "goo_masks", mask_name)
                Path(os.path.dirname(mask_output_path)).mkdir(parents=True, exist_ok=True)
                cv2.imwrite(str(mask_output_path), not_goo_mask)
            
        cv2.imwrite(str(output_path), img)
        count += 1

    print(f"Resizing complete. Processed {count} new images.")

#### 2.1.3 clean_and_save_masks

In [26]:
def clean_and_save_masks(goo_masks_dir, external_masks_dir, output_dir, target_size=(224, 224)):
    """
    Loads goo masks (White=Safe) and original external masks.
    Removes goo areas from external masks and saves the cleaned versions.
    Skips processing if a cleaned mask already exists in the output directory.
    """
    goo_masks_dir = Path(goo_masks_dir)
    external_masks_dir = Path(external_masks_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    goo_mask_files = list(goo_masks_dir.glob('goo_mask_*.png'))
    
    if not goo_mask_files:
        print(f"No goo masks found in {goo_masks_dir}")
        return

    print(f"Found {len(goo_mask_files)} goo masks. Checking for existing and processing new masks...")

    processed_count = 0
    skipped_count = 0
    for goo_mask_path in tqdm(goo_mask_files, desc="Cleaning External Masks"):
        # Derive the corresponding output mask name
        mask_name = goo_mask_path.name.replace('goo_mask_', 'mask_', 1)
        output_path = output_dir / mask_name


        # Check if the cleaned mask already exists to skip reprocessing
        if output_path.exists():
            skipped_count += 1
            continue


        external_mask_path = external_masks_dir / mask_name
        
        if not external_mask_path.exists():
            # This part was commented out in the original, but it's good practice
            # to log when a corresponding file is missing.
            # tqdm.write(f"Warning: External mask not found for {goo_mask_path.name}")
            continue

        # Load masks
        goo_mask = cv2.imread(str(goo_mask_path), cv2.IMREAD_GRAYSCALE)
        external_mask = cv2.imread(str(external_mask_path), cv2.IMREAD_GRAYSCALE)

        if goo_mask is None or external_mask is None:
            tqdm.write(f"Warning: Could not read one of the masks for {mask_name}")
            continue

        # Resize external mask if a target size is specified
        if target_size is not None:
             external_mask = cv2.resize(external_mask, target_size, interpolation=cv2.INTER_NEAREST)
             # Ensure goo_mask also matches the target size
             if goo_mask.shape[:2] != (target_size[1], target_size[0]):
                 goo_mask = cv2.resize(goo_mask, target_size, interpolation=cv2.INTER_NEAREST)

        # Ensure masks are binary (0 or 255)
        _, external_mask = cv2.threshold(external_mask, 127, 255, cv2.THRESH_BINARY)
        _, goo_mask = cv2.threshold(goo_mask, 127, 255, cv2.THRESH_BINARY)

        # Combine masks: The resulting pixel is white only if it's white in BOTH masks.
        # This effectively removes "goo" areas from the "region of interest".
        cleaned_mask = cv2.bitwise_and(external_mask, goo_mask)

        # Save the final cleaned mask
        cv2.imwrite(str(output_path), cleaned_mask)
        processed_count += 1

    print("\nProcessing complete.")
    print(f"  - Cleaned and saved: {processed_count} masks to {output_dir}")
    if skipped_count > 0:
        print(f"  - Skipped: {skipped_count} masks that already existed.")

#### 2.1.4 apply_mask

In [27]:
def apply_mask(image_path, mask_path, output_path, target_size=(224, 224), remove_goo=True):
    """
    Loads an image and a mask. 
    If remove_goo is True, it subtracts green pixels (using Smart Core/Shell logic) 
    from the valid mask area. Resizes and saves the result.
    """
    # 1. Load Image
    img = cv2.imread(str(image_path))
    if img is None:
        tqdm.write(f"Error: Could not load image at {image_path}")
        return

    # 2. Load External Mask (Read as grayscale)
    mask = cv2.imread(str(mask_path), 0)
    if mask is None:
        tqdm.write(f"Error: Could not load mask at {mask_path}")
        return

    # 3. Resize both to target size
    if target_size is not None:
        img = cv2.resize(img, target_size)
        mask = cv2.resize(mask, target_size)

    # 4. Standardize External Mask (Binary 0 or 255)
    _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)

    # 5. Determine Final Mask
    if remove_goo:
        # --- NEW SMART GOO LOGIC ---
        # Get the smart mask (White = Goo)
        goo_mask = _get_smart_goo_mask(img)
        
        # Invert Goo Mask (White = Safe)
        not_goo_mask = cv2.bitwise_not(goo_mask)
        
        # Combine: Must be Tissue (binary_mask) AND Safe (not_goo_mask)
        final_mask = cv2.bitwise_and(binary_mask, not_goo_mask)
    else:
        # --- ORIGINAL LOGIC ---
        final_mask = binary_mask

    # 6. Apply Final Mask
    # Areas outside the final mask become Black (0)
    masked_img = cv2.bitwise_and(img, img, mask=final_mask)

    # 7. Save result
    # Ensure output directory exists
    os.makedirs(os.path.dirname(str(output_path)), exist_ok=True)
    cv2.imwrite(str(output_path), masked_img)

#### 2.1.5 filter_bright_green_areas

In [28]:
def filter_bright_green_areas(image, lg_H=20, lg_S=45, lg_V=0, ug_H=84, ug_S=255, ug_V=255, dilate_iterations=2):
    """
    Filters out bright green areas from the input image with improved residual removal.
    
    Args:
        image: Input image in RGB format (0-1 range)
        lg_H, lg_S, lg_V: Lower bounds for HSV green detection
        ug_H, ug_S, ug_V: Upper bounds for HSV green detection
        dilate_iterations: Number of dilation iterations to expand mask (removes edge artifacts)
    """

    # Convert from RGB (0-1) to BGR (0-255) for OpenCV
    original_bgr = (image * 255).astype(np.uint8)[..., ::-1]

    # 1. Convert to HSV (Hue, Saturation, Value)
    hsv = cv2.cvtColor(original_bgr, cv2.COLOR_BGR2HSV)

    # 2. Define the "Bright Green" Range
    lower_green = (lg_H, lg_S, lg_V)
    upper_green = (ug_H, ug_S, ug_V)

    # Create the initial mask
    mask = cv2.inRange(hsv, lower_green, upper_green)

    # 3. Morphological operations to clean up the mask
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))

    # OPEN: Remove small noise
    clean_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)

    # DILATE: Expand the mask to catch edge artifacts and residuals
    # This ensures we remove green pixels at the boundaries
    if dilate_iterations > 0:
        clean_mask = cv2.dilate(clean_mask, kernel, iterations=dilate_iterations)

    # 4. Additional step: Detect any remaining green-ish pixels
    # Create a more aggressive mask for subtle green tones
    lower_green_subtle = (max(0, lg_H - 10), max(0, lg_S - 10), 0)
    upper_green_subtle = (min(180, ug_H + 10), 255, 255)
    subtle_mask = cv2.inRange(hsv, lower_green_subtle, upper_green_subtle)
    
    # Only keep subtle green pixels that are near the main green area
    subtle_mask = cv2.morphologyEx(subtle_mask, cv2.MORPH_OPEN, kernel, iterations=1)
    
    # Combine masks
    combined_mask = cv2.bitwise_or(clean_mask, subtle_mask)

    # 5. Invert mask to keep the useful parts
    mask_inv = cv2.bitwise_not(combined_mask)

    # 6. Apply the mask
    result_bgr = cv2.bitwise_and(original_bgr, original_bgr, mask=mask_inv)

    return result_bgr, combined_mask

#### 2.1.6 analyze_dataset_for_shreks

In [29]:
def analyze_dataset_for_shreks(directory, shrek_dir, ratio_threshold=0.0125):
    shrek_images = []
    tissue_images = []
    
    image_files = glob.glob(os.path.join(directory, 'img_*.png'))
    print(f"Found {len(image_files)} images in {directory}")

    for f in tqdm(image_files, desc="Analyzing for Shreks"):
        try:
            # Load image (BGR)
            img = cv2.imread(f)
            if img is None: continue
            
            # Prepare image for the new filter: Convert BGR to RGB (0-1 float)
            img_rgb_norm = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
            
            # Apply Filter
            result_bgr, mask = filter_bright_green_areas(img_rgb_norm)
            
            # Calculate Ratio of Green Pixels from the combined mask
            total_pixels = img.shape[0] * img.shape[1]
            green_pixels = np.count_nonzero(mask)
            ratio = green_pixels / total_pixels
            
            # Convert BGR to RGB for plotting
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            entry = {
                'name': os.path.basename(f), 
                'path': f,
                'img': img_rgb, 
                'ratio': ratio,
                'mask': mask
            }

            # === CLASSIFICATION LOGIC ===
            if ratio > ratio_threshold:
                shrek_images.append(entry)
            else:
                tissue_images.append(entry)
                
        except Exception as e:
            print(f"Skipping {f}: {e}")

    return shrek_images, tissue_images

#### 2.1.7 process_batch

In [30]:
def analyze_dataset_for_shreks(directory, shrek_dir, ratio_threshold=0.0125, expected_count=150):
    """
    Analyzes images in a directory to classify them as "shrek" or "tissue".

    If the target shrek_dir already exists and contains the expected_count of images,
    this function will skip the analysis to save time.

    Args:
        directory (str): The source directory containing images to analyze.
        shrek_dir (str): The target directory where "shrek" images are saved.
                         This is used to check if processing can be skipped.
        ratio_threshold (float): The threshold for green pixel ratio to be classified as "shrek".
        expected_count (int): The number of images expected in shrek_dir to skip processing.

    Returns:
        tuple: A tuple containing two lists: (shrek_images, tissue_images).
               Returns ([], []) if processing is skipped.
    """
    # --- Start of new implementation ---
    shrek_dir_path = Path(shrek_dir)
    if shrek_dir_path.is_dir():
        # Count image files (e.g., .png, .jpg) in the shrek directory
        num_existing_images = len(list(shrek_dir_path.glob('img_*.png')))
        
        if num_existing_images == expected_count:
            print(f"'{shrek_dir}' already contains exactly {expected_count} images. Skipping analysis.")
            return [], [] # Return empty lists as the analysis is skipped
    # --- End of new implementation ---

    shrek_images = []
    tissue_images = []
    
    image_files = glob.glob(os.path.join(directory, 'img_*.png'))
    print(f"Found {len(image_files)} images in '{directory}'. Analyzing for 'Shreks'...")

    for f in tqdm(image_files, desc="Analyzing for Shreks"):
        try:
            img = cv2.imread(f)
            if img is None: continue
            
            img_rgb_norm = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
            
            # This function call is part of your original code
            result_bgr, mask = filter_bright_green_areas(img_rgb_norm)
            
            total_pixels = img.shape[0] * img.shape[1]
            green_pixels = np.count_nonzero(mask)
            ratio = green_pixels / total_pixels
            
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            entry = {
                'name': os.path.basename(f), 
                'path': f,
                'img': img_rgb, 
                'ratio': ratio,
                'mask': mask
            }

            if ratio > ratio_threshold:
                shrek_images.append(entry)
            else:
                tissue_images.append(entry)
                
        except Exception as e:
            print(f"Skipping {f}: {e}")

    return shrek_images, tissue_images

#### 2.1.8 process_classification_results

In [31]:
def process_classification_results(shrek_list, tissue_list, shrek_dir, tissue_dir, threshold, visualize = True):
    """
    Saves classified images to respective directories and visualizes the results.
    
    Args:
        shrek_list (list): List of dicts containing Shrek image data.
        tissue_list (list): List of dicts containing Tissue image data.
        shrek_dir (str): Path to save Shrek images.
        tissue_dir (str): Path to save Tissue images.
        threshold (float): The green ratio threshold used for classification.
    """
    
    # 1. Print Summary
    print(f"Classified {len(shrek_list)} as Shrek")
    print(f"Classified {len(tissue_list)} as Tissue")

    # Ensure directories exist
    os.makedirs(shrek_dir, exist_ok=True)
    os.makedirs(tissue_dir, exist_ok=True)

    # 2. Save Shrek images
    print(f"Saving {len(shrek_list)} Shrek images to {shrek_dir}...")
    for item in tqdm(shrek_list, desc="Saving Shrek Images"):
        dest_path = os.path.join(shrek_dir, item['name'])
        
        # Check if file exists to prevent overwriting
        if os.path.exists(dest_path):
            continue

        try:
            shutil.copy2(item['path'], dest_path)
        except Exception as e:
            print(f"Error copying {item['name']} to shrek folder: {e}")

    # 3. Save Tissue images
    print(f"Saving {len(tissue_list)} Tissue images to {tissue_dir}...")
    for item in tqdm(tissue_list, desc="Saving Tissue Images"):
        dest_path = os.path.join(tissue_dir, item['name'])
        
        # Check if file exists to prevent overwriting
        if os.path.exists(dest_path):
            continue

        try:
            shutil.copy2(item['path'], dest_path)
        except Exception as e:
            print(f"Error copying {item['name']} to tissue folder: {e}")
    if visualize == True:
        # 4. Visualize Examples (2x2 Grid)
        if len(shrek_list) >= 2 and len(tissue_list) >= 2:
            fig_ex, axes = plt.subplots(2, 2, figsize=(12, 10))
            fig_ex.suptitle(f"Classification Results (Threshold: {threshold:.1%})", fontsize=16)

            def show_img(ax, item, label):
                ax.imshow(item['img'])
                # Show the Green Ratio in the title so you can see WHY it was classified
                ax.set_title(f"{label}\n{item['name']}\nGreen Pixels: {item['ratio']:.2%}")
                ax.axis('off')

            # Row 1: Detected Shrek
            show_img(axes[0, 0], shrek_list[0], "Detected Shrek")
            show_img(axes[0, 1], shrek_list[1], "Detected Shrek")
            
            # Row 2: Detected Tissue
            show_img(axes[1, 0], tissue_list[0], "Detected Tissue")
            show_img(axes[1, 1], tissue_list[1], "Detected Tissue")

            plt.tight_layout()
            plt.show()
        else:
            print("Not enough images in one or both classes to generate 2x2 sample grid.")

        # 5. Plot Scatter Distribution for Tuning
        shrek_ratios = [x['ratio'] for x in shrek_list]
        tissue_ratios = [x['ratio'] for x in tissue_list]

        plt.figure(figsize=(12, 6))

        # Plot Tissue points (Blue)
        plt.scatter(range(len(tissue_ratios)), tissue_ratios, color='blue', alpha=0.6, label='Classified as Tissue')

        # Plot Shrek points (Green) - Shifted on x-axis to be distinct
        # We shift the x-axis index for Shrek so they appear after the tissue points
        plt.scatter(range(len(tissue_ratios), len(tissue_ratios) + len(shrek_ratios)), shrek_ratios, color='green', alpha=0.6, label='Classified as Shrek')

        # Draw the Threshold Line
        plt.axhline(y=threshold, color='red', linestyle='--', linewidth=2, label=f'Threshold ({threshold:.1%})')

        plt.title('Green Pixel Ratio per Image', fontsize=14)
        plt.ylabel('Ratio of Green Pixels (0.0 - 1.0)')
        plt.xlabel('Image Index')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
    else:
        print("Shrek removal visulization OFF.")

# --- Example Usage ---
# shrek_list, tissue_list = analyze_dataset(DATASET_PATH) # Assuming this runs before
# process_classification_results(shrek_list, tissue_list, SHREK_DIR, TISSUE_DIR, RATIO_THRESHOLD)

#### 2.1.9 copy_masks

In [32]:
def copy_masks(image_list, masks_dir, output_dir):
    image_names = image_list
    mask_names = [name.replace('img_', 'mask_', 1) for name in image_names]

    for mask_name in tqdm(mask_names, desc="Copying Masks"):
        src_path = os.path.join(masks_dir, mask_name)
        dst_path = os.path.join(output_dir, mask_name)
        shutil.copy(src_path, dst_path)
    
    return

#### 2.1.10 extract_smart_patches

In [33]:
def extract_smart_patches(img_path, mask_path, patch_size=224, stride=224, threshold=0.30):
    """
    Intelligently extracts patches. 
    UPDATED: Groups nearby tumor spots and centers the patch on the region.
    """
    # Load images
    try:
        img = Image.open(img_path).convert("RGB")
    except FileNotFoundError:
        # Fallback for common extension swap if .png not found
        if img_path.endswith(".png"):
            img = Image.open(img_path.replace(".png", ".jpg")).convert("RGB")
        else:
            raise

    mask = Image.open(mask_path).convert("L")
    img_arr = np.array(img)
    mask_arr = np.array(mask)

    # Normalize mask
    if mask_arr.max() <= 1:
        mask_check = mask_arr * 255
    else:
        mask_check = mask_arr
    
    h, w, _ = img_arr.shape
    
    if isinstance(patch_size, int):
        ph, pw = patch_size, patch_size
    else:
        ph, pw = patch_size
        
    if isinstance(stride, int):
        sh, sw = stride, stride
    else:
        sh, sw = stride

    # --- Intelligent Extraction Logic ---
    
    # 1. GROUPING: Dilate the mask to merge nearby small dots into larger regions.
    # This prevents generating 1 patch per pixel-sized dot.
    # iterations=15 means dots within ~15 pixels of each other get merged.
    dilated_mask = ndimage.binary_dilation(mask_check > 128, iterations=15)
    
    # 2. Label the merged regions
    labeled_mask, num_features = ndimage.label(dilated_mask)
    objects = ndimage.find_objects(labeled_mask)
    
    candidate_coords = set()
    
    def get_valid_start(val, max_val, p_dim):
        return max(0, min(val, max_val - p_dim))

    # print(f"Found {num_features} clustered tumor regions.") # Commented out for batch processing

    for i, slice_obj in enumerate(objects):
        y_slice, x_slice = slice_obj
        
        # Region boundaries
        y_min, y_max = y_slice.start, y_slice.stop
        x_min, x_max = x_slice.start, x_slice.stop
        
        # --- Strategy: Center on Blob ---
        # We calculate the center of the blob and place the patch there.
        
        blob_cy = (y_min + y_max) // 2
        blob_cx = (x_min + x_max) // 2
        
        # Top-left corner for the patch to be centered on the blob center
        start_y = blob_cy - ph // 2
        start_x = blob_cx - pw // 2
        
        valid_y = get_valid_start(start_y, h, ph)
        valid_x = get_valid_start(start_x, w, pw)
        
        candidate_coords.add((valid_x, valid_y))

    # 3. Final Validation
    patches = []
    coords = []
    
    for (x, y) in candidate_coords:
        mask_patch = mask_check[y:y+ph, x:x+pw]
        img_patch = img_arr[y:y+ph, x:x+pw]
        
        # Use Tissue Threshold (non-white pixels)
        img_gray = np.mean(img_patch, axis=2)
        tissue_ratio = np.sum(img_gray < 235) / (ph * pw)
        
        # Use Mask Threshold (tumor pixels)
        mask_ratio = np.sum(mask_patch > 128) / (ph * pw)
        
        # Keep patch only if it has enough tumor AND enough tissue
        if mask_ratio >= threshold and tissue_ratio > 0.15:
            patches.append(img_patch)
            coords.append((x, y))

    return patches, coords, img_arr, mask_arr

#### 2.1.11 create_patches_dataset

In [34]:
def create_patches_dataset(input_dir, output_dir, mask_dir, patch_size=224, stride=224, threshold=0.01):
    """
    Iterates over images in input_dir, finds corresponding masks in mask_dir,
    extracts smart patches, and saves them to output_dir.
    Also saves corresponding mask patches to a 'masks' subdirectory.

    This function will skip processing for any source image if its corresponding
    patches are already found in the output directory.
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    mask_dir = Path(mask_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create subdirectory for mask patches
    masks_output_dir = output_dir / "masks"
    masks_output_dir.mkdir(parents=True, exist_ok=True)

    # Filter for image files
    image_files = sorted([
        f for f in input_dir.iterdir()
        if f.name.startswith('img_') and f.suffix.lower() in {'.png', '.jpg', '.jpeg'}
    ])

    total_patches_saved = 0
    images_processed = 0
    images_skipped = 0
    
    print(f"Starting patch extraction from {input_dir} to {output_dir}...")
    print(f"Found {len(image_files)} source images.")

    for source_image_path in tqdm(image_files, desc="Processing Images"):
        base_name = source_image_path.stem

        # --- CHECK FOR EXISTING PATCHES ---
        # Check if any patch file for this image already exists.
        # We use glob to find files matching the pattern like 'img_xxxx_p*.png'.
        # next() with a default value is an efficient way to check for existence
        # without listing all files.
        if next(output_dir.glob(f"{base_name}_p*.png"), None):
            # tqdm.write(f"Skipping {source_image_path.name}, patches already exist.")
            images_skipped += 1
            continue

        # Construct mask filename (e.g., 'img_xxxx.png' -> 'mask_xxxx.png')
        mask_name = source_image_path.name.replace('img_', 'mask_', 1)
        mask_path = mask_dir / mask_name

        # Fallback if mask has a different extension or wasn't found initially
        if not mask_path.exists():
             mask_path = mask_dir / (base_name.replace('img_', 'mask_', 1) + ".png")

        if mask_path.exists():
            # Extract patches
            patches, coords, _, mask_arr = extract_smart_patches(
                str(source_image_path),
                str(mask_path),
                patch_size=patch_size,
                stride=stride,
                threshold=threshold
            )

            if not patches:
                continue

            # Save each patch and corresponding mask patch
            for i, (patch_array, (x, y)) in enumerate(zip(patches, coords)):
                # Save image patch
                patch_img = Image.fromarray(patch_array)
                save_name = f"{base_name}_p{i}.png"
                patch_img.save(output_dir / save_name)
                
                # Extract and save mask patch
                if isinstance(patch_size, int):
                    ph, pw = patch_size, patch_size
                else:
                    ph, pw = patch_size
                    
                mask_patch = mask_arr[y:y+ph, x:x+pw]
                mask_patch_img = Image.fromarray(mask_patch)
                mask_save_name = f"mask_{base_name.replace('img_', '')}_p{i}.png"
                mask_patch_img.save(masks_output_dir / mask_save_name)
            
            total_patches_saved += len(patches)
            images_processed += 1
        else:
            tqdm.write(f"Warning: Mask not found for {source_image_path.name}")

    print("\n--- Extraction Summary ---")
    print(f"Images Processed: {images_processed}")
    print(f"Images Skipped:   {images_skipped}")
    print(f"Total Patches Saved in this run: {total_patches_saved}")
    print(f"Patches are located in: {output_dir}")
    print(f"Mask patches are located in: {masks_output_dir}")

#### 2.1.12 apply_patch_masks_to_images

In [35]:
def apply_patch_masks_to_images(patches_dir, masks_dir=None, output_dir=None):
    """Apply mask patches to image patches and save masked outputs.
    Args:
        patches_dir: directory with img_*_p*.png
        masks_dir: directory with mask_*_p*.png (default: patches_dir / 'masks')
        output_dir: destination for masked patches (default: patches_dir / 'masked')
    """
    patches_dir = Path(patches_dir)
    masks_dir = Path(masks_dir) if masks_dir else patches_dir / "masks"
    output_dir = Path(output_dir) if output_dir else patches_dir / "masked"
    output_dir.mkdir(parents=True, exist_ok=True)

    patch_files = sorted(patches_dir.glob("img_*_p*.png"))
    if not patch_files:
        print(f"No patch images found in {patches_dir}")
        return

    applied, skipped = 0, 0
    for patch_path in tqdm(patch_files, desc="Applying mask patches"):
        mask_name = patch_path.stem.replace("img_", "mask_", 1) + patch_path.suffix
        mask_path = masks_dir / mask_name
        out_path = output_dir / patch_path.name

        if out_path.exists():
            skipped += 1
            continue
        if not mask_path.exists():
            tqdm.write(f"Mask not found for {patch_path.name}")
            continue

        img = np.array(Image.open(patch_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))
        if mask.shape[:2] != img.shape[:2]:
            mask = np.array(Image.open(mask_path).convert("L").resize((img.shape[1], img.shape[0]), Image.NEAREST))

        mask_bin = (mask > 127).astype(np.uint8)
        masked = cv2.bitwise_and(img, img, mask=mask_bin)
        Image.fromarray(masked).save(out_path)
        applied += 1

    print(f"Applied masks to {applied} patches. Skipped {skipped} existing outputs. Results in: {output_dir}")

# **3. Run Preprocessing**

## **3.1 Preprocess the Train Set**

In [36]:
datasets_path = os.path.join(os.path.pardir, "an2dl2526c2")

train_data_path = os.path.join(datasets_path, "train_data")
train_labels_path = os.path.join(datasets_path, "train_labels.csv")
test_data_path = os.path.join(datasets_path, "test_data")

CSV_PATH = train_labels_path                # Path to the CSV file with labels
SOURCE_FOLDER = train_data_path

print(f"Dataset path: {datasets_path}")
print(f"Train data path: {train_data_path}")
print(f"Train labels path: {train_labels_path}")
print(f"Test data path: {test_data_path}")
# preprocessing output paths
#preprocessing step 1 output path
GOO_REMOVAL_OUT = os.path.join(datasets_path, "preprocessing_results_masked","train_nogoo")

#preprocessing step 2 output path
SHREK_REMOVAL_OUT = os.path.join(datasets_path, "preprocessing_results_masked","train_noshreks")
SHREKS_OUT = os.path.join(SHREK_REMOVAL_OUT, "train_shreks")
TISSUE_OUT = os.path.join(SHREK_REMOVAL_OUT, "train_tissue")


  # Where the resized unmasked images will be saved
PATCHES_OUT = os.path.join(datasets_path, "preprocessing_results_masked","train_patches")
PATCHES_OUT_MASKED = os.path.join(datasets_path, "preprocessing_results_masked","train_patches_masked")

TARGET_SIZE = (224, 224)                    # Target size for the resized images and masks

Dataset path: ..\an2dl2526c2
Train data path: ..\an2dl2526c2\train_data
Train labels path: ..\an2dl2526c2\train_labels.csv
Test data path: ..\an2dl2526c2\test_data


In [37]:
# Step 1: Remove goo and do not resize images
remove_goo(SOURCE_FOLDER,GOO_REMOVAL_OUT, target_size=None, remove_goo=True, save_masks=True, replacement_color=(195, 195, 195))

Scanning for images in: ..\an2dl2526c2\train_data...


Removing Goo from Images: 100%|██████████| 691/691 [00:00<00:00, 13458.70img/s]

Resizing complete. Processed 0 new images.





In [38]:
clean_and_save_masks(
    goo_masks_dir=os.path.join(GOO_REMOVAL_OUT, "goo_masks"), 
    external_masks_dir=SOURCE_FOLDER, 
    output_dir=os.path.join(GOO_REMOVAL_OUT, "cleaned_masks"), 
    target_size=None
)

Found 691 goo masks. Checking for existing and processing new masks...


Cleaning External Masks: 100%|██████████| 691/691 [00:00<00:00, 18084.20it/s]


Processing complete.
  - Cleaned and saved: 0 masks to ..\an2dl2526c2\preprocessing_results_masked\train_nogoo\cleaned_masks
  - Skipped: 691 masks that already existed.





In [39]:
#Step 2: Discard Shrek Images
shreks_list, tissue_list = analyze_dataset_for_shreks(GOO_REMOVAL_OUT, shrek_dir=SHREKS_OUT, ratio_threshold=0.0125, expected_count=150)

process_classification_results(shreks_list, tissue_list, SHREKS_OUT, TISSUE_OUT, 0.0125, visualize=False)

tissue_image_names = [item['name'] for item in tissue_list]
copy_masks(tissue_image_names, SOURCE_FOLDER, TISSUE_OUT)

# Clean up memory
del shreks_list
del tissue_list
gc.collect()

Found 691 images in '..\an2dl2526c2\preprocessing_results_masked\train_nogoo'. Analyzing for 'Shreks'...


Analyzing for Shreks: 100%|██████████| 691/691 [00:46<00:00, 15.02it/s]
Analyzing for Shreks: 100%|██████████| 691/691 [00:46<00:00, 15.02it/s]


Classified 60 as Shrek
Classified 631 as Tissue
Saving 60 Shrek images to ..\an2dl2526c2\preprocessing_results_masked\train_noshreks\train_shreks...


Saving Shrek Images: 100%|██████████| 60/60 [00:00<00:00, 25695.14it/s]
Saving Shrek Images: 100%|██████████| 60/60 [00:00<00:00, 25695.14it/s]


Saving 631 Tissue images to ..\an2dl2526c2\preprocessing_results_masked\train_noshreks\train_tissue...


Saving Tissue Images: 100%|██████████| 631/631 [00:00<00:00, 27126.59it/s]
Saving Tissue Images: 100%|██████████| 631/631 [00:00<00:00, 27126.59it/s]


Shrek removal visulization OFF.


Copying Masks: 100%|██████████| 631/631 [00:00<00:00, 1476.40it/s]



13

In [40]:
CLEANED_MASKS_DIR = os.path.join(GOO_REMOVAL_OUT, "cleaned_masks")

create_patches_dataset(
    TISSUE_OUT, 
    PATCHES_OUT, 
    mask_dir=CLEANED_MASKS_DIR,
    patch_size=224, 
    stride=224, 
    threshold=0.01
)

Starting patch extraction from ..\an2dl2526c2\preprocessing_results_masked\train_noshreks\train_tissue to ..\an2dl2526c2\preprocessing_results_masked\train_patches...
Found 631 source images.


Processing Images: 100%|██████████| 631/631 [00:01<00:00, 452.10it/s]


--- Extraction Summary ---
Images Processed: 0
Images Skipped:   631
Total Patches Saved in this run: 0
Patches are located in: ..\an2dl2526c2\preprocessing_results_masked\train_patches
Mask patches are located in: ..\an2dl2526c2\preprocessing_results_masked\train_patches\masks





In [41]:

apply_patch_masks_to_images(PATCHES_OUT, output_dir=PATCHES_OUT_MASKED)

Applying mask patches: 100%|██████████| 2788/2788 [00:09<00:00, 280.90it/s]

Applied masks to 2788 patches. Skipped 0 existing outputs. Results in: ..\an2dl2526c2\preprocessing_results_masked\train_patches_masked





## **3.2 Preprocess the Submission Set**

In [43]:
SUBMISSION_PATCHES_OUT = os.path.join(datasets_path, "preprocessing_results_masked","submission_patches")
SUBMISSION_SOURCE_FOLDER = os.path.join(datasets_path, "test_data")

create_patches_dataset(
    SUBMISSION_SOURCE_FOLDER, 
    SUBMISSION_PATCHES_OUT, 
    mask_dir=SUBMISSION_SOURCE_FOLDER,
    patch_size=224, 
    stride=224, 
    threshold=0.01
)

SUBMISSION_PATCHES_OUT_MASKED = os.path.join(datasets_path, "preprocessing_results_masked","submission_patches_masked")

apply_patch_masks_to_images(SUBMISSION_PATCHES_OUT, output_dir=SUBMISSION_PATCHES_OUT_MASKED)

Starting patch extraction from ..\an2dl2526c2\test_data to ..\an2dl2526c2\preprocessing_results_masked\submission_patches...
Found 477 source images.


Processing Images: 100%|██████████| 477/477 [00:00<00:00, 619.76it/s]




--- Extraction Summary ---
Images Processed: 0
Images Skipped:   477
Total Patches Saved in this run: 0
Patches are located in: ..\an2dl2526c2\preprocessing_results_masked\submission_patches
Mask patches are located in: ..\an2dl2526c2\preprocessing_results_masked\submission_patches\masks


Applying mask patches: 100%|██████████| 2052/2052 [00:07<00:00, 274.37it/s]

Applied masks to 2052 patches. Skipped 0 existing outputs. Results in: ..\an2dl2526c2\preprocessing_results_masked\submission_patches_masked



