In [1]:
# Install gdown (if not already installed)
#!pip install gdown --upgrade --quiet

# install OpenAI API
#!pip install openai

# install tiktoken to count tokens
#!pip install tiktoken

# install package to handle tiff files 
#!pip install tifffile

#%pip install scikit-image

In [2]:
import io
import os
import re
import csv
import math
import base64
import shutil
import hashlib
import tifffile
import tiktoken
import textwrap
import numpy as np
import pandas as pd
from PIL import Image
from openai import OpenAI
from matplotlib import pyplot as plt
from typing import Dict, List, Tuple

In [3]:
def process_llm_output(input_text, caption_width=80):
    # Split the input text into fields and caption
    fields_text, caption_text = input_text.split("**Caption:**")

    # Process fields into bullet points
    fields_lines = fields_text.strip().split("\n")
    processed_fields = "Example Fields:\n"
    for line in fields_lines:
        if ":" in line:
            key, value = line.split(":", 1)
            processed_fields += f"- {key.strip()}: {value.strip()}\n"

    # Wrap the caption
    wrapped_caption = textwrap.fill(caption_text.strip(), width=caption_width)

    # Combine processed fields and wrapped caption
    formatted_output = f"{processed_fields}\nCaption:\n{wrapped_caption}"
    return formatted_output

def count_tokens(text, model="gpt-4o"):
    encoding = tiktoken.encoding_for_model(model)
    return len(encoding.encode(text))

In [None]:
openai_api_key = "your_token_here"
os.environ["OPENAI_API_KEY"] = openai_api_key

client = OpenAI()

example_caption = (
    "[Image 11]: This single cell is a podocyte, obtained from a mouse kidney using Visium "
    "HD technology. Podocytes are crucial for the glomerular filtration barrier, " 
    "characterized by elongated foot processes that tightly adhere to glomerular " 
    "capillaries, forming a highly selective filtration interface. Unlike previous " 
    "samples, this region exhibits a more loosely arranged tissue structure, with " 
    "distinct gaps between cells and a high density of dark-stained nuclei " 
    "surrounding the glomerulus. The image also shows numerous red blood cells, " 
    "indicating abundant blood flow essential for sustaining efficient filtration. " 
    "Additionally, some podocytes appear morphologically irregular, potentially " 
    "undergoing structural remodeling to adapt to pressure changes. Podocyte " 
    "injury can lead to proteinuria and glomerulosclerosis, significantly " 
    "compromising kidney filtration function."
)

system_prompt = (
    "You are a helpful, knowledgeable assistant who will explain the "
    "morphology of the cell image visible in the histopathology tissue patch. "
    "You must utilize the visual cues in the images and textual cues from given "
    "metadata. Also, include only ONE caption per image. Separate each caption " 
    "with the image number like `[Image #]`. Here is a sample caption for your reference: " 
    f"'{example_caption}'"
)

# Cell and Tissue Metadata
#  - [0]: cell type 
#  - [1]: cell disease state 
#  - [2]: tissue
#  - [3]: tissue disease state
#  - [4]: source (mouse or human)
#  - [5]: st technology
def cell_tissue_user_prompt(metadata):
    return (
        "Generate a concise, biologically accurate caption that describes the cellular morphology, including "
        "shape, size, structure, and function of the cell in the histopathology tissue patch. The " 
        f"single cell image is a {'cancerous' if metadata[1]=='1' else 'non-cancerous'} {metadata[0]} cell. "
        f"It is located in the {metadata[4]}\'s {metadata[2]}, "
        f"which is {'cancerous' if metadata[3]=='1' else 'not cancerous'}. The images were "
        f"obtained using {metadata[5]} spatial transcriptomics technology."
    )

def encode_image(image_path):
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

In [5]:
def load_data(root_dir):
    all_prompts = []
    all_images = []
    all_cell_ids = []
    all_metadata = []

    expected_cols = [
        "cell-type",
        "cell disease state",
        "tissue",
        "tissue disease state",
        "source",
        "st technology"
    ]

    folders = sorted(os.listdir(root_dir))
    for folder in folders:
        folder_path = os.path.join(root_dir, folder)
        if not os.path.isdir(folder_path):
            continue

        point_id = folder  # e.g., point_13442
        cell_img = os.path.join(folder_path, f"{point_id}.png")
        tissue_img = os.path.join(folder_path, f"{point_id}_tissue.png")
        attr_path = os.path.join(folder_path, "attribute.csv")

        if not (os.path.exists(cell_img) and os.path.exists(tissue_img) and os.path.exists(attr_path)):
            continue

        try:
            df = pd.read_csv(attr_path)
            if df.empty:
                continue

            row = df.iloc[0]
            if not all(col in row for col in expected_cols):
                continue

            metadata = [str(row[col]) for col in expected_cols]
            prompt = cell_tissue_user_prompt(metadata)

        except Exception as e:
            print(f"Skipping {folder} due to error: {e}")
            continue

        all_cell_ids.append(point_id)
        all_metadata.append(metadata)
        all_prompts.append(prompt)
        all_images.append([cell_img, tissue_img])

    return all_cell_ids, all_metadata, all_prompts, all_images

