In [2]:
import os
import re
import json
import random
from pathlib import Path
import xml.etree.ElementTree as ET
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tifffile as tiff
import openslide
import rasterio
from rasterio.windows import Window
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import models, transforms, datasets
from tqdm import tqdm
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
from tifffile import imread
import glob
import gc

In [3]:
happiny_folder = Path(r"C:\Users\luukn\AIMI_MONKEY2\monkey-training")
orig_dir = happiny_folder / "images/pas-original"
cpg_dir = happiny_folder / "images/pas-cpg"
diagnostic_dir = happiny_folder / "images/pas-diagnostic"
xml_dir = happiny_folder / "annotations/xml"

# Extracting patches for training 
(including json file for within patch annotations)

In [6]:
# === CONFIGURATION ===
PATCH_SIZE = 256
BBOX_SIZE = 32
NUM_JSONS = 81
NUM_PATCHES_PER_FILE = 300

json_pixel_dir = happiny_folder / "annotations/json_pixel"
PATCHES_DIR = happiny_folder / "patches_with_annotations"
modalities = {
    "pas-original": happiny_folder / "images/pas-original",
    "pas-cpg": happiny_folder / "images/pas-cpg",
    "pas-diagnostic": happiny_folder / "images/pas-diagnostic",
}

PATCHES_DIR.mkdir(parents=True, exist_ok=True)

# === UTILITIES ===

def get_sample_id(json_path):
    match = re.search(r'P\d{6}', json_path.name)
    return match.group(0) if match else None

