In [17]:
#FIRST PROMPT:

# PROMPT_TEMPLATE = """You are creating grounded, segmentation-aware captions for pathology tiles.

# Image: {image_filename}

# Tissue area (px): {tissue_px}

# Region coverage (percentage of tissue area):
# {region_lines}

# Detected cell instances from the nucleus mask.
# Each cell has a unique instance_id and a bounding box in pixel coordinates [x_min, y_min, x_max, y_max].
# Cell instances:
# {cell_lines}

# TASK:
# Write 1–2 fluent paragraphs that describe the image at a level suitable for a pathologist audience. 
# Ground every mention of structures explicitly using tags:
#   • For tissue regions, append [REGION:{{region_id}}] after the mention (e.g., "stromal areas [REGION:2]").
#   • For specific cell mentions tied to visible instances, append [CELL:{{instance_id}}] after the mention.
#   • If you describe aggregate cell patterns (e.g., "dense lymphocytic infiltrate"), you may reference multiple instances like [CELL:{{id1}},CELL:{{id2}},CELL:{{id3}}].

# CONSTRAINTS:
#   • Do not invent categories not present in the label sets below.
#   • Avoid clinical conclusions (diagnoses, grading); focus on visual description and composition.
#   • Prefer precise, grounded language. If a region is <1% of tissue, describe it as sparse or rare.

# REFERENCE LABELS:
#   Regions: {region_label_map}
#   Nuclei:  {nucleus_label_map}
# """

In [18]:
# curate_claude_prompts.py
# Python 3.9+
import json
from pathlib import Path
from typing import Dict, Tuple, List
import numpy as np
from PIL import Image
from matplotlib.colors import rgb_to_hsv

# ========== CONFIG ==========
IMAGES_DIR = Path("/home/himanshu/Downloads/tcga-20251015T043235Z-1-002/tcga/images_512")
MASKS_DIR  = Path("/home/himanshu/Downloads/tcga-20251015T043235Z-1-002/tcga/masks_512")
PREV_JSON  = Path("exclude_over_tissue_div_tissue_area_gt10.json")  # images in here will be SKIPPED
OUT_JSONL  = Path("claude_grounded_prompts.jsonl")

IMG_EXTS   = {".png", ".jpg", ".jpeg", ".tif", ".tiff"}
MASK_EXTS  = {".png"}
NUM_SAMPLES_LIMIT = 5  # for testing, limit to first N samples
# ---------- Label dictionaries (same as before) ----------
REGION_LABELS: Dict[int, str] = {
    0: "Exclude",
    1: "Cancerous epithelium",
    2: "Stroma",
    3: "TILs",
    4: "Normal epithelium",
    5: "Junk/Debris",
    6: "Blood",
    7: "Other",
    8: "Whitespace/Empty",
}

NUCLEUS_LABELS: Dict[int, str] = {
    0: "Exclude",
    1: "Cancer nucleus",
    2: "Stromal nucleus",
    3: "Large stromal nucleus",
    4: "Lymphocyte nucleus",
    5: "Plasma/large TIL nucleus",
    6: "Normal epithelial nucleus",
    7: "Other nucleus",
    8: "Unknown/Ambiguous nucleus",
    9: "Background (non-nuclear material)",
}

# Which nucleus values count as real cells (exclude backgrounds 0 & 9)
NUCLEUS_VALID_VALUES = [1, 2, 3, 4, 5, 6, 7, 8]

# ========== UTILITIES ==========
def discover_pairs(images_dir: Path, masks_dir: Path) -> List[Tuple[Path, Path]]:
    imgs = {p.stem: p for p in images_dir.iterdir() if p.suffix.lower() in IMG_EXTS}
    msks = {p.stem: p for p in masks_dir.iterdir()  if p.suffix.lower() in MASK_EXTS}
    keys = sorted(set(imgs) & set(msks))
    return [(imgs[k], msks[k]) for k in keys]

