In [15]:
print("内核正常工作")
import sys
print(f"Python版本: {sys.version}")

内核正常工作
Python版本: 3.10.18 (main, Jun  5 2025, 08:13:51) [Clang 14.0.6 ]


In [16]:
import os
os.environ["GOOGLE_API_KEY"] = "AIzaSyBhrFyjm4FMNTckLYyxYaVPw-EBpTwQ3Ho"

In [17]:
import torch, torchvision
from segment_anything import sam_model_registry, SamPredictor

print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("segment_anything is available!")

PyTorch version: 2.2.2
Torchvision version: 0.17.2
segment_anything is available!


In [19]:
# environment check
import os
import logging
from pathlib import Path
import pandas as pd
import pydicom
import numpy as np
import matplotlib.pyplot as plt

import torch
from segment_anything import sam_model_registry, SamPredictor

from Bio import Entrez
import google.generativeai as genai

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

OUTPUT_DIR = Path("DIP Project/outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# MedSAM checkpoint
MODEL_CHECKPOINT = "DIP Project/models/medsam_vit_b.pth"

# Gemini API Key
GENAI_API_KEY = os.environ.get("GOOGLE_API_KEY")
if not GENAI_API_KEY:
    raise RuntimeError("Please set environment variable GOOGLE_API_KEY before running.")
genai.configure(api_key=GENAI_API_KEY)

# PubMed 邮箱
Entrez.email = "835597824@qq.com"

In [2]:
#1 DICOM -> PNG

from pathlib import Path
import pydicom
import numpy as np
import matplotlib.pyplot as plt

def dicom_to_png(dicom_file):
    ds = pydicom.dcmread(dicom_file)
    pixel_array = ds.pixel_array.astype(float)
    norm = (pixel_array - np.min(pixel_array)) / (np.max(pixel_array) - np.min(pixel_array))
    
    png_path = Path("DIP Project/outputs") / f"{Path(dicom_file).stem}_test.png"
    plt.imsave(png_path, norm, cmap='gray')
    print("PNG saved:", png_path)
    print("PatientID in DICOM:", getattr(ds, "PatientID", None))
    return png_path, ds

png_path, ds = dicom_to_png("DIP Project/data/Unknown-8.dcm")

PNG saved: DIP Project/outputs/Unknown-8_test.png
PatientID in DICOM: TCGA-AO-A03M


In [3]:
#2 MedSAM segmentation（自动中心点提示 + CPU/GPU 自适应 + 可视化 + DICE/IoU）

def run_medsam_segmentation_v2(image_path: str, gt_mask_path: str = None, output_dir=OUTPUT_DIR,
                               use_box_prompt=False):
    import cv2
    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    from segment_anything import sam_model_registry, SamPredictor
    from pathlib import Path

    print("Step 2: Running MedSAM segmentation (v2)...")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # 构建 SAM 模型并加载 checkpoint
    sam = sam_model_registry["vit_b"]()
    sam.to(device)
    checkpoint_dict = torch.load(MODEL_CHECKPOINT, map_location=device)
    sam.load_state_dict(checkpoint_dict)

    predictor = SamPredictor(sam)

    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    predictor.set_image(img_rgb)
    h, w, _ = img_rgb.shape

    # Prompt
    if use_box_prompt:
        # 示例 box：可以根据病灶大致位置调整
        box = np.array([w//4, h//4, 3*w//4, 3*h//4])
        masks, scores, _ = predictor.predict(box=box, multimask_output=False)
    else:
        # 多前景点示例
        point_coords = np.array([[h//2, w//2], [h//3, w//3]])
        point_labels = np.array([1, 1])  # 都是前景
        masks, scores, _ = predictor.predict(point_coords=point_coords,
                                     point_labels=point_labels,
                                     multimask_output=False)

    mask = masks[0].astype(np.uint8)
    mask_score = float(scores[0])
    print(f"Mask score: {mask_score:.4f}")

    # save mask PNG
    mask_path = output_dir / f"{Path(image_path).stem}_mask.png"
    plt.imsave(mask_path, mask, cmap='gray')
    print(f"Saved mask: {mask_path}")

    # DICE/IoU
    metrics = {}
    if gt_mask_path:
        gt_mask = cv2.imread(gt_mask_path, cv2.IMREAD_GRAYSCALE)
        gt_mask = (gt_mask > 0).astype(np.uint8)
        mask_bin = (mask > 0).astype(np.uint8)
        intersection = np.sum(gt_mask * mask_bin)
        dice = 2.0 * intersection / (np.sum(gt_mask) + np.sum(mask_bin) + 1e-8)
        union = np.sum((gt_mask + mask_bin) > 0)
        iou = intersection / (union + 1e-8)
        metrics = {"dice": float(dice), "iou": float(iou)}

    # save overlay
    overlay = img_rgb.copy()
    overlay[mask > 0] = [255, 0, 0]
    vis_path = output_dir / f"{Path(image_path).stem}_overlay.png"
    plt.imsave(vis_path, overlay)
    print(f"Saved overlay: {vis_path}")

    return str(mask_path), mask, mask_score, metrics, str(vis_path)

#测试
png_path = "DIP Project/outputs/Unknown-8_test.png"
mask_path, mask, mask_score, metrics, overlay_path = run_medsam_segmentation_v2(
    png_path, gt_mask_path=None, use_box_prompt=True
)

print("Mask path:", mask_path)
print("Overlay path:", overlay_path)
print("Mask score:", mask_score)
print("Metrics:", metrics)

NameError: name 'OUTPUT_DIR' is not defined

In [7]:
#3 extract Patient features from clinical csv

def get_patient_features(ds, clinical_csv_path="DIP Project/data/clinical_data_preprocessed.csv"):
    import pandas as pd
    import logging

    print("Step 3: Extracting patient features from clinical CSV...")
    df = pd.read_csv(clinical_csv_path)
    patient_id = getattr(ds, "PatientID", None)

    if patient_id is None:
        logging.warning("No PatientID in DICOM metadata. Using fallback keywords.")
        return {"subtype": None, "tumor_stage": None}

    patient_row = df[df["bcr_patient_barcode"] == patient_id]
    if patient_row.empty:
        logging.warning(f"No matching patient in clinical data for ID {patient_id}.")
        return {"subtype": None, "tumor_stage": None}

    subtype = patient_row.iloc[0].get("subtype", None)
    tumor_stage = patient_row.iloc[0].get("tumor_stage", None)

    # transfer nan to None
    if pd.isna(subtype):
        subtype = None
    if pd.isna(tumor_stage):
        tumor_stage = None

    print(f"Matched PatientID {patient_id}: subtype={subtype}, tumor_stage={tumor_stage}")
    return {"subtype": subtype, "tumor_stage": tumor_stage}


# test
features = get_patient_features(ds)


Step 3: Extracting patient features from clinical CSV...
Matched PatientID TCGA-AO-A03M: subtype=None, tumor_stage=I


In [8]:
#4 PubMed search

from Bio import Entrez
Entrez.email = "835597824@qq.com"

def test_pubmed(features):
    keywords = []
    if features.get("subtype"): keywords.append(features["subtype"])
    if features.get("tumor_stage"): keywords.append(features["tumor_stage"])
    query = " AND ".join(keywords) if keywords else "cancer"
    handle = Entrez.esearch(db="pubmed", term=query, retmax=2)
    record = Entrez.read(handle)
    ids = record.get("IdList", [])
    results = [{"id": pid, "link": f"https://pubmed.ncbi.nlm.nih.gov/{pid}/"} for pid in ids]
    print("PubMed results:", results)
    return results

#test
literature = test_pubmed(features)

PubMed results: []


In [48]:
#5 LLM Summary

import google.generativeai as genai
import warnings, logging

def generate_llm_summary(image_path, mask_path, metrics, literature):
    print("Step 5: Generating LLM summary...")

    llm_prompt = f"""
You are a radiology AI assistant. Based on the following inputs, generate a structured medical imaging report including:
(1) Findings, (2) Quantified metrics, (3) Suggested next steps, (4) Literature context, (5) Uncertainty/limitations.

Image Path: {png_path}
Mask Path: {mask_path}
Relevant Literature: {literature}
"""

    model = genai.GenerativeModel("gemini-2.5-flash")
    response = model.generate_content(llm_prompt)

    # extract txt
    summary_text = response.text if hasattr(response, "text") else str(response)
    return summary_text

#test
summary = generate_llm_summary(png_path, mask_path, metrics, literature)
print("LLM summary:\n", summary)

Step 5: Generating LLM summary...
LLM summary:
 **Medical Imaging Report - AI Assistant Generated (Simulated)**

**Patient Information:**
*   Patient ID: Not provided
*   Date of Study: [Current Date]
*   Modality: Not specified (AI assumed a generic imaging study, potentially mammography or ultrasound given context)

---

**1. Findings:**
Upon AI-driven analysis of the provided image (`DIP Project/outputs/Unknown-8_test.png`) and its associated mask (`DIP Project/outputs/Unknown-8_test_mask.png`), a distinct region of interest (ROI) has been identified and precisely delineated by the mask. This ROI, interpreted hypothetically as a lesion or abnormality, exhibits the following simulated characteristics:

*   **Morphology:** Irregular shape with speculated margins, indicative of an infiltrative process.
*   **Internal Characteristics:** Heterogeneous internal texture (simulated, consistent with solid mass components).
*   **Surrounding Tissue:** Mild architectural distortion is hypothet

In [5]:
#6 完整 Pipeline v4.2

def run_pipeline_v4_2(dicom_file, clinical_csv_path="DIP Project/data/clinical_data_preprocessed.csv"):
    print(f"=== Running pipeline v4.2 on {dicom_file} ===")
    
    # 1. DICOM -> PNG
    png_path, ds = dicom_to_png(dicom_file)
    
    # 2. MedSAM segmentation
    mask_path, mask = run_medsam_segmentation(png_path)
    
    # 3. Patient features
    features = get_patient_features(ds, clinical_csv_path)
    
    # 4. PubMed
    literature = search_pubmed(features)
    
    # 5. Metrics placeholder (可后续加入 DICE/IoU)
    metrics = {}
    
    # 6. LLM summary
    summary = generate_llm_summary(png_path, mask_path, metrics, literature)
    
    logging.info("=== Pipeline v4.2 completed ===")
    return summary

# 测试：
dicom_file = "DIP_Project/data/Unknown-8.dcm"
summary = run_pipeline_v4_2(dicom_file)
print(summary)

=== Running pipeline v4.2 on DIP_Project/data/Unknown-8.dcm ===


NameError: name 'dicom_to_png' is not defined

In [42]:
%%writefile config.py
from pathlib import Path
import os

# Base directory for the DIP project
BASE_DIR = Path(__file__).resolve().parent

# Data directories
DATA_DIR = BASE_DIR / "data"
DICOM_DIR = DATA_DIR / "dicoms"
LESION_DIR = DATA_DIR / "lesions"

# Outputs and model directories
OUTPUT_DIR = BASE_DIR / "outputs"
MODEL_DIR = BASE_DIR / "models"

for d in [DATA_DIR, DICOM_DIR, LESION_DIR, OUTPUT_DIR, MODEL_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# MedSAM checkpoint
MEDSAM_CHECKPOINT = MODEL_DIR / "medsam_vit_b.pth"

# Clinical TSV from TCGA (pan-can atlas)
CLINICAL_TSV = DATA_DIR / "brca_tcga_pan_can_atlas_2018_clinical_data.tsv"

# PubMed email (required by NCBI Entrez)
PUBMED_EMAIL = "your_email@domain.com"

# Gemini API Key (must be set in environment)
GENAI_API_KEY = os.environ.get("GOOGLE_API_KEY")

if __name__ == "__main__":
    print("BASE_DIR    :", BASE_DIR)
    print("DICOM_DIR   :", DICOM_DIR)
    print("LESION_DIR  :", LESION_DIR)
    print("OUTPUT_DIR  :", OUTPUT_DIR)
    print("MODEL_DIR   :", MODEL_DIR)
    print("MEDSAM_CHECKPOINT :", MEDSAM_CHECKPOINT)
    print("CLINICAL_TSV:", CLINICAL_TSV)
    print("GENAI_API_KEY is None? ", GENAI_API_KEY is None)


Overwriting config.py


In [22]:
%%writefile step0_prepare_case.py
from pathlib import Path
import pandas as pd
from config import LESION_DIR, DATA_DIR

def run_step(context: dict) -> dict:
    """
    Step 0: Prepare case by:
    1) extracting patient ID from DICOM filename
    2) finding GT .les mask (if exists)
    3) loading clinical info (age, subtype, stage code)
    """

    dicom_path = Path(context["dicom_path"]).resolve()
    patient_id = dicom_path.stem

    # ---------- GT lesion ----------
    gt_path = LESION_DIR / f"{patient_id}.les"
    if gt_path.exists():
        gt_mask_path = str(gt_path)
        print(f"[Step0] Found GT lesion file: {gt_mask_path}")
    else:
        gt_mask_path = None
        print(f"[Step0] No GT lesion (.les) found for {patient_id}.")

    # ---------- Load TSV ----------
    tsv_path = DATA_DIR / "brca_tcga_pan_can_atlas_2018_clinical_data.tsv"

    df = pd.read_csv(tsv_path, sep="\t", low_memory=False)
    print(f"[Step0] Clinical TSV loaded ({len(df)} rows).")

    # Find row for this patient
    row = df[df["Patient ID"] == patient_id]

    if row.empty:
        print(f"[Step0] WARNING: Patient {patient_id} not found")
        age = None
        subtype = None
        stage_code = None
    else:
        row = row.iloc[0]

        age = row["Diagnosis Age"]
        subtype = row["Subtype"]

        # Combine 3 columns to form final stage code
        t_stage = str(row["American Joint Committee on Cancer Tumor Stage Code"])
        n_stage = str(row["Neoplasm Disease Lymph Node Stage American Joint Committee on Cancer Code"])
        m_stage = str(row["American Joint Committee on Cancer Metastasis Stage Code"])

        # correct format e.g. T1C_N0_M0
        stage_code = f"{t_stage}_{n_stage}_{m_stage}"

        print(f"[Step0] Clinical info — Age: {age}, Subtype: {subtype}, Stage: {stage_code}")

    return {
        "dicom_path": str(dicom_path),
        "patient_id": patient_id,
        "gt_mask_path": gt_mask_path,
        "age": age,
        "subtype": subtype,
        "stage_code": stage_code
    }


Overwriting step0_prepare_case.py


In [19]:
%%writefile step1_dicom_to_png.py
from pathlib import Path
import numpy as np
import pydicom
import cv2

from config import OUTPUT_DIR

def run_step(context: dict) -> dict:
    """
    Read DICOM → normalize → save as standard 8-bit PNG (0–255).
    This ensures OpenCV can read it in Step2.
    """
    dicom_path = Path(context["dicom_path"])
    patient_id = context.get("patient_id", dicom_path.stem)

    ds = pydicom.dcmread(str(dicom_path))
    arr = ds.pixel_array.astype(float)

    # Normalize to 0–255
    arr -= arr.min()
    arr /= (arr.max() + 1e-8)
    arr_uint8 = (arr * 255).astype(np.uint8)

    # If grayscale, make 3-channel for SAM compatibility
    img_rgb = cv2.cvtColor(arr_uint8, cv2.COLOR_GRAY2RGB)

    png_path = OUTPUT_DIR / f"{patient_id}.png"
    cv2.imwrite(str(png_path), img_rgb)

    print(f"[Step1] Saved PNG: {png_path}")

    return {
        "png_path": str(png_path),
        "dicom_ds": ds,
    }


Overwriting step1_dicom_to_png.py


In [28]:
%%writefile step2_medsam_segmentation.py
from pathlib import Path
from config import OUTPUT_DIR, MEDSAM_CHECKPOINT
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from segment_anything import sam_model_registry, SamPredictor


def preprocess_image_breast(bgr_img):
    """
    Basic preprocessing for breast MR:
    - convert to gray
    - normalize intensity to [0, 255]
    - apply CLAHE (local contrast enhancement)
    - light Gaussian blur
    Returns:
        pre_rgb: 3-channel RGB image for SAM
        pre_gray: single-channel preprocessed gray for seed selection
    """
    gray = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2GRAY)

    # Normalize to full 0-255 range
    gray_norm = cv2.normalize(gray, None, alpha=0, beta=255,
                              norm_type=cv2.NORM_MINMAX)
    gray_norm = gray_norm.astype(np.uint8)

    # CLAHE to enhance local contrast around lesions
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    gray_clahe = clahe.apply(gray_norm)

    # Light smoothing to reduce noise
    gray_blur = cv2.GaussianBlur(gray_clahe, (3, 3), 0)

    # SAM expects 3-channel RGB
    pre_rgb = cv2.cvtColor(gray_blur, cv2.COLOR_GRAY2RGB)

    return pre_rgb, gray_blur


def get_intensity_seed(gray_img, top_k_ratio=0.01):
    """
    Select high-intensity pixels as foreground seed points.
    Returns an (N, 2) array of (x, y) points.
    """
    flat = gray_img.flatten()
    n_pixels = len(flat)
    k = max(1, int(n_pixels * top_k_ratio))

    # indices of k brightest pixels
    top_idx = np.argpartition(flat, -k)[-k:]
    ys, xs = np.divmod(top_idx, gray_img.shape[1])
    pts = np.column_stack([xs, ys])
    return pts


def build_bbox_from_points(points, pad=25, shape=None):
    """
    Build a loose bounding box around the given points.
    shape: (H, W) of the image.
    """
    if points is None or len(points) == 0:
        return None

    xs, ys = points[:, 0], points[:, 1]
    x_min, x_max = xs.min(), xs.max()
    y_min, y_max = ys.min(), ys.max()

    if shape is not None:
        H, W = shape
        x_min = max(0, x_min - pad)
        y_min = max(0, y_min - pad)
        x_max = min(W - 1, x_max + pad)
        y_max = min(H - 1, y_max + pad)
    else:
        x_min -= pad
        y_min -= pad
        x_max += pad
        y_max += pad

    return np.array([x_min, y_min, x_max, y_max], dtype=np.int32)


def postprocess_mask(mask):
    """
    Keep only the largest connected component as final lesion.
    """
    mask_bin = (mask > 0).astype(np.uint8)
    num_labels, labels = cv2.connectedComponents(mask_bin)

    if num_labels <= 1:
        return mask_bin

    max_area = 0
    best_id = 0
    for lbl in range(1, num_labels):
        area = np.sum(labels == lbl)
        if area > max_area:
            max_area = area
            best_id = lbl

    return (labels == best_id).astype(np.uint8)


def classify_shape_from_mask(mask):
    """
    Very simple shape classifier based on contour circularity.
    Returns "Round-Oval" or "Irregular" or "Unknown".
    """
    mask_bin = (mask > 0).astype(np.uint8)
    contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL,
                                   cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return "Unknown"

    cnt = max(contours, key=cv2.contourArea)
    area = cv2.contourArea(cnt)
    peri = cv2.arcLength(cnt, closed=True)
    if peri == 0:
        return "Unknown"

    circularity = 4.0 * np.pi * area / (peri ** 2 + 1e-8)
    if circularity > 0.75:
        return "Round-Oval"
    else:
        return "Irregular"


def run_step(context: dict) -> dict:
    """
    Enhanced MedSAM segmentation step:
      - DICOM was already converted to PNG in Step1
      - Here we:
          * preprocess image (CLAHE + blur)
          * run MedSAM with intensity-based seeds + bounding box
          * post-process mask (largest component)
          * classify tumor shape
          * save mask & overlay
    """
    png_path = Path(context["png_path"])
    patient_id = context.get("patient_id", png_path.stem)

    print("[Step2] Running MedSAM segmentation...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[Step2] Using device: {device}")

    # Load MedSAM model checkpoint
    sam = sam_model_registry["vit_b"]()
    state = torch.load(str(MEDSAM_CHECKPOINT), map_location=device)
    sam.load_state_dict(state)
    sam.to(device)

    predictor = SamPredictor(sam)

    # --- Load original PNG image ---
    img_bgr = cv2.imread(str(png_path))
    if img_bgr is None:
        raise RuntimeError(f"[Step2] Failed to read PNG: {png_path}")
    img_rgb_orig = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    # --- Preprocess image for SAM ---
    img_rgb_pre, img_gray_pre = preprocess_image_breast(img_bgr)
    H, W = img_gray_pre.shape

    # --- Intensity-based seed points ---
    seed_points = get_intensity_seed(img_gray_pre, top_k_ratio=0.01)
    if seed_points.shape[0] > 10:
        fg_points = seed_points[:10]
    else:
        fg_points = seed_points
    point_labels = np.ones(len(fg_points), dtype=np.int32)

    # --- Auto bounding box from seeds ---
    bbox = build_bbox_from_points(seed_points, pad=30, shape=(H, W))
    if bbox is None:
        # Fallback: center box
        cx, cy = W // 2, H // 2
        bbox = np.array([cx - 32, cy - 32, cx + 32, cy + 32], dtype=np.int32)

    # --- Run SAM on preprocessed image ---
    predictor.set_image(img_rgb_pre)
    masks, scores, _ = predictor.predict(
        box=bbox,
        point_coords=fg_points,
        point_labels=point_labels,
        multimask_output=False,
    )

    raw_mask = masks[0]
    mask = postprocess_mask(raw_mask)
    mask_score = float(scores[0])

    # --- Shape classification from mask ---
    shape = classify_shape_from_mask(mask)

    # --- Save mask as PNG ---
    mask_path = OUTPUT_DIR / f"{patient_id}_mask.png"
    plt.imsave(mask_path, mask, cmap="gray")

    # --- Save overlay on ORIGINAL RGB image ---
    overlay = img_rgb_orig.copy()
    overlay[mask > 0] = [255, 0, 0]
    overlay_path = OUTPUT_DIR / f"{patient_id}_overlay.png"
    plt.imsave(overlay_path, overlay)

    print(f"[Step2] Mask score = {mask_score:.4f}, shape = {shape}")
    print(f"[Step2] Saved mask: {mask_path}")
    print(f"[Step2] Saved overlay: {overlay_path}")

    # Update context for later steps
    context["mask_path"] = str(mask_path)
    context["overlay_path"] = str(overlay_path)
    context["shape"] = shape

    return {
        "pred_mask": mask,
        "mask_score": mask_score,
        "mask_path": str(mask_path),
        "overlay_path": str(overlay_path),
        "shape": shape,
    }


Overwriting step2_medsam_segmentation.py


In [21]:
%%writefile step3_evaluation.py
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from config import OUTPUT_DIR

def load_les_polygon_mask(les_path: str, image_shape) -> np.ndarray:
    """Decode TCGA .les polygon → 0/1 GT mask."""
    data = np.fromfile(str(les_path), dtype=np.uint16)
    data = data[data != 0]  # remove paddings

    if len(data) % 2 != 0:
        data = data[:-1]

    pts = data.reshape((-1, 2)).astype(np.int32)

    h, w = image_shape
    mask = np.zeros((h, w), dtype=np.uint8)
    cv2.fillPoly(mask, [pts], 1)
    return mask

def compute_dice_iou(pred_mask: np.ndarray, gt_mask: np.ndarray):
    pred_bin = (pred_mask > 0).astype(np.uint8)
    gt_bin = (gt_mask > 0).astype(np.uint8)

    inter = np.sum(pred_bin * gt_bin)
    dice = 2 * inter / (np.sum(pred_bin) + np.sum(gt_bin) + 1e-8)

    union = np.sum((pred_bin + gt_bin) > 0)
    iou = inter / (union + 1e-8)
    return float(dice), float(iou)

def visualize_pred_vs_gt(img_gray, pred, gt, vis_path):
    """Overlay pred (RED) & GT (GREEN) contours on original image."""
    H, W = img_gray.shape
    canvas = np.stack([img_gray]*3, axis=-1)

    # compute contours
    pred_cont = (pred - cv2.erode(pred, None)) > 0
    gt_cont   = (gt   - cv2.erode(gt, None)) > 0

    # apply colors
    canvas[pred_cont] = [255, 0, 0]   # red
    canvas[gt_cont]   = [0, 255, 0]   # green

    plt.figure(figsize=(6,6))
    plt.imshow(canvas)
    plt.axis("off")
    plt.savefig(vis_path, dpi=200, bbox_inches="tight")
    plt.close()

def run_step(context: dict) -> dict:
    """
    Evaluate Dice / IoU and save visualization figure.
    """
    pred_mask = context.get("pred_mask")
    gt_mask_path = context.get("gt_mask_path")
    png_path = context.get("png_path")

    metrics = {}

    if pred_mask is None:
        print("[Step3] No pred_mask in context → skip.")
        return {"metrics": metrics}

    if gt_mask_path is None:
        print("[Step3] No GT .les file → skip DICE/IoU.")
        return {"metrics": metrics}

    print(f"[Step3] Evaluating Dice/IoU with GT: {gt_mask_path}")

    # load GT
    h, w = pred_mask.shape[:2]
    gt_mask = load_les_polygon_mask(gt_mask_path, (h, w))

    # compute Dice & IoU
    dice, iou = compute_dice_iou(pred_mask, gt_mask)
    metrics = {"dice": dice, "iou": iou}
    print(f"[Step3] Dice = {dice:.4f}, IoU = {iou:.4f}")

    # visualization
    if png_path:
        img = cv2.imread(png_path, cv2.IMREAD_GRAYSCALE)
        vis_path = OUTPUT_DIR / f"{Path(png_path).stem}_pred_vs_gt.png"
        visualize_pred_vs_gt(img, pred_mask, gt_mask, vis_path)
        print(f"[Step3] Saved visualization: {vis_path}")

        context["eval_vis_path"] = str(vis_path)

    return {"metrics": metrics, "eval_vis_path": context.get("eval_vis_path")}


Overwriting step3_evaluation.py


In [31]:
%%writefile step4_literature_search.py
from Bio import Entrez
from config import PUBMED_EMAIL


def run_step(context: dict) -> dict:
    """
    Search PubMed using "breast cancer" + (optional) tumor shape keyword.
    We request up to 20 most relevant papers and extract:
      - title
      - journal
      - year (if available)
    The result is stored in context["papers"] as a list of dicts.
    """
    shape = context.get("shape", "")
    # Base query: always include "breast cancer"
    if shape:
        query = f"breast cancer {shape}"
    else:
        query = "breast cancer"

    print(f"[Step4] Searching PubMed for: {query}")

    # Required by NCBI Entrez
    Entrez.email = PUBMED_EMAIL

    # esearch default sort is "relevance", but set it explicitly for clarity
    handle = Entrez.esearch(
        db="pubmed",
        term=query,
        retmax=20,        # ask up to 20 ids
        sort="relevance", # most relevant
    )
    result = Entrez.read(handle)
    handle.close()

    ids = result.get("IdList", [])
    print(f"[Step4] Retrieved {len(ids)} papers.")

    papers = []
    if not ids:
        context["papers"] = papers
        return {"papers": papers}

    # Fetch article details (XML)
    handle = Entrez.efetch(db="pubmed", id=",".join(ids), retmode="xml")
    records = Entrez.read(handle)
    handle.close()

    for article in records.get("PubmedArticle", []):
        article_data = article.get("MedlineCitation", {}).get("Article", {})
        title = article_data.get("ArticleTitle", "")

        journal_info = article_data.get("Journal", {})
        journal_title = journal_info.get("Title", "")

        pub_date = journal_info.get("JournalIssue", {}).get("PubDate", {})
        year = pub_date.get("Year", "")
        # Sometimes only MedlineDate exists, e.g. "2024 Jan-Feb"
        if not year and "MedlineDate" in pub_date:
            year = str(pub_date["MedlineDate"])

        papers.append(
            {
                "title": str(title),
                "journal": str(journal_title),
                "year": str(year),
            }
        )

    context["papers"] = papers
    return {"papers": papers}


Overwriting step4_literature_search.py


In [40]:
%%writefile step5_build_prompt.py
from pathlib import Path
from datetime import datetime

def run_step(context: dict) -> dict:
    """
    Build the prompt for the LLM/Gemini.
    Includes:
      - Clinical info (patient_id, age, subtype, stage)
      - Image paths
      - Tumor shape
      - Literature (titles/journals/years)
    """

    patient_id = context.get("patient_id", "")
    age = context.get("age", "")
    subtype = context.get("subtype", "")
    stage = context.get("stage_code", "")

    png_path = context.get("png_path", "")
    mask_path = context.get("mask_path", "")
    shape = context.get("shape", "")
    papers = context.get("papers", [])

    # Format literature text
    lit_section = "\n".join(
        [f"- {p['title']} ({p['journal']}, {p['year']})" for p in papers]
    )

    date_str = datetime.now().strftime("%Y-%m-%d")
    
    prompt = f"""
You are a radiology AI assistant. Generate a structured, clinically useful report based on the provided information.

Basic Clinical Information:
- Patient ID: {patient_id}
- Age: {age}
- Subtype: {subtype}
- Stage Code: {stage}
- Date of Report: {date_str}

Image Information:
- Original PNG: {png_path}
- Segmentation Mask: {mask_path}
- Predicted Tumor Shape: {shape}

Instructions:
Using all the information above (image, mask, shape, clinical data, and literature titles), generate a structured radiology-style report including:

1. **Findings**
   - Detailed description of tumor appearance in the *original PNG image*
   - Clear explanation of what the *segmentation mask* highlights vs. misses
   - Use precise radiology language (location, margins, enhancement pattern, architectural distortion, etc.)

2. **Literature Context**
   - Synthesize insights from the 20 provided paper titles
   - Integrate findings with patient-specific factors (age, subtype, stage, tumor shape)
   - Avoid quoting PMIDs or listing papers individually—provide an integrated discussion.

3. **Suggested Next Clinical Steps**
   - Provide actionable, realistic clinical recommendations
   - Include biopsy, MDT consultation, receptor testing, genetic counseling, imaging follow-up, etc.
   - Tailor recommendations to subtype (Luminal A), patient age, and shape characteristics.

4. **Uncertainty / Limitations**
   - Discuss segmentation accuracy limitations
   - Single-slice imaging limitations
   - Potential diagnostic uncertainty or areas needing further evaluation

Write the report in paragraphs, concise and clinically oriented.
"""

    context["llm_prompt"] = prompt
    print("[Step5] LLM prompt constructed.")
    return {"llm_prompt": prompt}


Overwriting step5_build_prompt.py


In [26]:
%%writefile step6_llm_summary.py
import google.generativeai as genai
from config import GENAI_API_KEY, OUTPUT_DIR
from pathlib import Path

def run_step(context: dict) -> dict:
    """
    Use Gemini-2.5-Flash to generate a final radiology report.
    """

    if GENAI_API_KEY is None:
        raise RuntimeError("GOOGLE_API_KEY not found")

    genai.configure(api_key=GENAI_API_KEY)
    model = genai.GenerativeModel("gemini-2.5-flash")

    prompt = context["llm_prompt"]

    print("[Step6] Calling Gemini-2.5-Flash Vision API...")

    response = model.generate_content(prompt)
    summary = response.text

    out_path = OUTPUT_DIR / f"{context['patient_id']}_summary.txt"
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(summary)

    print(f"[Step6] Summary saved to: {out_path}")

    return {"summary_path": str(out_path)}


Overwriting step6_llm_summary.py


In [9]:
%%writefile step7_run_agent.py
"""
Main LangChain-based workflow runner for DIP project.

Pipeline:
0) step0_prepare_case      - infer patient_id & GT .les path
1) step1_dicom_to_png      - DICOM → PNG
2) step2_medsam_segmentation - MedSAM segmentation
3) step3_evaluation        - Dice / IoU with .les polygon
4) step4_literature_search - PubMed: 'breast cancer'
5) step5_build_prompt      - build LLM prompt
6) step6_llm_summary       - call Gemini Vision to generate report
"""

from pathlib import Path
import glob

from langchain_core.runnables import RunnableLambda, RunnableSequence

from config import DICOM_DIR
from step0_prepare_case import run_step as step0
from step1_dicom_to_png import run_step as step1
from step2_medsam_segmentation import run_step as step2
from step3_evaluation import run_step as step3
from step4_literature_search import run_step as step4
from step5_build_prompt import run_step as step5
from step6_llm_summary import run_step as step6


def _wrap_step(name, func):
    """Wrap step(context) into RunnableLambda."""
    def inner(context: dict):
        print(f"\n===== Running {name} =====")
        result = func(context)
        if result is None:
            return context
        context.update(result)
        return context
    return RunnableLambda(inner)


def build_chain() -> RunnableSequence:
    """Build the LangChain workflow (sequence of steps)."""
    chain = RunnableSequence(
        _wrap_step("Step0: prepare case (patient_id & GT)", step0),
        _wrap_step("Step1: DICOM → PNG", step1),
        _wrap_step("Step2: MedSAM segmentation", step2),
        _wrap_step("Step3: Dice/IoU evaluation", step3),
        _wrap_step("Step4: PubMed 'breast cancer' search", step4),
        _wrap_step("Step5: build LLM prompt", step5),
        _wrap_step("Step6: Gemini Vision summary", step6),
    )
    return chain


def run_for_one_dicom(dicom_path: Path):
    """Complete the entire workflow for a single DICOM file"""
    chain = build_chain()
    context = {"dicom_path": str(dicom_path)}
    final_context = chain.invoke(context)
    return final_context


def run_batch(test_mode: bool = True):
    """
    test_mode = True  → Only run the first DICOM (for debugging)
    test_mode = False → Run all DICOM files in the DICOM_DIR directory.
    """
    dicom_files = sorted(
        list(DICOM_DIR.glob("*.dcm")) + list(DICOM_DIR.glob("*.dicom"))
    )

    if not dicom_files:
        print(f"No DICOM files found in: {DICOM_DIR}")
        return

    if test_mode:
        dicom_files = dicom_files[:1]
        print("[Runner] TEST_MODE = True, will only run 1 case.")
    else:
        print(f"[Runner] TEST_MODE = False, will run {len(dicom_files)} cases.")

    for dcm in dicom_files:
        print("\n=======================================")
        print("Running pipeline for:", dcm.name)
        print("=======================================")
        ctx = run_for_one_dicom(dcm)
        print("\n>>> Summary saved at:", ctx.get("summary_path"))


if __name__ == "__main__":
    TEST_MODE = True
    run_batch(test_mode=TEST_MODE)


Overwriting step7_run_agent.py


In [10]:
!python step7_run_agent.py

Traceback (most recent call last):
  File [35m"/Users/shenyuyu/DIP Project/step7_run_agent.py"[0m, line [35m21[0m, in [35m<module>[0m
    from step1_dicom_to_png import run_step as step1
  File [35m"/Users/shenyuyu/DIP Project/step1_dicom_to_png.py"[0m, line [35m3[0m, in [35m<module>[0m
    import pydicom
[1;35mModuleNotFoundError[0m: [35mNo module named 'pydicom'[0m


In [41]:
import sys
!{sys.executable} step7_run_agent.py

[Runner] TEST_MODE = True, will only run 1 case.

Running pipeline for: TCGA-AO-A03M.dcm

===== Running Step0: prepare case (patient_id & GT) =====
[Step0] Found GT lesion file: /Users/shenyuyu/DIP Project/data/lesions/TCGA-AO-A03M.les
[Step0] Clinical TSV loaded (1084 rows).
[Step0] Clinical info — Age: 29, Subtype: BRCA_LumA, Stage: T1C_N0 (I-)_M0

===== Running Step1: DICOM → PNG =====
[Step1] Saved PNG: /Users/shenyuyu/DIP Project/outputs/TCGA-AO-A03M.png

===== Running Step2: MedSAM segmentation =====
[Step2] Running MedSAM segmentation...
[Step2] Using device: cpu
[Step2] Mask score = 0.4267, shape = Irregular
[Step2] Saved mask: /Users/shenyuyu/DIP Project/outputs/TCGA-AO-A03M_mask.png
[Step2] Saved overlay: /Users/shenyuyu/DIP Project/outputs/TCGA-AO-A03M_overlay.png

===== Running Step3: Dice/IoU evaluation =====
[Step3] Evaluating Dice/IoU with GT: /Users/shenyuyu/DIP Project/data/lesions/TCGA-AO-A03M.les
[Step3] Dice = 0.0000, IoU = 0.0000
[Step3] Saved visualization: /Users

In [29]:
# 单独跑step2
from step2_medsam_segmentation import run_step

context = {
    "png_path": "/Users/shenyuyu/DIP Project/outputs/TCGA-AO-A03M.png",
    "patient_id": "TCGA-AO-A03M"
}

result = run_step(context)


[Step2] Running MedSAM segmentation...
[Step2] Using device: cpu
[Step2] Mask score = 0.4453, shape = Irregular
[Step2] Saved mask: /Users/shenyuyu/DIP Project/outputs/TCGA-AO-A03M_mask.png
[Step2] Saved overlay: /Users/shenyuyu/DIP Project/outputs/TCGA-AO-A03M_overlay.png


In [30]:
# run step 3
context.update(run_step3(context))

context["metrics"]
context["overlay_path"]
context["eval_vis_path"]


[Step3] No pred_mask in context → skip.


KeyError: 'eval_vis_path'