In [1]:
# Install gdown (if not already installed)
#!pip install gdown --upgrade --quiet

# install OpenAI API
#!pip install openai

# install tiktoken to count tokens
#!pip install tiktoken

# install package to handle tiff files 
#!pip install tifffile

#%pip install scikit-image

In [2]:
import io
import os
import re
import csv
import math
import base64
import shutil
import hashlib
import tempfile
import tifffile
import tiktoken
import textwrap
import numpy as np
import pandas as pd
from PIL import Image
from openai import OpenAI
from matplotlib import pyplot as plt
from typing import Dict, List, Tuple

In [3]:
def convert_tif_to_png(tif_path):
    image_array = tifffile.imread(tif_path)

    # Normalize to 8-bit
    if image_array.dtype != np.uint8:
        image_array = np.clip(image_array, 0, 255)
        image_array = (image_array / image_array.max() * 255).astype(np.uint8)

    # Shape handling
    if len(image_array.shape) == 2:
        pil_image = Image.fromarray(image_array, mode='L')
    elif image_array.shape[2] >= 3:
        pil_image = Image.fromarray(image_array[:, :, :3], mode='RGB')
    else:
        raise ValueError(f"Unsupported image shape: {image_array.shape}")

    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    pil_image.save(temp_file.name, "PNG")
    return temp_file.name

In [4]:
def process_llm_output(input_text, caption_width=80):
    # Split the input text into fields and caption
    fields_text, caption_text = input_text.split("**Caption:**")

    # Process fields into bullet points
    fields_lines = fields_text.strip().split("\n")
    processed_fields = "Example Fields:\n"
    for line in fields_lines:
        if ":" in line:
            key, value = line.split(":", 1)
            processed_fields += f"- {key.strip()}: {value.strip()}\n"

    # Wrap the caption
    wrapped_caption = textwrap.fill(caption_text.strip(), width=caption_width)

    # Combine processed fields and wrapped caption
    formatted_output = f"{processed_fields}\nCaption:\n{wrapped_caption}"
    return formatted_output

def count_tokens(text, model="gpt-4o"):
    encoding = tiktoken.encoding_for_model(model)
    return len(encoding.encode(text))

In [None]:
openai_api_key = "your_token_here"
os.environ["OPENAI_API_KEY"] = openai_api_key

client = OpenAI()

example_caption = (
    "[Image 11]: This single cell is a podocyte, obtained from a mouse kidney using Visium "
    "HD technology. Podocytes are crucial for the glomerular filtration barrier, " 
    "characterized by elongated foot processes that tightly adhere to glomerular " 
    "capillaries, forming a highly selective filtration interface. Unlike previous " 
    "samples, this region exhibits a more loosely arranged tissue structure, with " 
    "distinct gaps between cells and a high density of dark-stained nuclei " 
    "surrounding the glomerulus. The image also shows numerous red blood cells, " 
    "indicating abundant blood flow essential for sustaining efficient filtration. " 
    "Additionally, some podocytes appear morphologically irregular, potentially " 
    "undergoing structural remodeling to adapt to pressure changes. Podocyte " 
    "injury can lead to proteinuria and glomerulosclerosis, significantly " 
    "compromising kidney filtration function."
)

system_prompt = (
    "You are a helpful, knowledgeable assistant who will explain the "
    "morphology of the cell image visible in the histopathology tissue patch. "
    "You must utilize the visual cues in the images and textual cues from given "
    "metadata. Also, include only ONE caption per image. Separate each caption " 
    "with the image number like `[Image #]`. Here is a sample caption for your reference: " 
    f"'{example_caption}'"
)

# Cell and Tissue Metadata
#  - [0]: cell type 
#  - [1]: cell disease state 
#  - [2]: tissue
#  - [3]: tissue disease state
#  - [4]: source (mouse or human)
#  - [5]: st technology
def cell_tissue_user_prompt(metadata):
    return (
        "Generate a concise, biologically accurate caption that describes the cellular morphology, including "
        "shape, size, structure, and function of the cell in the histopathology tissue patch. The " 
        f"single cell image is a {'cancerous' if metadata[1]=='1' else 'non-cancerous'} {metadata[0]} cell. "
        f"It is located in the {metadata[4]}\'s {metadata[2]}, "
        f"which is {'cancerous' if metadata[3]=='1' else 'not cancerous'}. The images were "
        f"obtained using {metadata[5]} spatial transcriptomics technology."
    )

def encode_image(image_path):
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