def load_rgb(path: Path) -> np.ndarray:
    return np.array(Image.open(path).convert("RGB"))

def load_mask(path: Path) -> np.ndarray:
    m = np.array(Image.open(path))
    if m.ndim != 3 or m.shape[-1] != 3:
        raise ValueError(f"Mask must be (H,W,3). Got {m.shape} at {path}.")
    if not np.issubdtype(m.dtype, np.integer):
        m = m.astype(np.uint8)
    return m

# ---------- Tissue via S-channel threshold ----------
def otsu_threshold_01(img01: np.ndarray) -> float:
    hist, bin_edges = np.histogram(img01.ravel(), bins=256, range=(0.0, 1.0))
    hist = hist.astype(np.float64)
    p = hist / (hist.sum() + 1e-12)
    omega = np.cumsum(p)
    mu = np.cumsum(p * np.arange(256))
    mu_t = mu[-1]
    denom = (omega * (1.0 - omega) + 1e-12)
    sigma_b2 = (mu_t * omega - mu) ** 2 / denom
    k_star = int(np.nanargmax(sigma_b2))
    thr = (bin_edges[k_star] + bin_edges[k_star+1]) * 0.5
    return float(np.clip(thr, 0.0, 1.0))

def tissue_mask_from_S(rgb: np.ndarray, method: str = "otsu", fixed_thresh: float = 0.2) -> np.ndarray:
    rgb01 = rgb.astype(np.float32) / 255.0
    hsv = rgb_to_hsv(rgb01)
    S = hsv[..., 1]
    t = otsu_threshold_01(S) if method == "otsu" else float(fixed_thresh)
    return S > t

# ---------- Connected components with fallbacks ----------
def connected_components_bool(mask_bool: np.ndarray) -> Tuple[np.ndarray, int]:
    """
    Return (labels, num_labels) for a boolean mask.
    Fallback chain: scipy.ndimage -> skimage.measure -> pure numpy BFS (slow but robust).
    """
    try:
        from scipy import ndimage as ndi
        labels, num = ndi.label(mask_bool.astype(np.uint8), structure=np.ones((3,3), dtype=np.uint8))
        return labels, int(num)
    except Exception:
        pass
    try:
        from skimage.measure import label as sk_label
        labels = sk_label(mask_bool, connectivity=2)
        return labels, int(labels.max())
    except Exception:
        pass
    # Pure numpy BFS fallback
    labels = np.zeros(mask_bool.shape, dtype=np.int32)
    H, W = mask_bool.shape
    num = 0
    # 8-connected neighbors
    neighbors = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]
    visited = np.zeros_like(mask_bool, dtype=bool)
    for y in range(H):
        for x in range(W):
            if mask_bool[y,x] and not visited[y,x]:
                num += 1
                stack = [(y,x)]
                visited[y,x] = True
                labels[y,x] = num
                while stack:
                    cy,cx = stack.pop()
                    for dy,dx in neighbors:
                        ny,nx = cy+dy, cx+dx
                        if 0 <= ny < H and 0 <= nx < W and mask_bool[ny,nx] and not visited[ny,nx]:
                            visited[ny,nx] = True
                            labels[ny,nx] = num
                            stack.append((ny,nx))
    return labels, num

def bboxes_from_cc_labels(labels: np.ndarray, num: int) -> List[Tuple[int,int,int,int]]:
    """
    Given labeled components (1..num), return list of bounding boxes [x_min, y_min, x_max, y_max].
    """
    boxes = []
    for k in range(1, num+1):
        ys, xs = np.where(labels == k)
        if ys.size == 0:
            boxes.append([0,0,0,0])
            continue
        y_min, y_max = int(ys.min()), int(ys.max())
        x_min, x_max = int(xs.min()), int(xs.max())
        boxes.append([x_min, y_min, x_max, y_max])
    return boxes