In [6]:
def create_batches(cell_ids, metadata_list, prompts, images, batch_size=10):
    batches = []
    for i in range(0, len(prompts), batch_size):
        batch = (
            cell_ids[i:i+batch_size],
            metadata_list[i:i+batch_size],
            prompts[i:i+batch_size],
            images[i:i+batch_size]
        )
        batches.append(batch)
    return batches

In [7]:
import csv

def append_to_csv(csv_path, cell_ids, metadata_list, captions):
    header = ["cell_id", "cell-type", "cell disease state", "tissue",
              "tissue disease state", "source", "st technology", "caption"]
    file_exists = os.path.exists(csv_path)

    with open(csv_path, mode='a', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)

        # Write header once
        if not file_exists:
            writer.writerow(header)

        for cid, meta, cap in zip(cell_ids, metadata_list, captions):
            writer.writerow([cid] + meta + [cap])

In [8]:
def batch_llm_request(client, system_prompt, prompt_list, image_lists):
    """
    Sends a single GPT-4o request containing multiple (prompt + 2 images) entries.

    Parameters:
        client: OpenAI client
        system_prompt: str, system message with task definition
        prompt_list: list of strings, user prompts for each cell-tissue pair
        image_lists: list of [cell_img_path, tissue_img_path] for each entry

    Returns:
        response: OpenAI API response
    """
    messages = [{"role": "system", "content": system_prompt}]
    user_content = []

    for i, (prompt, image_pair) in enumerate(zip(prompt_list, image_lists)):
        cell_img_path, tissue_img_path = image_pair

        # Add prompt text with identifier
        user_content.append({
            "type": "text",
            "text": f"[Image Pair {i+1}] {prompt}"
        })

        # Add cell image
        user_content.append({
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{encode_image(cell_img_path)}"
            }
        })

        # Add tissue image
        user_content.append({
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{encode_image(tissue_img_path)}"
            }
        })

    messages.append({"role": "user", "content": user_content})

    try:
        response = client.chat.completions.create(
            model="gpt-4.1",
            messages=messages
        )
        return response
    except Exception as e:
        print(f"Error in batch request: {e}")
        return None

In [None]:
def run_captioning_pipeline(root_dir, batch_size=10, save_path="captions_output_a.csv"):
    cell_ids, metadata_list, prompts, image_paths = load_data(root_dir)
    print(f"Loaded {len(prompts)} valid entries.")

    batches = create_batches(cell_ids, metadata_list, prompts, image_paths, batch_size=batch_size)

    responses = []
    token_counts = []

    for idx, (cid_batch, meta_batch, prompt_batch, image_batch) in enumerate(batches):
        print(f"\nBatch #{idx+1}/{len(batches)}")
        response = batch_llm_request(client, system_prompt, prompt_batch, image_batch)
        responses.append(response)

        if response:
            # Extract captions
            text = response.choices[0].message.content
            captions = [line.strip() for line in text.split("\n") if line.strip()]

            if len(captions) != len(cid_batch):
                print(f"⚠️ Warning: Mismatch in caption count (got {len(captions)}, expected {len(cid_batch)}). Skipping save.")
                continue

            # Save to CSV
            append_to_csv(save_path, cid_batch, meta_batch, captions)
            print(f"✅ Batch {idx+1} saved to {save_path}")

            # Track token usage
            if hasattr(response, "usage"):
                tokens = response.usage.total_tokens
                token_counts.append(tokens)
                print(f"Tokens used: {tokens}")
        else:
            print(f"❌ Batch {idx+1} failed. Response is None.")

    return responses, token_counts

In [10]:
root_path = "../../../../../extra/zhanglab0/xil43/HD/a/point/"

responses, token_counts = run_captioning_pipeline(root_path, batch_size=10)

Loaded 1000 valid entries.

Batch #1/100
✅ Batch 1 saved to captions_output.csv
Tokens used: 6648

Batch #2/100
✅ Batch 2 saved to captions_output.csv
Tokens used: 6750

Batch #3/100
✅ Batch 3 saved to captions_output.csv
Tokens used: 6717

Batch #4/100
✅ Batch 4 saved to captions_output.csv
Tokens used: 6739

Batch #5/100
✅ Batch 5 saved to captions_output.csv
Tokens used: 6717

Batch #6/100
✅ Batch 6 saved to captions_output.csv
Tokens used: 6732

Batch #7/100
✅ Batch 7 saved to captions_output.csv
Tokens used: 6768

Batch #8/100
✅ Batch 8 saved to captions_output.csv
Tokens used: 6753

Batch #9/100
✅ Batch 9 saved to captions_output.csv
Tokens used: 6761

Batch #10/100
✅ Batch 10 saved to captions_output.csv
Tokens used: 6814

Batch #11/100
✅ Batch 11 saved to captions_output.csv
Tokens used: 6713

Batch #12/100
✅ Batch 12 saved to captions_output.csv
Tokens used: 6724

Batch #13/100
✅ Batch 13 saved to captions_output.csv
Tokens used: 6825