In [6]:
def load_data_from_xenium_single_folder(folder_path):
    all_cell_ids = []
    all_metadata = []
    all_prompts = []
    all_images = []

    required_cols = ["cell_id", "cell_type", "cell_disease_state", "tissue_disease_state"]

    # Parse fixed metadata from folder name
    base_name = os.path.basename(folder_path.rstrip("/"))
    parts = base_name.split("_")
    if len(parts) < 3:
        raise ValueError(f"Folder name '{base_name}' must be like 'Xenium_Mouse_Femur'")

    st_tech = parts[0]
    source = parts[1]
    tissue = parts[2]

    sample_dirs = sorted([
        f for f in os.listdir(folder_path)
        if os.path.isdir(os.path.join(folder_path, f))
    ])

    print(f"📂 Found {len(sample_dirs)} samples in {base_name}")

    for sample_id in sample_dirs:
        sample_path = os.path.join(folder_path, sample_id)
        cell_tif = os.path.join(sample_path, "cell.tif")
        tissue_tif = os.path.join(sample_path, "patch.tif")
        attr_csv = os.path.join(sample_path, "attributes.csv")

        if not all(os.path.exists(p) for p in [cell_tif, tissue_tif, attr_csv]):
            print(f"⛔ Missing files in {sample_id}")
            continue

        try:
            df = pd.read_csv(attr_csv)
            if df.empty:
                continue
            row = df.iloc[0]
            if not all(col in row for col in required_cols):
                print(f"⛔ Missing columns in {sample_id}")
                continue

            cell_id = str(row["cell_id"])
            cell_type = str(row["cell_type"])
            cell_dis = str(row["cell_disease_state"])
            tissue_dis = str(row["tissue_disease_state"])

        except Exception as e:
            print(f"⛔ Error reading {sample_id}/attributes.csv: {e}")
            continue

        # Convert .tif to .png
        try:
            cell_png = convert_tif_to_png(cell_tif)
            tissue_png = convert_tif_to_png(tissue_tif)
        except Exception as e:
            print(f"⛔ TIFF conversion failed for {sample_id}: {e}")
            continue

        # Compose metadata and prompt
        metadata = [
            cell_type,
            cell_dis,
            tissue,
            tissue_dis,
            source,
            st_tech
        ]

        prompt = cell_tissue_user_prompt(metadata)

        all_cell_ids.append(cell_id)
        all_metadata.append(metadata)
        all_prompts.append(prompt)
        all_images.append([cell_png, tissue_png])

    print(f"✅ Loaded {len(all_prompts)} samples from {base_name}")
    return all_cell_ids, all_metadata, all_prompts, all_images


In [7]:
def create_batches(cell_ids, metadata_list, prompts, images, batch_size=10):
    batches = []
    for i in range(0, len(prompts), batch_size):
        batch = (
            cell_ids[i:i+batch_size],
            metadata_list[i:i+batch_size],
            prompts[i:i+batch_size],
            images[i:i+batch_size]
        )
        batches.append(batch)
    return batches

In [8]:
import csv

def append_to_csv(csv_path, cell_ids, metadata_list, captions):
    header = ["cell_id", "cell-type", "cell disease state", "tissue",
              "tissue disease state", "source", "st technology", "caption"]
    file_exists = os.path.exists(csv_path)

    with open(csv_path, mode='a', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)

        # Write header once
        if not file_exists:
            writer.writerow(header)

        for cid, meta, cap in zip(cell_ids, metadata_list, captions):
            writer.writerow([cid] + meta + [cap])

In [9]:
def batch_llm_request(client, system_prompt, prompt_list, image_lists):
    """
    Sends a single GPT-4o request containing multiple (prompt + 2 images) entries.

    Parameters:
        client: OpenAI client
        system_prompt: str, system message with task definition
        prompt_list: list of strings, user prompts for each cell-tissue pair
        image_lists: list of [cell_img_path, tissue_img_path] for each entry

    Returns:
        response: OpenAI API response
    """
    messages = [{"role": "system", "content": system_prompt}]
    user_content = []

    for i, (prompt, image_pair) in enumerate(zip(prompt_list, image_lists)):
        cell_img_path, tissue_img_path = image_pair

        # Add prompt text with identifier
        user_content.append({
            "type": "text",
            "text": f"[Image Pair {i+1}] {prompt}"
        })

        # Add cell image
        user_content.append({
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{encode_image(cell_img_path)}"
            }
        })

        # Add tissue image
        user_content.append({
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{encode_image(tissue_img_path)}"
            }
        })

    messages.append({"role": "user", "content": user_content})

    try:
        response = client.chat.completions.create(
            model="gpt-4.1",
            messages=messages
        )
        return response
    except Exception as e:
        print(f"Error in batch request: {e}")
        return None

