In [None]:
#Generate color-normalized WSI overlays with attention heatmaps based on patch-level attention scores for visual interpretability.
#================================================================================================================================
import os
import openslide
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pandas as pd
import numpy as np
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from histomicstk.preprocessing.color_normalization import reinhard
from PIL import Image

# === Color Normalization Parameters (TCGA-A2-A3XS-DX1, Amgad et al., 2019) ===
cnorm = {
    'mu': np.array([8.74108109, -0.12440419, 0.0444982]),
    'sigma': np.array([0.6135447, 0.10989545, 0.0286032]),
}

def normalize_wsi(wsi_image):
    """Apply Reinhard color normalization to a WSI image."""
    img_array = np.array(wsi_image.convert("RGB"))
    normalized_img = reinhard(img_array, target_mu=cnorm['mu'], target_sigma=cnorm['sigma'])
    normalized_img = np.clip(normalized_img, 0, 255).astype(np.uint8)
    return Image.fromarray(normalized_img)

def save_wsi(wsi_path, output_path, slide_id, display_level=2):
    """Save the downsampled original WSI image."""
    slide = openslide.OpenSlide(wsi_path)
    display_level = min(display_level, slide.level_count - 1)
    downsampled = slide.read_region((0, 0), display_level, slide.level_dimensions[display_level])
    os.makedirs(output_path, exist_ok=True)
    downsampled.convert("RGB").save(os.path.join(output_path, f"{slide_id}_original.png"))

def highlight_attention_patches(wsi_path, attention_csv, output_path, slide_id, display_level=2, tile_size=256, threshold=0.5):
    """Overlay attention patches on the WSI image and save the heatmap."""
    slide = openslide.OpenSlide(wsi_path)
    display_level = min(display_level, slide.level_count - 1)

    df = pd.read_csv(attention_csv)
    df_filtered = df[df["attention_score"] > threshold]

    downsampled = slide.read_region((0, 0), display_level, slide.level_dimensions[display_level])
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.imshow(downsampled)
    ax.set_title(f"Attention Heatmap – {slide_id}", fontsize=18)

    scale = slide.level_downsamples[display_level]
    norm = mcolors.Normalize(vmin=-0.1, vmax=df["attention_score"].max())
    cmap = cm.get_cmap("jet")

    for _, row in df_filtered.iterrows():
        x = row["col"] * tile_size / scale
        y = row["row"] * tile_size / scale
        w = h = tile_size / scale
        color = cmap(norm(row["attention_score"]))
        rect = patches.Rectangle((x, y), w, h, linewidth=0, facecolor=color, alpha=0.7)
        ax.add_patch(rect)

    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ax=ax, label="Attention Score")
    os.makedirs(output_path, exist_ok=True)
    plt.savefig(os.path.join(output_path, f"{slide_id}_heatmap.png"), dpi=600)
    plt.close()

    # Save mean attention score
    mean_score = df["attention_score"].mean()
    summary_path = os.path.join(output_path, "slide_attention_scores.csv")
    score_entry = pd.DataFrame([{"Slide_ID": slide_id, "Attention_Mean": mean_score}])
    if os.path.exists(summary_path):
        existing = pd.read_csv(summary_path)
        combined = pd.concat([existing, score_entry], ignore_index=True)
        combined.to_csv(summary_path, index=False)
    else:
        score_entry.to_csv(summary_path, index=False)
    print(f" Processed {slide_id} | Mean Attention: {mean_score:.4f}")

# === Paths Configuration ===
base_wsi_path = "/path/to/wsi/files/"
attention_csv_dir = "/path/to/attention/csvs/"
output_dir = "/path/to/output/heatmaps/"

# === Run Heatmap Generation for All Slides in Folder ===
for file in os.listdir(attention_csv_dir):
    if file.endswith("_resnet50Features_dict.csv"):
        slide_id = file.replace("_resnet50Features_dict.csv", "")
        wsi_path = None

        # Try both .svs and .ndpi formats
        for ext in [".svs", ".ndpi"]:
            candidate_path = os.path.join(base_wsi_path, slide_id + ext)
            if os.path.exists(candidate_path):
                wsi_path = candidate_path
                break

        if wsi_path:
            csv_path = os.path.join(attention_csv_dir, file)
            save_wsi(wsi_path, output_dir, slide_id)
            highlight_attention_patches(wsi_path, csv_path, output_dir, slide_id, threshold=0.0001)
        else:
            print(f"Missing WSI file for {slide_id}")