Batch #14/100
✅ Batch 14 saved to caption


---


In [21]:
responses

[ChatCompletion(id='chatcmpl-BY16F4WcTzuIX31eKV6D00eeMwcXx', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="1. This single cell is a podocyte from a non-cancerous mouse kidney, captured using Visium HD technology. Podocytes exhibit distinctive foot processes crucial for the glomerular filtration barrier. The tissue section shows a dense arrangement of these cells around glomerular structures, with dark-stained nuclei and well-defined cytoplasmic extensions.\n\n2. This non-cancerous proximal tubule cell from a mouse kidney is visualized using Visium HD technology. Proximal tubule cells have abundant cytoplasm and a brush border for reabsorption. The tissue section depicts tightly packed tubular cells with elongated nuclei, indicative of active transport and filtration function.\n\n3. This is a podocyte from a non-cancerous mouse kidney, imaged with Visium HD technology. Podocytes form the filtration barrier with their interdigitating

In [None]:
def show_batch(this_batch):
    print(f"Prompt: {this_batch[0]}")
    image_paths = this_batch[1]
    
    num_subplots = len(image_paths)
    fig, axes = plt.subplots(math.ceil(num_subplots/10), 10, figsize=(10,3))  # 2 rows, 5 columns
    axes = axes.flatten()  # make it easier to index
    for i, path in enumerate(image_paths):
        img = Image.open(path)
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title(f'Image {i+1}')
    plt.tight_layout()
    plt.show()

#this_batch = batches[3]
#show_batch(this_batch)

In [22]:
text_responses = [response.choices[0].message.content for response in responses]
llm_outputs = ' '.join(text_responses)
#outputs = [x for x in re.split(r'###|\d+\.', llm_outputs) if x.strip()]
#text_responses = [x for x in re.split(r'\n\n', llm_outputs) if x.strip()]
#text_responses = [x for x in re.split(r'\d+', llm_outputs) if x.strip()]
outputs = []
for text_resp in text_responses:
    indiv = [x.strip() for x in re.split(r'\d+', text_resp) if len(x) > 10]
    outputs += indiv    
        
outputs

['. This single cell is a podocyte from a non-cancerous mouse kidney, captured using Visium HD technology. Podocytes exhibit distinctive foot processes crucial for the glomerular filtration barrier. The tissue section shows a dense arrangement of these cells around glomerular structures, with dark-stained nuclei and well-defined cytoplasmic extensions.',
 '. This non-cancerous proximal tubule cell from a mouse kidney is visualized using Visium HD technology. Proximal tubule cells have abundant cytoplasm and a brush border for reabsorption. The tissue section depicts tightly packed tubular cells with elongated nuclei, indicative of active transport and filtration function.',
 '. This is a podocyte from a non-cancerous mouse kidney, imaged with Visium HD technology. Podocytes form the filtration barrier with their interdigitating foot processes. The histological section reveals a dense clustering of podocytes with prominent nuclei and visible capillary loops, demonstrating their role in 

In [15]:
# Define CSV file path
csv_filename = f"Captions/captions_{num_batches}b_{batch_size}sz_{start_dir_idx}-{end_dir_idx-1}_dirs.csv"

# Identify rows to be removed (indices of captions that contain "unable" or "blank")
rows_to_be_removed = [i for i, resp in enumerate(outputs) if "unable" in resp or "blank" in resp]

# Extract valid captions (removing those in `rows_to_be_removed`)
captions = [resp for i, resp in enumerate(outputs) if i not in rows_to_be_removed]

# Flatten image_dict to [(directory, image), ...]
image_list = [(directory, image) for directory, images in image_dict.items() for image in images]

# Filter images based on the same indices in `rows_to_be_removed`
filtered_image_list = [img for i, img in enumerate(image_list) if i not in rows_to_be_removed]

# Ensure there are enough captions for the images
num_images = len(filtered_image_list)
num_captions = len(captions)

print("  # of images:", num_images)
print("# of captions:", num_captions)

if num_captions < num_images:
    print(f"Warning: Fewer captions ({num_captions}) than images ({num_images}). Some images will be left without captions.")
elif num_captions > num_images:
    print(f"Warning: More captions ({num_captions}) than images ({num_images}). Some captions will be unused.")


# Ensure both lists have the same length by truncating to the minimum size
min_length = min(len(filtered_image_list), len(captions))
filtered_image_list = filtered_image_list[:min_length]
captions = captions[:min_length]

# Pair filtered images with captions
paired_data = [
    (dir_dict[directory], image.split("/")[-1], captions[i])
    for i, (directory, image) in enumerate(filtered_image_list)
]

# Write to CSV
with open(csv_filename, mode="w", newline="", encoding="utf-8") as file:
    writer = csv.writer(file)
    writer.writerow(["Directory_Name", "Patch_Name", "Caption"])  # Header
    writer.writerows(paired_data)

print(f"Data saved to {csv_filename}")

  # of images: 498
# of captions: 498
Data saved to Captions/captions_50b_10sz_20-29_dirs.csv