In [15]:
from pathlib import Path

def run_captioning_pipeline_single_output(root_dir, batch_size=10):
    """
    Runs the captioning pipeline for a single organ directory.
    Outputs a single CSV file named captions_{source}_{organ}.csv.
    """

    # === Step 1: Load data ===
    cell_ids, metadata_list, prompts, image_paths = load_data_from_xenium_single_folder(root_dir)
    print(f"✅ Loaded {len(prompts)} entries from {root_dir}")

    # === Step 2: Derive output filename from folder name ===
    folder_name = Path(root_dir.rstrip("/")).name
    parts = folder_name.split("_")
    if len(parts) < 3:
        raise ValueError(f"Folder name '{folder_name}' must be of the form 'Xenium_Source_Organ'")

    source = parts[1].lower()
    organ = parts[2].lower()
    csv_path = f"captions_{source}_{organ}.csv"

    # === Step 3: Create batches ===
    batches = create_batches(cell_ids, metadata_list, prompts, image_paths, batch_size=batch_size)

    responses = []
    token_counts = []

    # === Step 4: Process each batch ===
    for idx, (cid_batch, meta_batch, prompt_batch, image_batch) in enumerate(batches):
        print(f"\n🌀 Processing Batch #{idx + 1}/{len(batches)}")

        response = batch_llm_request(client, system_prompt, prompt_batch, image_batch)
        responses.append(response)

        if response:
            text = response.choices[0].message.content

            # Split captions by newlines
            captions = [line.strip() for line in text.split("\n") if line.strip()]

            if len(captions) != len(cid_batch):
                print(f"⚠️ Warning: Mismatch in caption count (got {len(captions)}, expected {len(cid_batch)}). Skipping batch.")
                continue

            append_to_csv(csv_path, cid_batch, meta_batch, captions)
            print(f"✅ Saved Batch #{idx + 1} to {csv_path}")

            if hasattr(response, "usage"):
                tokens = response.usage.total_tokens
                token_counts.append(tokens)
                print(f"🔢 Tokens used: {tokens}")
        else:
            print(f"❌ GPT call failed for Batch #{idx + 1}")

    return responses, token_counts

In [None]:
root_path = "../../../../../extra/zhanglab0/xil43/Xenium/CellNet_data/prepare_captioning/Xenium_Mouse_Femur"
#root_path = "../../../../../extra/zhanglab0/xil43/Xenium/CellNet_data/prepare_captioning/Xenium_Human_Prostate"
#root_path = "../../../../../extra/zhanglab0/xil43/Xenium/CellNet_data/prepare_captioning/Xenium_Human_Ovary"

responses, token_counts = run_captioning_pipeline_single_output(root_path, batch_size=10)

📂 Found 1000 samples in Xenium_Mouse_Femur
✅ Loaded 1000 samples from Xenium_Mouse_Femur
✅ Loaded 1000 entries from ../../../../../extra/zhanglab0/xil43/Xenium/CellNet_data/prepare_captioning/Xenium_Mouse_Femur

🌀 Processing Batch #1/100
✅ Saved Batch #1 to captions_mouse_femur.csv
🔢 Tokens used: 6970

🌀 Processing Batch #2/100
✅ Saved Batch #2 to captions_mouse_femur.csv
🔢 Tokens used: 6824

🌀 Processing Batch #3/100
✅ Saved Batch #3 to captions_mouse_femur.csv
🔢 Tokens used: 6812

🌀 Processing Batch #4/100

🌀 Processing Batch #5/100
✅ Saved Batch #5 to captions_mouse_femur.csv
🔢 Tokens used: 6837

🌀 Processing Batch #6/100
✅ Saved Batch #6 to captions_mouse_femur.csv
🔢 Tokens used: 6852

🌀 Processing Batch #7/100
✅ Saved Batch #7 to captions_mouse_femur.csv
🔢 Tokens used: 6784

🌀 Processing Batch #8/100
✅ Saved Batch #8 to captions_mouse_femur.csv
🔢 Tokens used: 6729

🌀 Processing Batch #9/100
✅ Saved Batch #9 to captions_mouse_femur.csv
🔢 Tokens used: 6763

🌀 Processing Batch #10/10


---
