In [None]:
#!/usr/bin/env python3
#Description: 
#   Processes a flat directory of WSI feature files (ResNet50)
#   to detect distinct tissue sections using DBSCAN clustering.
#   Filters spatial noise using local neighbor density.
#   Outputs a CSV and plot for each WSI with section IDs.
#
# Input:
#   - FEATURE_DIR: Directory containing .pickle files with patch features
#
# Output:
#   - CSV_DIR: CSV files listing patch coordinates and section IDs
#   - PLOT_DIR: PNG plots of clustered tissue sections

# === IMPORTS ===
import os
import pickle
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')  # Safe for multiprocessing
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
from multiprocessing import Pool, cpu_count

# === CONFIG ===
FEATURE_DIR = "data/features/features_from_resnet50"
CSV_DIR     = "data/section_assignment_csvs"
PLOT_DIR    = "data/section_assignment_plots"

# clustering/filter params
EPS          = 10
MIN_SAMPLES  = 50
NOISE_RADIUS = 400
MIN_NEIGHBORS= 3

# ensure output dirs exist
os.makedirs(CSV_DIR,  exist_ok=True)
os.makedirs(PLOT_DIR, exist_ok=True)

# === helpers ===
def extract_coords(patch_name):
    parts = patch_name.replace(".png", "").split("_")
    if 'tile' in parts:
        try:
            idx = parts.index('tile')
            return [int(parts[idx+1]), int(parts[idx+2])]
        except:
            return None
    return None

def filter_noise(coords, radius=NOISE_RADIUS, min_neighbors=MIN_NEIGHBORS):
    nbrs   = NearestNeighbors(radius=radius).fit(coords)
    counts = np.array([len(nbrs.radius_neighbors([pt], return_distance=False)[0]) 
                       for pt in coords])
    mask   = counts >= min_neighbors
    return coords[mask], mask

# === slide‐level processing ===
def process_slide(file):
    slide_id    = file.replace("_resnet50Features_dict.pickle", "")
    pickle_path = os.path.join(FEATURE_DIR, file)
    csv_out     = os.path.join(CSV_DIR,    f"{slide_id}_sections.csv")
    plot_out    = os.path.join(PLOT_DIR,   f"{slide_id}_sections.png")

    try:
        with open(pickle_path, "rb") as f:
            patch_dict = pickle.load(f)

        coords, names = [], []
        for pname in patch_dict:
            c = extract_coords(pname)
            if c is not None:
                coords.append(c)
                names.append(pname)

        if not coords:
            print(f"No valid coords in {slide_id}")
            return

        coords_arr, mask = filter_noise(np.array(coords))
        names_filt       = np.array(names)[mask]

        if len(coords_arr) < MIN_SAMPLES:
            print(f"Too few valid patches in {slide_id}")
            return

        db      = DBSCAN(eps=EPS, min_samples=MIN_SAMPLES).fit(coords_arr)
        labels  = db.labels_
        n_sect  = len(set(labels)) - (1 if -1 in labels else 0)

        # save CSV
        df = pd.DataFrame({
            "patch_name": names_filt,
            "row":        coords_arr[:,0],
            "col":        coords_arr[:,1],
            "section_id": labels
        })
        df.to_csv(csv_out, index=False)

        # save plot
        plt.figure(figsize=(12, 8))
        plt.scatter(coords_arr[:,1], coords_arr[:,0], c=labels, cmap='tab10', s=10)
        plt.gca().invert_yaxis()
        plt.title(f"{slide_id} | Sections: {n_sect} | EPS={EPS}, MIN={MIN_SAMPLES}", fontsize=30)
        plt.xlabel("Column", fontsize=29)
        plt.ylabel("Row",    fontsize=29)
        plt.xticks(fontsize=24)
        plt.yticks(fontsize=24)
        plt.tight_layout()
        plt.savefig(plot_out, dpi=600)
        plt.close()

        print(f"{slide_id} | Sections: {n_sect}")

    except Exception as e:
        print(f" Error processing {slide_id}: {e}")

# === gather and run in parallel ===
if __name__ == "__main__":
    all_files = [f for f in os.listdir(FEATURE_DIR) if f.endswith(".pickle")]
    print(f"Starting parallel section detection on {len(all_files)} slides using {cpu_count()} cores...")
    with Pool(cpu_count()) as pool:
        pool.map(process_slide, all_files)
    print("\n All tissue sections processed and saved.")