# ---------- Region coverage over tissue ----------
def region_percentages_over_tissue(reg: np.ndarray, tissue: np.ndarray) -> Dict[int, float]:
    tissue_px = int(tissue.sum())
    if tissue_px == 0:
        return {k: 0.0 for k in REGION_LABELS.keys()}
    out = {}
    for k in REGION_LABELS.keys():
        if k in (0, 8):  # Exclude + Whitespace often not tissue; still compute objectively:
            inter = int(((reg == k) & tissue).sum())
            out[k] = 100.0 * inter / tissue_px
        else:
            inter = int(((reg == k) & tissue).sum())
            out[k] = 100.0 * inter / tissue_px
    return out

# ---------- Build prompt text ----------
PROMPT_TEMPLATE = """
You are an expert histopathology assistant.

Your job has two phases:
(1) Write the best possible description based **only on what is visually present in the image**.
(2) Repeat that same text **verbatim**, but append grounding tags using the provided metadata (regions and detected nuclei instances). Use metadata strictly for tagging and light quantification; do not let it change the wording of the image-based description.

# IMAGE + METADATA
Image: {image_filename}
Tissue area (px): {tissue_px}

Region coverage (% of tissue area):
{region_lines}

Detected cell instances from the nucleus mask (each has unique instance_id and bbox [x_min, y_min, x_max, y_max] in px):
{cell_lines}

REFERENCE LABELS:
  Regions: {region_label_map}
  Nuclei:  {nucleus_label_map}

# SILENT DELIBERATION (do not include in output)
- IMAGE-FIRST: Build your description solely from visible evidence (architecture, interfaces, stroma, cellularity, distribution, cytologic features). Do not import labels or IDs into wording.
- SALIENCE ORDER: (1) overall composition & dominant compartment(s); (2) architecture & interfaces; (3) stromal features; (4) cellularity & distribution (diffuse/focal/clustered); (5) cytology if clearly appreciable; (6) conspicuous negatives (e.g., no necrosis) if confidently visible; (7) notable artifacts if present.
- QUANTIFY (if visually supportable): use approximate counts/percentages/densities that are visually reasonable; later you may refine with metadata **without** changing the prose.
- SPATIAL LANGUAGE: use central/peripheral, adjacent to, interface/border, perivascular, subglandular, etc.
- GROUNDING RULE: In the grounded copy, add tags only to entities already described from the image. Never invent mentions to satisfy metadata. If metadata lacks an ID for a described entity, leave it untagged.

# OUTPUT (STRICT) — produce exactly these sections but give only the GROUNDED_DESCRIPTION

DESCRIPTION
Write 1–2 paragraphs of polished pathologist prose derived **only from the image**. Lead with dominant regions and composition; describe architecture and interfaces; summarize stromal qualities; characterize cellularity and distribution; add cytology only if clearly visible; include conspicuous negatives where appropriate. No IDs, no tags, no label names from metadata.

GROUNDED_DESCRIPTION
Repeat DESCRIPTION **verbatim**, but append grounding tags **inline after the first mention** of each corresponding entity:
- Tissue regions → [REGION:{{region_id}}]
- Specific single cells you explicitly reference → [CELL:{{instance_id}}]
- Aggregate cellular patterns (e.g., “dense lymphocytic infiltrate”) → one tag listing all representative instances: [CELL:{{id1}},{{id2}},{{id3}}]
Use only IDs present above. Do not alter wording. Do not invent tags.

"""






def format_region_lines(region_pcts: Dict[int, float]) -> str:
    lines = []
    for k in sorted(region_pcts.keys()):
        lines.append(f"- {k}: {REGION_LABELS[k]} = {region_pcts[k]:.2f}%")
    return "\n".join(lines)

def format_cell_lines(instances: List[Dict]) -> str:
    if not instances:
        return "- (no cell instances detected)"
    lines = []
    for inst in instances:
        lines.append(
            f"- CELL:{inst['instance_id']} | type={inst['cell_type_id']}({inst['cell_type_name']}) "
            f"| bbox={inst['bbox']}"
        )
    return "\n".join(lines)

