# PANDA Slide Vector Extraction Pipeline using Trident + Titan
## Step 1: Deduplicate slides using image hashing
## Step 2: Remove penmarked slides
## Step 3: Extract slide vectors (tissue segmentation, patch encoding, MIL aggregation)

# Install all dependencies

In [None]:
!pip install git+https://github.com/mahmoodlab/TRIDENT.git --quiet
!apt-get install -y openslide-tools > /dev/null
!pip install openslide-python imagehash --quiet
!pip install huggingface_hub --quiet



In [None]:
from huggingface_hub import login
login("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")


## Import all dependencies

In [None]:
# Import all statements
import time
import os
import json
import torch
import h5py
from pathlib import Path
from PIL import Image
import numpy as np
import imagehash
from openslide import OpenSlide
from trident import load_wsi
from trident.segmentation_models import segmentation_model_factory
from trident.patch_encoder_models import encoder_factory as patch_factory
from trident.slide_encoder_models import encoder_factory as slide_factory
import pickle
import geopandas as gpd
import shutil

# Monkey-patch for compatibility
if not hasattr(gpd.GeoSeries, "union_all"):
    gpd.GeoSeries.union_all = lambda self, *args, **kwargs: self.unary_union


## Configurations

In [None]:
class Config:
    def __init__(self,
                 input_dir="/kaggle/input/prostate-cancer-grade-assessment/train_images",
                 output_dir="/kaggle/working/slide_vectors",
                 clean_slide_list="/kaggle/working/clean_slides.json",
                 do_deduplication=True,
                 do_penmark_check=True,
                 preprocessing = False,
                 hash_threshold=0.9,
                 patch_mag=10,
                 patch_size=224,
                 patch_encoder_name="phikon",
                 slide_encoder_name="titan",
                 device=None):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.clean_slide_list = clean_slide_list
        self.do_deduplication = do_deduplication
        self.preprocessing = preprocessing
        self.do_penmark_check = do_penmark_check
        self.hash_threshold = hash_threshold
        self.patch_mag = patch_mag
        self.patch_size = patch_size
        self.patch_encoder_name = patch_encoder_name
        self.slide_encoder_name = slide_encoder_name
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")


## Removing penmarks

In [None]:
def has_penmarks(slide_path):
    try:
        slide = OpenSlide(str(slide_path))
        img = slide.read_region((0, 0), 2, (256, 256)).convert("RGB")
        arr = np.array(img)
        pen_pixels = np.sum((arr[:, :, 2] > 150) & (arr[:, :, 0] < 100))  # Bright blue
        return pen_pixels > 500
    except:
        return True  

## Preprocessing by image hash and removing penmarks

In [None]:
def preprocess_slides(cfg):
    print("Start Preprocessing ...")
    slide_paths = list(Path(cfg.input_dir).glob("*.tiff"))
    print(f"found {len(slide_paths)} slides")

    # step 1: found duplicates using imagehash
    hashes = {}
    unique_slides = []
    start_total = time.time()
    
    if cfg.do_deduplication:
        for slide in slide_paths:  
            try:
                img = Image.open(slide).resize((256, 256))
                h = imagehash.average_hash(img)
                if not any(h - v < (1 - cfg.hash_threshold) * len(h.hash) ** 2 for v in hashes.values()):
                    hashes[slide.name] = h
                    unique_slides.append(slide)
            except Exception as e:
                print(f" Skipping {slide.name}: {e}")
        print(f"After deduplication: {len(unique_slides)}")
        slide_paths = unique_slides  # replace original list with deduplicated slides

    # step 2: removing penmark slides
    if cfg.do_penmark_check:
        clean_slides = [s.name for s in slide_paths if not has_penmarks(s)]
    else:
        clean_slides = [s.name for s in slide_paths]
    print(f"After Penmark removal: {len(clean_slides)} slides ")

    # Step 3: save all clean slides in a json file
    with open(cfg.clean_slide_list, "w") as f:
        json.dump(clean_slides, f)
    print(f" Saved clean slide list to {cfg.clean_slide_list}")

    # Timing summary
    end_total = time.time()
    total_time = end_total - start_total
    total_checked = len(slide_paths)
    avg_time = total_time / total_checked if total_checked else 0

    print(f"\n Total preprocessing time (dedup + penmark): {total_time:.2f} seconds")
    print(f" Average time per slide: {avg_time:.2f} seconds")

    

## Extracting features