def load_json_annotations(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    return [(int(p["point"][0]), int(p["point"][1])) for p in data.get("points", [])]

def check_blackwhite_space(patch):
    white_threshold = 240
    black_threshold = 15
    white_pixels = np.sum(patch > white_threshold)
    black_pixels = np.sum(patch < black_threshold)
    total_pixels = patch.size
    return (white_pixels / total_pixels) > 0.9 or (black_pixels / total_pixels) > 0.9

def is_low_contrast(patch, std_thresh=15):
    # check if the patch is low contrast by calculating the standard deviation of the grayscale values
    gray = np.dot(patch[..., :3], [0.2989, 0.5870, 0.1140])
    return np.std(gray) < std_thresh

def check_background(patch):
    if check_blackwhite_space(patch):
        return True
    if is_low_contrast(patch):
        return True
    return False

def extract_patch_and_boxes(tif_path, center_x, center_y, all_coords, patch_size=PATCH_SIZE, box_size=BBOX_SIZE):
    half_patch = patch_size // 2
    half_box = box_size // 2

    xmin = max(center_x - half_patch, 0)
    ymin = max(center_y - half_patch, 0)

    with rasterio.open(tif_path) as src:
        max_width, max_height = src.width, src.height
        xmin = min(xmin, max_width - patch_size)
        ymin = min(ymin, max_height - patch_size)
        window = Window(xmin, ymin, patch_size, patch_size)
        patch = src.read(window=window)
        patch = np.transpose(patch, (1, 2, 0))  # (H, W, C)

    # Convert global coords to local patch coords
    boxes = []
    for x, y in all_coords:
        if xmin <= x < xmin + patch_size and ymin <= y < ymin + patch_size:
            x_local = x - xmin
            y_local = y - ymin
            box = [
                int(x_local - half_box),
                int(y_local - half_box),
                int(x_local + half_box),
                int(y_local + half_box)
            ]
            # Clip to patch bounds
            box = [
                max(0, box[0]),
                max(0, box[1]),
                min(patch_size, box[2]),
                min(patch_size, box[3])
            ]
            boxes.append(box)

    return patch, boxes

# === MAIN PATCH EXTRACTION ===

def save_selected_patches():
    json_files = list(json_pixel_dir.glob("*inflammatory-cells.json"))
    print(f"🔍 Found {len(json_files)} inflammatory cell JSON files")

    json_files_sampled = random.sample(json_files, min(NUM_JSONS, len(json_files)))

    for modality, tif_dir in modalities.items():
        output_dir = PATCHES_DIR / modality
        img_dir = output_dir / "images"
        ann_dir = output_dir / "annotations"
        img_dir.mkdir(parents=True, exist_ok=True)
        ann_dir.mkdir(parents=True, exist_ok=True)
        print(f"\n📂 Processing modality: {modality}")

        for json_file in json_files_sampled:
            all_coords = load_json_annotations(json_file)
            if not all_coords:
                continue

            sample_id = get_sample_id(json_file)
            if not sample_id:
                continue

            tif_matches = list(tif_dir.glob(f"*{sample_id}*.tif"))
            if not tif_matches:
                print(f"⚠️ No TIFF file found for sample ID {sample_id} in {modality}")
                continue

            tif_path = tif_matches[0]
            print(f"  Using {tif_path.name} with {len(all_coords)} annotated cells")

            selected_coords = all_coords[:NUM_PATCHES_PER_FILE]
            for i, (x, y) in enumerate(selected_coords):
                try:
                    patch, boxes = extract_patch_and_boxes(
                        tif_path, x, y, all_coords, PATCH_SIZE, BBOX_SIZE
                    )
                    patch_img = Image.fromarray(patch.astype(np.uint8))
                    if check_background(patch):
                        print(f"⚠️ Patch {i} at ({x},{y}) is too white/black, skipping.")
                        continue

                    patch_name = f"{sample_id}_patch{i}.png"
                    patch_img.save(img_dir / patch_name)

                    # Save bounding boxes in JSON
                    box_data = [{"bbox": box, "label": "inflammatory-cell"} for box in boxes]
                    json_path = ann_dir / f"{patch_name.replace('.png', '.json')}"
                    with open(json_path, "w") as f:
                        json.dump(box_data, f)

                except Exception as e:
                    print(f"    ⚠️ Failed to extract patch {i} at ({x},{y}): {e}")

if __name__ == "__main__":
    save_selected_patches()
    print("\n✅ Patch extraction with annotations complete.")


🔍 Found 81 inflammatory cell JSON files

📂 Processing modality: pas-original
⚠️ No TIFF file found for sample ID P000029 in pas-original
⚠️ No TIFF file found for sample ID P000035 in pas-original
⚠️ No TIFF file found for sample ID P000031 in pas-original
  Using D_P000014_PAS_Original.tif with 948 annotated cells
⚠️ Patch 0 at (63576,85682) is too white/black, skipping.
⚠️ Patch 1 at (63318,85694) is too white/black, skipping.
⚠️ Patch 2 at (63198,85872) is too white/black, skipping.
⚠️ Patch 3 at (63230,86002) is too white/black, skipping.
⚠️ Patch 4 at (68862,85144) is too white/black, skipping.
⚠️ Patch 5 at (68612,85246) is too white/black, skipping.
⚠️ Patch 6 at (68704,85314) is too white/black, skipping.
⚠️ Patch 7 at (68946,85324) is too white/black, skipping.
⚠️ Patch 8 at (68514,85462) is too white/black, skipping.
⚠️ Patch 9 at (69182,85718) is too white/black, skipping.
⚠️ Patch 10 at (68732,85730) is too white/black, skipping.
⚠️ Patch 11 at (68678,85956) is too white/bl

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


⚠️ Patch 24 at (64102,86746) is too white/black, skipping.
⚠️ Patch 25 at (62872,86828) is too white/black, skipping.
⚠️ Patch 26 at (64626,86906) is too white/black, skipping.
⚠️ Patch 27 at (62628,86954) is too white/black, skipping.
⚠️ Patch 28 at (64016,87102) is too white/black, skipping.
⚠️ Patch 29 at (63990,87138) is too white/black, skipping.
⚠️ Patch 30 at (62802,87354) is too white/black, skipping.
⚠️ Patch 31 at (63898,87362) is too white/black, skipping.
⚠️ Patch 32 at (62286,87394) is too white/black, skipping.
⚠️ Patch 33 at (63462,87394) is too white/black, skipping.
⚠️ Patch 34 at (62518,87412) is too white/black, skipping.
⚠️ Patch 35 at (63266,87594) is too white/black, skipping.
⚠️ Patch 36 at (63466,87632) is too white/black, skipping.
⚠️ Patch 37 at (63010,87762) is too white/black, skipping.
⚠️ Patch 38 at (64024,87840) is too white/black, skipping.
⚠️ Patch 39 at (61622,87848) is too white/black, skipping.
⚠️ Patch 40 at (63196,87884) is too white/black, skippin

# Extracting patches for testing

In [3]:
processed_imgs = [r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000001_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000002_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000003_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000004_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000005_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000006_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000007_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000011_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000014_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000016_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000017_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000018_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000020_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000021_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000022_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000024_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000029_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000030_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000031_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000032_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000033_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000034_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000035_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000036_PAS_CPG.tif',
                  r'C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000037_PAS_CPG.tif',]
#processed_imgs = []

In [4]:
patch_size = 256
patches_per_image = 50
output_folder = Path(r"C:\Users\luukn\AIMI_MONKEY2\monkey-training\patches_newest\cpg_test")
output_folder.mkdir(parents=True, exist_ok=True)


def check_blackwhite_space(patch):
    white_threshold = 240
    black_threshold = 15
    white_pixels = np.sum(patch > white_threshold)
    black_pixels = np.sum(patch < black_threshold)
    total_pixels = patch.size
    return (white_pixels / total_pixels) > 0.95 or (black_pixels / total_pixels) > 0.95

def is_low_contrast(patch, std_thresh=15):
    # check if the patch is low contrast by calculating the standard deviation of the grayscale values
    gray = np.dot(patch[..., :3], [0.2989, 0.5870, 0.1140])
    return np.std(gray) < std_thresh

def check_background(patch):
    if check_blackwhite_space(patch):
        return True
    if is_low_contrast(patch):
        return True
    return False


def process_image(img_path, max_patches):
    print(f"🔄 Processing {img_path}")
    img = imread(str(img_path), maxworkers=1)

    print(f"Image loaded, shape: {img.shape}")

    if img.ndim == 2:
        img = np.stack([img]*3, axis=-1)
    elif img.shape[0] in [3, 4] and img.ndim == 3:
        img = np.moveaxis(img, 0, -1)

    height, width = img.shape[:2]

    print('getting valid coordinates for patches...')

    valid_coords = []
    for y in range(0, height - patch_size + 1, patch_size):
        for x in range(0, width - patch_size + 1, patch_size):
            patch = img[y:y+patch_size, x:x+patch_size]
            if not check_background(patch):
                valid_coords.append((x, y))
    
    print(f"Found {len(valid_coords)} valid coordinates for patches, out of which {max_patches} will be randomly sampled...")

    selected_coords = random.sample(valid_coords, min(max_patches, len(valid_coords)))

    print('Extracting patches utilizing the selected coordinates...')
    for x, y in selected_coords:
        patch = img[y:y+patch_size, x:x+patch_size]
        patch_img = Image.fromarray(patch.astype(np.uint8))
        patch_img.save(output_folder / f"{Path(img_path).stem}_x{x}_y{y}.png")

    print(f"✅ {len(selected_coords)} patches saved to {output_folder}")


    # garbage collection to free memory
    del img
    gc.collect()

    return len(selected_coords)

def main():
    image_paths = glob.glob(str(cpg_dir / "*.tif"))

    total_patches = 0
    for img_path in image_paths:
        if img_path not in processed_imgs:
            count = process_image(img_path, patches_per_image)
            total_patches += count
            print('--------------------------------------------------------------------------------')

    print(f"✅ Done. Total patches saved: {total_patches}")

main()

🔄 Processing C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\A_P000038_PAS_CPG.tif
Image loaded, shape: (184832, 86272, 3)
getting valid coordinates for patches...
Found 58117 valid coordinates for patches, out of which 50 will be randomly sampled...
Extracting patches utilizing the selected coordinates...
✅ 50 patches saved to C:\Users\luukn\AIMI_MONKEY2\monkey-training\patches_newest\cpg_test
--------------------------------------------------------------------------------
🔄 Processing C:\Users\luukn\AIMI_MONKEY2\monkey-training\images\pas-cpg\B_P000001_PAS_CPG.tif
Image loaded, shape: (120832, 80896, 3)
getting valid coordinates for patches...
Found 11910 valid coordinates for patches, out of which 50 will be randomly sampled...
Extracting patches utilizing the selected coordinates...
✅ 50 patches saved to C:\Users\luukn\AIMI_MONKEY2\monkey-training\patches_newest\cpg_test
--------------------------------------------------------------------------------
🔄 Processing C:\User

# Extracting patches from ROIs from the train dirs

In [None]:
PATCH_SIZE = 256
NUM_JSONS = 81
NUM_PATCHES_PER_FILE = 300

#happiny_folder = Path("/data/temporary/archives/lung/generator/MONKEY_challenge")
json_pixel_dir = happiny_folder / "annotations/json_pixel"
PATCHES_DIR = happiny_folder / "patches_newest"  
modalities = {
    "pas-original": happiny_folder / "images/pas-original",
    "pas-cpg": happiny_folder / "images/pas-cpg",
    "pas-diagnostic": happiny_folder / "images/pas-diagnostic",
}

PATCHES_DIR.mkdir(parents=True, exist_ok=True)

def get_sample_id(json_path):
    match = re.search(r'P\d{6}', json_path.name)
    return match.group(0) if match else None

def load_json_annotations(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    return [(int(p["point"][0]), int(p["point"][1])) for p in data.get("points", [])]

def check_white_space(patch):
    white_threshold = 240  # Adjust this threshold as needed
    black_threshold = 15
    white_pixels = np.sum(patch > white_threshold)
    black_pixels = np.sum(patch < black_threshold)
    total_pixels = patch.size
    return (white_pixels / total_pixels) > 0.95 or  (black_pixels / total_pixels) > 0.95 # Adjust the ratio as needed

def extract_patch(tif_path, center_x, center_y, patch_size=PATCH_SIZE):
    half = patch_size // 2
    xmin = max(center_x - half, 0)
    ymin = max(center_y - half, 0)
    with rasterio.open(tif_path) as src:
        max_width, max_height = src.width, src.height
        xmin = min(xmin, max_width - patch_size)
        ymin = min(ymin, max_height - patch_size)
        window = Window(xmin, ymin, patch_size, patch_size)
        patch = src.read(window=window)
        patch = np.transpose(patch, (1, 2, 0))  # (H, W, C)
    return patch

def save_selected_patches():
    json_files = list(json_pixel_dir.glob("*inflammatory-cells.json"))
    print(f"🔍 Found {len(json_files)} inflammatory cell JSON files")

    # Randomly sample JSON files without replacement
    json_files_sampled = random.sample(json_files, min(NUM_JSONS, len(json_files)))

    for modality, tif_dir in modalities.items():
        output_dir = PATCHES_DIR / modality
        output_dir.mkdir(parents=True, exist_ok=True)
        print(f"\nProcessing modality: {modality}")

        for json_file in json_files_sampled:
            coords = load_json_annotations(json_file)[:NUM_PATCHES_PER_FILE]
            sample_id = get_sample_id(json_file)
            if not sample_id:
                continue

            tif_matches = list(tif_dir.glob(f"*{sample_id}*.tif"))
            if not tif_matches:
                print(f"  ⚠️ No TIFF file found for sample ID {sample_id} in {modality}")
                continue

            tif_path = tif_matches[0]  # Take the first match
            print(f"  Using {tif_path.name} with {len(coords)} ROIs")

            for i, (x, y) in enumerate(coords):
                try:
                    patch = extract_patch(tif_path, x, y, PATCH_SIZE)
                    patch_img = Image.fromarray(patch.astype(np.uint8))
                    if check_white_space(patch):
                        print(f"    ⚠️ Patch {i} at ({x},{y}) is too white or black, skipping.")
                        continue
                    else:
                        patch_name = f"{sample_id}_patch{i}.png"
                        patch_img.save(output_dir / patch_name)
                except Exception as e:
                    print(f"    ⚠️ Failed to extract patch {i} at ({x},{y}): {e}")

if __name__ == "__main__":
    save_selected_patches()
    print("\n✅ Patch extraction complete.")


🔍 Found 81 inflammatory cell JSON files

Processing modality: pas-original
  ⚠️ No TIFF file found for sample ID P000035 in pas-original
  ⚠️ No TIFF file found for sample ID P000038 in pas-original
  ⚠️ No TIFF file found for sample ID P000032 in pas-original
  ⚠️ No TIFF file found for sample ID P000029 in pas-original
  ⚠️ No TIFF file found for sample ID P000037 in pas-original
  ⚠️ No TIFF file found for sample ID P000028 in pas-original
  Using D_P000017_PAS_Original.tif with 300 ROIs
    ⚠️ Patch 0 at (39399,71509) is too white or black, skipping.
    ⚠️ Patch 1 at (39409,71481) is too white or black, skipping.
    ⚠️ Patch 2 at (39511,71515) is too white or black, skipping.
    ⚠️ Patch 3 at (39486,71349) is too white or black, skipping.
    ⚠️ Patch 4 at (39327,71313) is too white or black, skipping.
    ⚠️ Patch 5 at (39812,71470) is too white or black, skipping.
    ⚠️ Patch 6 at (39657,71141) is too white or black, skipping.
    ⚠️ Patch 7 at (39682,71197) is too white or b

# Extracting patches from ROIs from the test dirs

In [None]:
# Configuration
PATCH_SIZE = 256
NUM_PATCHES_PER_FILE = 5

# Paths
#happiny_folder = Path("/data/temporary/archives/lung/generator/MONKEY_challenge")
#json_pixel_dir = happiny_folder / "annotations/json_pix"
PATCHES_DIR = happiny_folder / "patches_test"

cpg_dir = happiny_folder / "images/pas-cpg"

# Create output directories for cpg modalities + cell types
for cell_type in ["inflammatory-cells", "lymphocytes", "monocytes"]:
    (PATCHES_DIR / "cpg" / cell_type).mkdir(parents=True, exist_ok=True)

# Helpers
def get_sample_id(json_path):
    match = re.search(r'P\d{6}', json_path.name)
    return match.group(0) if match else None

def load_json_annotations(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    return [(int(p["point"][0]), int(p["point"][1])) for p in data.get("points", [])]

def extract_patch(tif_path, center_x, center_y, patch_size=PATCH_SIZE):
    half = patch_size // 2
    xmin = max(center_x - half, 0)
    ymin = max(center_y - half, 0)
    with rasterio.open(tif_path) as src:
        max_width, max_height = src.width, src.height
        xmin = min(xmin, max_width - patch_size)
        ymin = min(ymin, max_height - patch_size)
        window = Window(xmin, ymin, patch_size, patch_size)
        patch = src.read(window=window)
        patch = np.transpose(patch, (1, 2, 0))  # (H, W, C)
    return patch

def save_cpg_patches_for_cell_types():
    # Cell types to process
    cell_types = ["inflammatory-cells", "lymphocytes", "monocytes"]
    
    # For each cell type, find relevant json files
    for cell_type in cell_types:
        json_files = sorted(json_pixel_dir.glob(f"*{cell_type}.json"))
        print(f"Found {len(json_files)} {cell_type} JSON files")

        for json_file in json_files:
            coords = load_json_annotations(json_file)[:NUM_PATCHES_PER_FILE]
            sample_id = get_sample_id(json_file)
            if not sample_id:
                continue

            tif_matches = list(cpg_dir.glob(f"*{sample_id}*.tif"))
            if not tif_matches:
                continue

            tif_path = tif_matches[0]  # Use the first match
            print(f"Processing {tif_path.name} with {len(coords)} {cell_type} ROIs")

            output_dir = PATCHES_DIR / "cpg" / cell_type

            for i, (x, y) in enumerate(coords):
                patch = extract_patch(tif_path, x, y, PATCH_SIZE)
                patch_img = Image.fromarray(patch.astype(np.uint8))
                patch_name = f"{sample_id}_{cell_type}_patch{i}.png"
                patch_img.save(output_dir / patch_name)

if __name__ == "__main__":
    save_cpg_patches_for_cell_types()
    print("✅ Finished extracting cpg patches for inflammatory, lymphocytes, and monocytes.")


Found 81 inflammatory-cells JSON files
Processing B_P000001_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing D_P000002_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing B_P000003_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing D_P000004_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing A_P000005_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing D_P000006_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing A_P000007_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing B_P000011_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing D_P000014_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing D_P000016_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing A_P000017_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing A_P000018_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing A_P000020_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing A_P000021_PAS_CPG.tif with 5 inflammatory-cells ROIs
Processing C_P000022_PAS_CPG.tif with 5 inflammatory-cells ROIs
P

# Model building

In [6]:
train_dirs = [
    PATCHES_DIR / "pas-original",
    PATCHES_DIR / "pas-diagnostic"
]
test_dirs = [
    PATCHES_DIR / "cpg" / "inflammatory-cells",
    PATCHES_DIR / "cpg" / "monocytes",
    PATCHES_DIR / "cpg" / "lymphocytes"
]
cpg_inflammatory_dir = PATCHES_DIR / "cpg" / "inflammatory-cells"

class_names_all = ["inflammatory-cells", "monocytes", "lymphocytes"]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

def wrap_with_label_folder(path, label="dummy"):
    temp = path.parent / f"{path.name}_wrapped"
    temp.mkdir(exist_ok=True)
    label_path = temp / label
    label_path.mkdir(exist_ok=True)
    for img in path.glob("*.png"):
        symlink = label_path / img.name
        if not symlink.exists():
            symlink.symlink_to(img)
    return temp

# Create dataloaders from directories
def create_dataloader(dirs, transform, batch_size=32, shuffle=True):
    wrapped_dirs = [wrap_with_label_folder(d) for d in dirs]
    datasets_list = [datasets.ImageFolder(d, transform=transform) for d in wrapped_dirs]
    combined_dataset = ConcatDataset(datasets_list)
    loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=shuffle)
    return loader

# Extract features for a given dataset loader and encoder model
def extract_features(encoder, loader, device):
    encoder.eval()
    all_feats = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            feats = encoder(imgs).cpu()
            all_feats.append(feats)
            all_labels.append(labels)
    X = torch.cat(all_feats).numpy()
    y = torch.cat(all_labels).numpy()
    return X, y

# Pretrain encoder on train_dirs (pas-original + pas-diagnostic)
def pretrain_encoder(train_dirs, transform, device, epochs=5):
    train_loader = create_dataloader(train_dirs, transform, batch_size=32, shuffle=True)
    encoder = models.resnet18(pretrained=True)
    encoder.fc = nn.Identity()  
    encoder = encoder.to(device)
    optimizer = optim.Adam(encoder.parameters(), lr=1e-4)
    encoder.train()
    for epoch in range(epochs):
        for imgs, _ in tqdm(train_loader, desc=f"Pretraining Epoch {epoch+1}"):
            imgs = imgs.to(device)
            features = encoder(imgs)
            loss = features.norm(dim=1).mean() 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return encoder

# Fine-tune encoder on only CPG inflammatory-cells (binary classification)
def finetune_on_cpg_inflammatory(encoder, cpg_dir, transform, device, epochs=3):
    # Prepare dataset with label 0 for inflammatory cells (binary classification)
    wrapped_dir = wrap_with_label_folder(cpg_dir, label="inflammatory-cells")
    dataset = datasets.ImageFolder(wrapped_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # Add classification head for binary classification
    encoder.fc = nn.Linear(512, 2)
    encoder = encoder.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(encoder.parameters(), lr=1e-4)
    
    encoder.train()
    for epoch in range(epochs):
        for imgs, labels in tqdm(loader, desc=f"Fine-tuning CPG inflammatory Epoch {epoch+1}"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = encoder(imgs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return encoder

# Prepare test dataloader for all CPG classes with correct labels
def create_cpg_test_loader(test_dirs, transform, batch_size=32):
    wrapped_dirs = []
    for i, d in enumerate(test_dirs):
        wrapped_dirs.append(wrap_with_label_folder(d, label=str(i)))  # label folder name = class index
    datasets_list = [datasets.ImageFolder(d, transform=transform) for d in wrapped_dirs]
    combined_dataset = ConcatDataset(datasets_list)
    loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=False)
    return loader

# Extract features with fine-tuned encoder for multi-class classification
def extract_features_multiclass(encoder, loader, device):
    encoder.eval()
    all_feats = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            feats = encoder(imgs)
            all_feats.append(feats.cpu())
            all_labels.append(labels.cpu())
    X = torch.cat(all_feats).numpy()
    y = torch.cat(all_labels).numpy()
    return X, y

In [7]:
# --------- MAIN PIPELINE ---------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model A: Pretrain encoder on pas-original + pas-diagnostic only
print("Pretraining encoder (Model A)...")
encoder_A = pretrain_encoder(train_dirs, transform, device)

# Test Model A on all CPG classes with RF classifier on frozen encoder features
print("\nEvaluating Model A (no fine-tuning on CPG)...")
cpg_test_loader = create_cpg_test_loader(test_dirs, transform, batch_size=32)
X_A, y_A = extract_features(encoder_A, cpg_test_loader, device)

clf_A = RandomForestClassifier(n_estimators=100, random_state=42)
clf_A.fit(X_A, y_A)
y_pred_A = clf_A.predict(X_A)

print("Model A Confusion Matrix:")
print(confusion_matrix(y_A, y_pred_A))
print("Model A Classification Report:")
print(classification_report(y_A, y_pred_A, target_names=class_names_all))

Pretraining encoder (Model A)...


OSError: [WinError 1314] A required privilege is not held by the client: 'C:\\Users\\luukn\\AIMI_MONKEY2\\monkey-training\\patches\\pas-original\\P000001_patch0.png' -> 'C:\\Users\\luukn\\AIMI_MONKEY2\\monkey-training\\patches\\pas-original_wrapped\\dummy\\P000001_patch0.png'

In [None]:
# Model B: Fine-tune encoder on CPG inflammatory-cells (binary) 
print("\nFine-tuning encoder on CPG inflammatory-cells (Model B)...")
encoder_B = finetune_on_cpg_inflammatory(encoder_A, cpg_inflammatory_dir, transform, device, epochs=3)

# For Model B testing, remove classification head for feature extraction (to get embeddings)
encoder_B.fc = nn.Identity()
encoder_B = encoder_B.to(device)

# Extract features from all CPG classes with fine-tuned encoder
print("\nEvaluating Model B (fine-tuned on CPG inflammatory-cells)...")
X_B, y_B = extract_features(encoder_B, cpg_test_loader, device)

clf_B = RandomForestClassifier(n_estimators=100, random_state=42)
clf_B.fit(X_B, y_B)
y_pred_B = clf_B.predict(X_B)

print("Model B Confusion Matrix:")
print(confusion_matrix(y_B, y_pred_B))
print("Model B Classification Report:")
print(classification_report(y_B, y_pred_B, target_names=class_names_all))