# Stained area per marker (before/after normalization)
This notebook computes how much of each marker channel is 'stained' (area fraction of pixels above a threshold)
before and after the **same control-marker regression normalization** used in preprocessing.

Output: one table with **mean / std / median stained area** for each **shared marker** in each dataset.


In [None]:
!pip install imagecodecs

In [None]:
import os, glob, gc, json
import numpy as np
import pandas as pd
import imagecodecs
import tifffile
import zipfile
from sklearn.linear_model import SGDRegressor

import matplotlib.pyplot as plt
import seaborn as sns


DATASET_A_NAME = "jackson2020"
DATASET_A_PATH = f"/kaggle/input/{DATASET_A_NAME}"   

DATASET_B_NAME = "danenberg2022"
DATASET_B_PATH = f"/kaggle/input/{DATASET_B_NAME}"  

MARKER_NAMES_A = ['c_erb_b_2_her2', 'c_myc', 'carbonic_anhydrase_ix', 'cd20', 'cd3', 'cd44', 'cd45', 'cd68', 'cleaved_parp', 'cytokeratin_19', 'cytokeratin_5', 'cytokeratin_7', 'cytokeratin_8_18', 'dna1', 'dna2', 'e_cadherin_p_cadherin', 'egfr', 'fibronectin', 'gata3', 'histone_h3', 'histone_h3_phospho', 'histone_h3_trimethylate', 'keratin_14_krt14', 'ki_67', 'm_tor', 'p53', 'pan_cytokeratin', 'progesterone_receptor_a_b', 'rabbit_ig_g_h_l', 's6', 'slug', 'sma', 'twist', 'v_wf', 'vimentin']
MARKER_NAMES_B = ['beta_2_microglobulin', 'c_erb_b_2_her2_3b5', 'c_erb_b_2_her2_d8f12', 'caveolin_1', 'cd11c', 'cd134', 'cd140b_pdgf_receptor_beta', 'cd15', 'cd16', 'cd163', 'cd20', 'cd278_icos', 'cd279_pd_1', 'cd3', 'cd31_v_wf', 'cd38', 'cd4', 'cd45', 'cd45ra', 'cd57', 'cd68', 'cd8a', 'cleaved_caspase3', 'cxcl12_sdf_1', 'cytokeratin_5', 'cytokeratin_8_18', 'dna1', 'dna2', 'estrogen_receptor_alpha', 'foxp3', 'fsp1', 'gitr_tnfrsf18', 'histone_h3', 'hla_abc', 'hla_dr', 'ki_67', 'pan_cytokeratin', 'podoplanin', 'sma']


COFACTOR = 5.0
CHUNK_SIZE = 100_000         # for regression fitting & area computation
STAIN_THRESHOLD_RAW = 0.0        # threshold to define "stained" pixels (applied to raw and normalized values)
RIDGE_LAMBDA = 1e-6

CONTROL_MARKER_NAMES = ['dna1', 'dna2', 'histone_h3'] 

In [None]:
def list_tiff_files(root: str):
    files = sorted(
        glob.glob(os.path.join(root, "**", "*.tif"), recursive=True)
        + glob.glob(os.path.join(root, "**", "*.tiff"), recursive=True)
    )
    return [f for f in files if os.path.isfile(f)]

def safe_imread(path: str):
    """Robust TIFF reader (handles most IMC/OME-TIFF setups)."""
    import tifffile
    try:
        return tifffile.imread(path)
    except Exception as e:
        raise RuntimeError(f"Failed to read TIFF: {path}\n{e}")

def to_hwc(img: np.ndarray) -> np.ndarray:
    """
    Ensure HWC float32.
    Common TIFF layouts:
      - HWC already
      - CHW (channels first)
      - HW (single channel)
    """
    if img.ndim == 2:
        img = img[..., None]
    if img.ndim != 3:
        raise ValueError(f"Expected 2D or 3D image, got shape={img.shape}")

    # Heuristic: if first dim is "small", treat as channels-first
    if img.shape[0] < img.shape[1] and img.shape[0] < img.shape[2]:
        img = np.transpose(img, (1, 2, 0))

    return img.astype(np.float32, copy=False)

# ----------------------------
# Marker indexing / validation
# ----------------------------

def make_index_map(marker_names):
    return {name: i for i, name in enumerate(marker_names)}

def validate_controls(marker_names, control_marker_names):
    name_to_idx = make_index_map(marker_names)
    missing = [c for c in control_marker_names if c not in name_to_idx]
    if missing:
        raise ValueError(
            f"Missing control markers: {missing}\n"
            f"Available markers (first 25): {marker_names[:25]}"
        )
    ctrl_idx = [name_to_idx[c] for c in control_marker_names]
    return ctrl_idx, name_to_idx

# ----------------------------
# Core computation
# ----------------------------