In [None]:
def extract_features(cfg):
    print("Starting slide vector extraction...")

    # Step 1: Determine which slide paths to use
    if cfg.preprocessing:
        with open(cfg.clean_slide_list, "r") as f:
            slide_names = json.load(f)
        slide_paths = [Path(cfg.input_dir) / name for name in slide_names]
    else:
        slide_paths = list(Path(cfg.input_dir).glob("*.tiff"))

    slide_paths.sort(key=lambda x: x.name)

    # step 2 : create a saving directory, and mkdir because otherwise it will crash since it doesn't exist the first time
    save_dir = Path(cfg.output_dir) / "train"
    save_dir.mkdir(parents=True, exist_ok=True)

    # Step 3: Load Trident models
    seg = segmentation_model_factory("hest", confidence_thresh=0.5)
    patch_encoder = patch_factory(cfg.patch_encoder_name).eval().to(cfg.device)
    slide_encoder = slide_factory(cfg.slide_encoder_name).eval().to(cfg.device)

    # Step 4: Create dedicated folder for combined .pt files, using encoder combo in name
    encoder_combo = f"{cfg.patch_encoder_name}_{cfg.slide_encoder_name}"
    combined_vectors_dir = Path(cfg.output_dir) / "combined"
    combined_vectors_dir.mkdir(parents=True, exist_ok=True)
    

    # Step 5: Loop over slides
    slide_vectors = {}
    slide_times = []

    # start_idx = 5000  # Skip first 5000 slides
    # for idx, slide_path in enumerate(slide_paths[start_idx:], start_idx + 1):

    start_idx = 9000
    end_idx = len(slide_paths)
    
    for idx, slide_path in enumerate(slide_paths[start_idx:end_idx], start_idx + 1):
        out_path = save_dir / f"{slide_path.stem}_{encoder_combo}.pt"
        if out_path.exists():
            print(f" Skipping {slide_path.name}, already processed.")
            continue

        try:
            start_time = time.time()
            job_dir = save_dir / slide_path.stem
            job_dir.mkdir(parents=True, exist_ok=True)

            print(f" Processing {slide_path.name}...")
            wsi = load_wsi(slide_path, lazy_init=False)

            # Tissue segmentation
            wsi.segment_tissue(seg, seg.target_mag, job_dir, cfg.device)

            # Extract patch coordinates
            coords_path = wsi.extract_tissue_coords(
                target_mag=cfg.patch_mag,
                patch_size=cfg.patch_size,
                save_coords=str(job_dir),
            )

            # Extract patch features (saved to job_dir/patches/)
            patch_features_path = wsi.extract_patch_features(
                patch_encoder=patch_encoder,
                coords_path=str(coords_path),
                save_features=str(job_dir),
                device=cfg.device,
                batch_limit=32,
            )

            # create embeddings directory and extract slide-level features
            embeddings_dir = job_dir / "embeddings"
            embeddings_dir.mkdir(parents=True, exist_ok=True)

            slide_vector_path = wsi.extract_slide_features(
                str(patch_features_path),
                slide_encoder,
                str(embeddings_dir),
                cfg.device
            )

            # convert the h5 file into pt for using pytorch
            embeddings_h5 = embeddings_dir / f"{slide_path.stem}.h5"
            if embeddings_h5.exists():
                with h5py.File(embeddings_h5, "r") as f:
                    features = f["features"][:]
                    tensor = torch.from_numpy(features)
                    slide_vectors[slide_path.stem] = tensor
                    torch.save(tensor, out_path)
            else:
                print(f" Warning: embeddings .h5 file not found for {slide_path.stem}")

            # save metadata
            with open(job_dir / "meta.json", "w") as f:
                json.dump({
                    "slide_name": slide_path.name,
                    "patch_encoder": cfg.patch_encoder_name,
                    "slide_encoder": cfg.slide_encoder_name,
                    "patch_mag": cfg.patch_mag,
                    "patch_size": cfg.patch_size
                }, f)

            elapsed = time.time() - start_time
            slide_times.append(elapsed)
            print(f" Done: {slide_path.name} in {elapsed:.2f} seconds")

            # average every 10 slides
            if len(slide_times) % 10 == 0:
                avg_time = sum(slide_times[-10:]) / 10
                start_idx = len(slide_times) - 9
                end_idx = len(slide_times)
                print(f" Average time for slides {start_idx}-{end_idx}: {avg_time:.2f} seconds")

    
            # save intermediate combined .pt file every 100 slides
            if idx % 100 == 0:
                dict_save_path = combined_vectors_dir / f"slide_vectors_{idx}_{encoder_combo}.pt"
                torch.save(slide_vectors, dict_save_path)
                print(f" Saved intermediate {idx} slide vectors to {dict_save_path}")
                slide_vectors.clear()  # Clear dictionary to free memory


        except Exception as e:
            print(f" Error processing {slide_path.name}: {e}")

    # final save of any remaining slide vectors
    if slide_vectors:
        dict_save_path = combined_vectors_dir / f"slide_vectors_final_{encoder_combo}.pt"
        torch.save(slide_vectors, dict_save_path)
        print(f" Saved final slide vectors to {dict_save_path}")

    print(" Combining all intermediate .pt files into one final .pt file...")
    final_slide_vectors = {}
    for pt_file in sorted(combined_vectors_dir.glob("slide_vectors_*.pt")):
        chunk = torch.load(pt_file, map_location="cpu")
        final_slide_vectors.update(chunk)
    final_pt_path = combined_vectors_dir / f"slide_vectors_ALL_{encoder_combo}.pt"
    torch.save(final_slide_vectors, final_pt_path)
    print(f"Final combined .pt file saved at {final_pt_path} with {len(final_slide_vectors)} slides!")


    # final summary
    if slide_times:
        total_avg = sum(slide_times) / len(slide_times)
        print(f"\n Extracted {len(slide_paths)} slide vectors to {cfg.output_dir}")
        print(f" Final average processing time per slide: {total_avg:.2f} seconds")
    else:
        print(" No slides were processed.")


In [None]:
def main():
    # Initial config setup
    cfg = Config(
        do_deduplication=False,
        do_penmark_check=False,
        preprocessing=True  # this will be adjusted
    )

    # Disable preprocessing if no checks are requested
    if not cfg.do_deduplication and not cfg.do_penmark_check:
        print(" Neither removing duplicates nor penmarks removal")
        cfg.preprocessing = False

    # Run pipeline
    if cfg.preprocessing:
        preprocess_slides(cfg)
    else:
        print(" Skip preprocessing")

    extract_features(cfg)

#  Entry point
if __name__ == "__main__":
    main()