# ---------- Main curation ----------
def curate_prompts(
    images_dir: Path = IMAGES_DIR,
    masks_dir: Path = MASKS_DIR,
    prev_json: Path = PREV_JSON,
    out_jsonl: Path = OUT_JSONL,
    s_method: str = "otsu",
    s_fixed_thresh: float = 0.2,
    min_component_size: int = 5  # drop tiny specks (px)
):
    pairs = discover_pairs(images_dir, masks_dir)
    pairs = pairs[:NUM_SAMPLES_LIMIT]  # for testing, limit to first N samples
    
    print(f"Found {len(pairs)} image/mask pairs.")

    # Load the "skip" set from prev_json
    skip_names = set()
    if prev_json.exists():
        try:
            data = json.loads(prev_json.read_text())
            for r in data.get("results", []):
                if "filename" in r:
                    skip_names.add(r["filename"])
        except Exception as e:
            print(f"Warning: could not parse {prev_json}: {e}")

    n_written = 0
    with out_jsonl.open("w") as fout:
        for ipath, mpath in pairs:
            if ipath.name in skip_names:
                # Skip images that failed earlier criteria
                continue

            rgb = load_rgb(ipath)
            m   = load_mask(mpath)
            reg = m[..., 0]
            nuc = m[..., 1]

            # Tissue mask
            tissue = tissue_mask_from_S(rgb, method=s_method, fixed_thresh=s_fixed_thresh)
            tissue_px = int(tissue.sum())

            # Region coverage (as % of tissue)
            region_pcts = region_percentages_over_tissue(reg, tissue)

            # Nucleus instances (per valid class)
            instances = []
            next_id = 1
            H, W = nuc.shape
            for val in NUCLEUS_VALID_VALUES:
                class_mask = (nuc == val)
                if not class_mask.any():
                    continue
                # restrict to tissue to avoid whitespace specks
                class_mask = class_mask & tissue
                if not class_mask.any():
                    continue

                labels, num = connected_components_bool(class_mask)
                if num == 0:
                    continue

                # Optional: remove tiny components
                # Compute sizes and filter
                sizes = np.bincount(labels.ravel())
                # sizes[0] is background
                keep = np.where(sizes >= max(min_component_size, 1))[0]
                keep = keep[keep != 0]

                if keep.size == 0:
                    continue

                # Create mask with only kept components
                kept_mask = np.isin(labels, keep)
                kept_labels, kept_num = connected_components_bool(kept_mask)

                boxes = bboxes_from_cc_labels(kept_labels, kept_num)
                for k in range(1, kept_num + 1):
                    ys, xs = np.where(kept_labels == k)
                    if ys.size == 0:
                        continue
                    y_min, y_max = int(ys.min()), int(ys.max())
                    x_min, x_max = int(xs.min()), int(xs.max())
                    instances.append({
                        "instance_id": next_id,
                        "cell_type_id": int(val),
                        "cell_type_name": NUCLEUS_LABELS[val],
                        "bbox": [x_min, y_min, x_max, y_max],
                        "area_px": int(ys.size),
                    })
                    next_id += 1

            # Build the prompt text
            prompt_text = PROMPT_TEMPLATE.format(
                image_filename=ipath.name,
                tissue_px=tissue_px,
                region_lines=format_region_lines(region_pcts),
                cell_lines=format_cell_lines(instances),
                region_label_map=json.dumps(REGION_LABELS, ensure_ascii=False),
                nucleus_label_map=json.dumps(NUCLEUS_LABELS, ensure_ascii=False),
            )

            # Write one JSONL record per image
            record = {
                "image_path": str(ipath),
                "metadata": {
                    "tissue_px": tissue_px,
                    "region_percent_over_tissue": {str(k): round(v, 4) for k, v in region_pcts.items()},
                    "cell_instances": instances,  # list of dicts with instance_id, type, bbox, area_px
                    "region_labels": REGION_LABELS,
                    "nucleus_labels": NUCLEUS_LABELS,
                },
                "prompt": prompt_text,
            }
            fout.write(json.dumps(record, ensure_ascii=False) + "\n")
            n_written += 1

    print(f"Wrote {n_written} prompt(s) to {out_jsonl}")