def fit_controls_to_all_markers_multitarget(
    files,
    marker_names,
    control_marker_names,
    chunk_size=100_000,
    ridge_lambda=1e-6,
):
    """
    Fit multi-target linear regression in raw space:
      Y_nonctrl â‰ˆ [1, X_controls] @ B

    Returns:
      ctrl_idx: list[int] indices of control markers in channel axis
      coef_full: (n_ctrl, C) float32 -- zeros for control columns
      intercept_full: (C,) float32   -- zeros for control columns
    """
    ctrl_idx, _ = validate_controls(marker_names, control_marker_names)
    C = len(marker_names)
    k = len(ctrl_idx)

    nonctrl_idx = [i for i in range(C) if i not in ctrl_idx]
    m = len(nonctrl_idx)

    # normal equations stats for X_aug = [1, X_controls]
    XtX = np.zeros((k + 1, k + 1), dtype=np.float64)
    XtY = np.zeros((k + 1, m), dtype=np.float64)

    for i, f in enumerate(files):
        if i % max(1, len(files)//5) == 0 or i == len(files) - 1:
            print(f"[fit {os.path.basename(files[0])[:0]}] {i+1}/{len(files)}: {os.path.basename(f)}")

        img = to_hwc(safe_imread(f))
        H, W, C_img = img.shape
        if C_img != C:
            raise ValueError(
                f"Channel mismatch in {f}:\n"
                f"  image channels={C_img}\n"
                f"  marker_names={C}"
            )

        flat = img.reshape(-1, C_img)
        N = flat.shape[0]

        for s in range(0, N, chunk_size):
            e = min(N, s + chunk_size)
            chunk = flat[s:e]  # (b, C)

            X = chunk[:, ctrl_idx].astype(np.float64, copy=False)  # (b, k)
            ones = np.ones((X.shape[0], 1), dtype=np.float64)
            X_aug = np.concatenate([ones, X], axis=1)              # (b, k+1)

            Y = chunk[:, nonctrl_idx].astype(np.float64, copy=False)  # (b, m)

            XtX += X_aug.T @ X_aug
            XtY += X_aug.T @ Y

        del img, flat
        gc.collect()

    XtX_reg = XtX + ridge_lambda * np.eye(k + 1, dtype=np.float64)
    B = np.linalg.solve(XtX_reg, XtY)  # (k+1, m)

    intercept = B[0, :]   # (m,)
    coef = B[1:, :]       # (k, m)

    # Expand to full marker space; controls get 0 correction
    coef_full = np.zeros((k, C), dtype=np.float32)
    intercept_full = np.zeros((C,), dtype=np.float32)

    coef_full[:, nonctrl_idx] = coef.astype(np.float32)
    intercept_full[nonctrl_idx] = intercept.astype(np.float32)

    return ctrl_idx, coef_full, intercept_full

def compute_means_raw_and_norm_and_raw_area(
    files,
    marker_names,
    ctrl_idx,
    coef_full,
    intercept_full,
    chunk_size=100_000,
    stain_threshold_raw=0.0,
):
    """
    For ALL markers:
      - raw_mean: mean(raw) per image
      - norm_mean: mean(resid) per image, resid = raw - pred_from_controls
      - raw_area_%: % pixels where raw > stain_threshold_raw
                    (computed BEFORE arcsinh and BEFORE normalization)
    """
    C = len(marker_names)

    raw_means = []
    norm_means = []
    raw_areas = []

    for i, f in enumerate(files):
        if i % max(1, len(files)//5) == 0 or i == len(files) - 1:
            print(f"[summaries] {i+1}/{len(files)}: {os.path.basename(f)}")

        img = to_hwc(safe_imread(f))
        H, W, C_img = img.shape
        if C_img != C:
            raise ValueError(
                f"Channel mismatch in {f}:\n"
                f"  image channels={C_img}\n"
                f"  marker_names={C}"
            )

        flat = img.reshape(-1, C_img)
        N = flat.shape[0]

        sum_raw = np.zeros((C,), dtype=np.float64)
        sum_norm = np.zeros((C,), dtype=np.float64)
        cnt_stain_raw = np.zeros((C,), dtype=np.int64)

        for s in range(0, N, chunk_size):
            e = min(N, s + chunk_size)
            chunk = flat[s:e]  # (b, C)

            # RAW mean (no arcsinh)
            sum_raw += chunk.sum(axis=0)

            # RAW stained area (no arcsinh, no normalization)
            cnt_stain_raw += (chunk > stain_threshold_raw).sum(axis=0)

            # Normalization residuals (raw space)
            X = chunk[:, ctrl_idx]                          # (b, k)
            pred = (X @ coef_full) + intercept_full         # (b, C)
            resid = chunk - pred                            # (b, C)
            sum_norm += resid.sum(axis=0)

        raw_means.append(sum_raw / N)
        norm_means.append(sum_norm / N)
        raw_areas.append((cnt_stain_raw / N) * 100.0)

        del img, flat
        gc.collect()

    raw_mean_df = pd.DataFrame(np.vstack(raw_means), columns=marker_names)
    norm_mean_df = pd.DataFrame(np.vstack(norm_means), columns=marker_names)
    raw_area_df = pd.DataFrame(np.vstack(raw_areas), columns=marker_names)

    return raw_mean_df, norm_mean_df, raw_area_df

# ----------------------------
# Reporting
# ----------------------------

def plot_spearman_corr_matrix(df_means: pd.DataFrame, title: str, out_png: str):
    corr = df_means.corr(method="spearman")
    n = corr.shape[0]

    # dynamic figure sizing; large marker panels get huge quickly
    size = max(7.0, min(22.0, 0.40 * n))
    plt.figure(figsize=(size, size))
    im = plt.imshow(corr.values, cmap="coolwarm", vmin=-1, vmax=1)
    plt.title(title)

    if n <= 60:
        fs = max(5, int(14 - n / 5))
        plt.xticks(range(n), corr.columns.tolist(), rotation=90, fontsize=fs)
        plt.yticks(range(n), corr.index.tolist(), fontsize=fs)
    else:
        k = int(np.ceil(n / 40))
        ticks = list(range(0, n, k))
        plt.xticks(ticks, [corr.columns[i] for i in ticks], rotation=90, fontsize=6)
        plt.yticks(ticks, [corr.index[i] for i in ticks], fontsize=6)

    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()
    print(f"Saved: {out_png}")
    return corr

def summarize_raw_area(raw_area_df: pd.DataFrame):
    rows = []
    for m in raw_area_df.columns:
        v = raw_area_df[m].to_numpy(dtype=np.float64)
        n = int(np.sum(~np.isnan(v)))
        rows.append({
            "marker": m,
            "n_images": n,
            "raw_area_mean_%": float(np.nanmean(v)) if n else np.nan,
            "raw_area_std_%": float(np.nanstd(v, ddof=1)) if n > 1 else (0.0 if n == 1 else np.nan),
            "raw_area_median_%": float(np.nanmedian(v)) if n else np.nan,
            "raw_area_q25_%": float(np.nanpercentile(v, 25)) if n else np.nan,
            "raw_area_q75_%": float(np.nanpercentile(v, 75)) if n else np.nan,
        })
    return pd.DataFrame(rows).sort_values("marker").reset_index(drop=True)

def run_dataset(dataset_path, dataset_name, marker_names):
    files = list_tiff_files(dataset_path)
    if len(files) == 0:
        raise ValueError(f"No TIFF files found under: {dataset_path}")

    print(f"\n=== {dataset_name} ===")
    print(f"Found {len(files)} TIFFs")
    print(f"Markers: {len(marker_names)}")

    # Fit models
    ctrl_idx, coef_full, intercept_full = fit_controls_to_all_markers_multitarget(
        files=files,
        marker_names=marker_names,
        control_marker_names=CONTROL_MARKER_NAMES,
        chunk_size=CHUNK_SIZE,
        ridge_lambda=RIDGE_LAMBDA,
    )

    # Compute per-image summaries
    raw_mean_df, norm_mean_df, raw_area_df = compute_means_raw_and_norm_and_raw_area(
        files=files,
        marker_names=marker_names,
        ctrl_idx=ctrl_idx,
        coef_full=coef_full,
        intercept_full=intercept_full,
        chunk_size=CHUNK_SIZE,
        stain_threshold_raw=STAIN_THRESHOLD_RAW,
    )

    # Spearman correlations (all markers)
    corr_raw = plot_spearman_corr_matrix(
        raw_mean_df,
        title=f"{dataset_name} Spearman (RAW means, all markers)",
        out_png=f"spearman_raw_all_{dataset_name}.png",
    )
    corr_norm = plot_spearman_corr_matrix(
        norm_mean_df,
        title=f"{dataset_name} Spearman (NORMALIZED means, all markers)",
        out_png=f"spearman_norm_all_{dataset_name}.png",
    )

    corr_raw.to_csv(f"spearman_raw_all_{dataset_name}.csv")
    corr_norm.to_csv(f"spearman_norm_all_{dataset_name}.csv")
    print(f"Saved: spearman_raw_all_{dataset_name}.csv")
    print(f"Saved: spearman_norm_all_{dataset_name}.csv")

    # Raw stain area stats (before arcsinh, before norm)
    area_stats = summarize_raw_area(raw_area_df)
    area_stats.to_csv(f"raw_area_stats_all_{dataset_name}.csv", index=False)
    print(f"Saved: raw_area_stats_all_{dataset_name}.csv")

    return {
        "raw_mean_df": raw_mean_df,
        "norm_mean_df": norm_mean_df,
        "raw_area_df": raw_area_df,
        "corr_raw": corr_raw,
        "corr_norm": corr_norm,
        "area_stats": area_stats,
    }

def main():
    # Basic safety: ensure marker lists are filled
    if not MARKER_NAMES_A or not MARKER_NAMES_B:
        raise ValueError(
            "Please paste MARKER_NAMES_A and MARKER_NAMES_B lists (must match TIFF channel order)."
        )

    outA = run_dataset(DATASET_A_PATH, DATASET_A_NAME, MARKER_NAMES_A)
    outB = run_dataset(DATASET_B_PATH, DATASET_B_NAME, MARKER_NAMES_B)

    # Optional: print a quick peek
    print("\n--- Example outputs ---")
    print(outA["area_stats"].head(10).to_string(index=False))
    print(outB["area_stats"].head(10).to_string(index=False))

if __name__ == "__main__":
    main()