In [None]:
# === Section-Wise Patch Saving for a Single Class ===
# For each slide:
# 1. Load full patch features (ResNet50 embeddings).
# 2. Load corresponding section IDs from CSV (via DBSCAN clustering).
# 3. Split patches by section and save each section as a new .pickle file.
# 4. Append metadata to a summary CSV.
# ===========================================================================
import os
import pickle
import pandas as pd

# === Config ===
ORIG_PICKLE_DIR = "/data/features/features_from_resnet50/..."
SECTION_CSV_DIR = "/data/features/section_assignment_csvs/..."
OUTPUT_PICKLE_DIR = "data/features/section_patches/..."
SUMMARY_CSV = "/data/features/section_patch_summary.csv"

os.makedirs(OUTPUT_PICKLE_DIR, exist_ok=True)

summary = []

# === Main Loop ===
for file in os.listdir(ORIG_PICKLE_DIR):
    if not file.endswith(".pickle"):
        continue

    # derive slide ID and paths
    slide_id = file.replace("_resnet50Features_dict.pickle", "")
    pickle_path = os.path.join(ORIG_PICKLE_DIR, file)
    section_csv_path = os.path.join(SECTION_CSV_DIR, f"{slide_id}_sections.csv")

    if not os.path.exists(section_csv_path):
        print(f"⚠️ Missing section CSV for {slide_id}")
        continue

    # Load full patch dict
    with open(pickle_path, "rb") as f:
        patch_dict = pickle.load(f)

    # Load section mapping
    df_sec = pd.read_csv(section_csv_path)

    # Group patches by section
    for section_id, group in df_sec.groupby("section_id"):
        section_dict = {
            row["patch_name"]: patch_dict[row["patch_name"]]
            for _, row in group.iterrows()
            if row["patch_name"] in patch_dict
        }

        new_slide_id = f"{slide_id}_section{section_id}"
        output_path = os.path.join(OUTPUT_PICKLE_DIR, f"{new_slide_id}.pickle")

        # Save section‐level pickle
        with open(output_path, "wb") as f_out:
            pickle.dump(section_dict, f_out)

        summary.append({
            "new_slide_id": new_slide_id,
            "original_slide_id": slide_id,
            "section_id": section_id,
            "class_label": "normal",
            "num_patches": len(section_dict)
        })

    print(f"✅ {slide_id} → {df_sec['section_id'].nunique()} sections saved")

# Save summary CSV
df_summary = pd.DataFrame(summary)
df_summary.to_csv(SUMMARY_CSV, index=False)
print(f"\n📄 Summary saved to: {SUMMARY_CSV}")


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

# --- Paths ---
INPUT_CSV = "/data/features/section_patch_summary.csv"
SUMMARY_CSV = "/data/features/expanded_wsi_summary.csv"
OUTPUT_PLOT = "/data/features/expanded_wsi_distribution.png"

# --- Load section summary ---
df = pd.read_csv(INPUT_CSV)

# --- Count WSI per class ---
class_counts = df['class_label'].value_counts().sort_index()
class_names = {"2": "WD", "3": "MD", "4": "PD"}

# --- Save summary CSV ---
summary_df = df.groupby('class_label')['new_slide_id'].nunique().reset_index()
summary_df.columns = ['class_label', 'num_expanded_wsis']
summary_df['class_name'] = summary_df['class_label'].astype(str).map(class_names)
summary_df.to_csv(SUMMARY_CSV, index=False)
print(f" Saved summary: {SUMMARY_CSV}")

# --- Prepare Plot ---
classes = [class_names[str(c)] for c in [2, 3, 4]]
counts = [class_counts.get(c, 0) for c in [2, 3, 4]]
x = np.arange(len(classes))
bar_width = 0.8

plt.figure(figsize=(7, 6))
plt.rcParams['axes.spines.left'] = True
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.bottom'] = True

plt.bar(x, counts, color=["burlywood", "wheat", "beige"], width=bar_width)
plt.xticks(x, classes, fontsize=18, rotation=0)
plt.yticks(fontsize=18)
plt.xlabel('Class', fontdict={'fontsize': 19, 'fontweight': 'bold', 'fontfamily': 'arial'})
plt.ylabel('Count', fontdict={'fontsize': 19, 'fontweight': 'bold', 'fontfamily': 'arial'})
plt.title('WSI distribution', fontdict={'fontsize': 24, 'fontweight': 'bold', 'fontfamily': 'arial'})
plt.ylim(-10, max(counts) + 100)

for i, value in enumerate(counts):
    plt.text(x[i], value + 5, f'{value}', ha='center', va='bottom', fontsize=18)

plt.tight_layout()
plt.savefig(OUTPUT_PLOT, dpi=600)
plt.close()
print(f" Plot saved: {OUTPUT_PLOT}")