# ---------- Run ----------
if __name__ == "__main__":
    curate_prompts()


Found 5 image/mask pairs.
Wrote 5 prompt(s) to claude_grounded_prompts.jsonl


In [19]:
# send_to_claude.py
# Python 3.9+
import os
import json
import time
from pathlib import Path
from typing import Optional

# pip install anthropic==0.* (current SDK)
import anthropic
from anthropic.types import Message

IN_JSONL  = Path("claude_grounded_prompts.jsonl")
OUT_JSONL = Path("claude_grounded_prompts_with_responses.jsonl")

# Choose your model; update if you prefer a smaller/cheaper one
CLAUDE_MODEL = os.getenv("CLAUDE_MODEL", "claude-3-5-sonnet-latest")

# Safety knobs
MAX_RETRIES = 5
INITIAL_BACKOFF_S = 2.0
RATE_LIMIT_DELAY_S = float(os.getenv("CLAUDE_RATE_DELAY_S", "0.0"))  # e.g., 0.5 to throttle

def call_claude(
    client: anthropic.Anthropic,
    prompt_text: str,
    max_tokens: int = 2048
) -> str:
    """
    Calls Claude Messages API with your prompt text and returns the string content.
    """
    backoff = INITIAL_BACKOFF_S
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            msg: Message = client.messages.create(
                model=CLAUDE_MODEL,
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": prompt_text}],
            )
            # Messages API returns a list of content blocks; join text blocks
            parts = []
            for block in msg.content:
                if block.type == "text":
                    parts.append(block.text)
            return "\n".join(parts).strip()
        except anthropic.RateLimitError as e:
            # Hit org/project rate limits: back off
            if attempt == MAX_RETRIES:
                raise
            time.sleep(backoff)
            backoff *= 2
        except anthropic.APIError as e:
            # Transient server errors: retry with backoff
            if getattr(e, "status_code", 500) >= 500 and attempt < MAX_RETRIES:
                time.sleep(backoff)
                backoff *= 2
                continue
            raise
NUM_SAMPLES_LIMIT = 5  # for testing, limit to first N samples
def main():
    api_key = os.getenv("ANTHROPIC_API_KEY")
    if not api_key:
        raise RuntimeError("Set ANTHROPIC_API_KEY in your environment.")

    client = anthropic.Anthropic(api_key=api_key)

    if not IN_JSONL.exists():
        raise FileNotFoundError(f"Input JSONL not found: {IN_JSONL}")

    n_in, n_out = 0, 0
    with IN_JSONL.open("r") as fin, OUT_JSONL.open("w") as fout:
        for i, line in enumerate(fin):
            line = line.strip()
            if not line:
                continue
            n_in += 1
            record = json.loads(line)

            prompt_text = record.get("prompt", "")
            if not prompt_text:
                # Nothing to send—skip
                continue

            # Optional: throttle to be nice to rate limits
            if RATE_LIMIT_DELAY_S > 0:
                time.sleep(RATE_LIMIT_DELAY_S)

            response_text = call_claude(client, prompt_text, max_tokens=1200)

            out_rec = {
                **record,
                "claude": {
                    "model": CLAUDE_MODEL,
                    "response": response_text,
                },
            }
            fout.write(json.dumps(out_rec, ensure_ascii=False) + "\n")
            n_out += 1

            if i == NUM_SAMPLES_LIMIT - 1:
                # For testing, limit to first 10 prompts
                break

    print(f"Processed {n_in} prompts; wrote {n_out} responses to {OUT_JSONL}")

if __name__ == "__main__":
    main()


Processed 5 prompts; wrote 5 responses to claude_grounded_prompts_with_responses.jsonl
