In [None]:
# ==============================================================================
# Tutorial 1: Data Preprocessing Pipeline
# ==============================================================================
# Goal: Demonstrate how PhenoSSP processes raw mIF images into cell-centered patches.
# Steps:
# 1. Load a raw multiplex immunofluorescence (mIF) image.
# 2. Perform whole-cell segmentation using DeepCell Mesmer.
# 3. Extract 64x64 pixel single-cell patches based on segmentation masks.
# ==============================================================================


import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import tifffile
from skimage import measure
from deepcell.applications import Mesmer  # Requires deepcell installed

# --- 1. Configuration & Setup ---
# Define paths relative to the repository root for portability
current_dir = os.path.dirname(os.path.abspath("__file__")) # For .py script
# current_dir = os.getcwd() # Use this if running in Jupyter Notebook
project_root = os.path.dirname(current_dir)
demo_data_dir = os.path.join(project_root, 'demo_data')

# Ensure directories exist
raw_image_path = os.path.join(demo_data_dir, 'raw_images', 'sample_01.ome.tiff')
output_patch_dir = os.path.join(demo_data_dir, 'patches')
os.makedirs(output_patch_dir, exist_ok=True)

# Define Channel Names (Based on your dataset)
CHANNEL_NAMES = ['DAPI', 'CD8', 'FoxP3', 'PanCK', 'PD1', 'CD4', 'CD3']
PATCH_SIZE = 64

print(f"üìÇ Data Source: {raw_image_path}")
print(f"üìÇ Output Dir: {output_patch_dir}")


In [None]:
# --- 2. Load Raw Image ---
def load_and_preprocess_image(image_path):
    """
    Loads a .tiff image and prepares it for Mesmer segmentation.
    Mesmer expects input shape: (batch, height, width, channels)
    """
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Sample image not found at {image_path}. Please place a .tiff file there.")
    
    # Load image (Channels, Height, Width) -> (Height, Width, Channels)
    with tifffile.TiffFile(image_path) as tif:
        image = tif.asarray()
    
    # Transpose if necessary (ensure channels are last)
    if image.shape[0] == 7: 
        image = np.transpose(image, (1, 2, 0))
    
    print(f"‚úÖ Loaded Image Shape: {image.shape}")
    
    # Prepare input for Mesmer (Add batch dimension)
    # Mesmer uses Nuclear (Ch 0) and Membrane signals. 
    # Here we sum relevant membrane markers (CD3, CD4, CD8, PanCK) for the membrane channel.
    nuclear = image[..., 0:1] # DAPI
    membrane_channels = image[..., [1, 3, 4, 5, 6]] # CD8, PanCK, PD1, CD4, CD3
    membrane = np.sum(membrane_channels, axis=2, keepdims=True)
    
    # Stack into (Batch, H, W, 2) for segmentation
    mesmer_input = np.stack((nuclear[..., 0], membrane[..., 0]), axis=-1)
    mesmer_input = np.expand_dims(mesmer_input, axis=0) 
    
    return image, mesmer_input

# Run Load
raw_image, segmentation_input = load_and_preprocess_image(raw_image_path)

In [None]:
# --- 3. Cell Segmentation (Mesmer) ---
def run_segmentation(input_data):
    """
    Runs DeepCell Mesmer to generate cell masks.
    """
    print("ü§ñ Initializing Mesmer Model...")
    app = Mesmer()
    
    print("‚è≥ Running Segmentation (this may take a moment)...")
    # image_mpp=0.5 corresponds to 20x magnification usually
    predictions = app.predict(input_data, image_mpp=0.5)
    
    # Extract the first batch and the first channel (Cell Mask)
    segmentation_mask = predictions[0, ..., 0]
    
    print(f"‚úÖ Segmentation Complete. Found {len(np.unique(segmentation_mask)) - 1} cells.")
    return segmentation_mask

# Run Segmentation
mask = run_segmentation(segmentation_input)

# [Visualization 1] Show Segmentation Result
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Nuclear Channel (DAPI)")
plt.imshow(raw_image[..., 0], cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Segmentation Mask")
plt.imshow(mask, cmap='nipy_spectral', interpolation='nearest')
plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# --- 4. Patch Extraction ---
def extract_patches(image, mask, save_dir, patch_size=64):
    """
    Crops cell-centered patches from the raw image based on the segmentation mask.
    """
    print(f"‚úÇÔ∏è Extracting {patch_size}x{patch_size} patches...")
    
    props = measure.regionprops(mask)
    half_size = patch_size // 2
    count = 0
    
    # Iterate over specific cells (e.g., first 5 for demo) or all cells
    # For tutorial, we limit to 5 examples to keep it fast
    demo_indices = range(min(5, len(props))) 
    
    extracted_patches = []
    
    for i in demo_indices:
        prop = props[i]
        
        # Filter small artifacts
        if prop.area < 20: continue
        
        # Get Centroid
        cy, cx = map(int, prop.centroid)
        
        # Calculate Crop Coordinates with Padding handling
        min_y, max_y = cy - half_size, cy + half_size
        min_x, max_x = cx - half_size, cx + half_size
        
        # Create a padded canvas
        patch = np.zeros((patch_size, patch_size, image.shape[2]), dtype=image.dtype)
        
        # Handle boundaries (if cell is at the edge of image)
        img_h, img_w = image.shape[:2]
        
        src_min_y, src_max_y = max(0, min_y), min(img_h, max_y)
        src_min_x, src_max_x = max(0, min_x), min(img_w, max_x)
        
        dst_min_y, dst_max_y = src_min_y - min_y, patch_size - (max_y - src_max_y)
        dst_min_x, dst_max_x = src_min_x - min_x, patch_size - (max_x - src_max_x)
        
        # Crop
        patch[dst_min_y:dst_max_y, dst_min_x:dst_max_x, :] = \
            image[src_min_y:src_max_y, src_min_x:src_max_x, :]
            
        # Transpose to (Channels, H, W) for PyTorch
        patch_torch = np.transpose(patch, (2, 0, 1))
        
        # Save
        cell_id = prop.label
        save_path = os.path.join(save_dir, f"cell_{cell_id}.npy")
        np.save(save_path, patch_torch)
        
        extracted_patches.append((cell_id, patch))
        count += 1
        
    print(f"‚úÖ Extracted and saved {count} patches to {save_dir}")
    return extracted_patches

# Run Extraction
patches = extract_patches(raw_image, mask, output_patch_dir)


In [None]:
# --- 5. Visualization of Patches ---
# Display the extracted patches to verify content
if patches:
    fig, axes = plt.subplots(1, len(patches), figsize=(15, 3))
    if len(patches) == 1: axes = [axes]
    
    for ax, (cell_id, patch) in zip(axes, patches):
        # Create a composite RGB for visualization
        # R=PanCK(3), G=CD8(1), B=DAPI(0)
        rgb_patch = np.zeros((64, 64, 3))
        rgb_patch[..., 0] = patch[..., 3] / (patch[..., 3].max() + 1e-8) # Red: PanCK
        rgb_patch[..., 1] = patch[..., 1] / (patch[..., 1].max() + 1e-8) # Green: CD8
        rgb_patch[..., 2] = patch[..., 0] / (patch[..., 0].max() + 1e-8) # Blue: DAPI
        
        ax.imshow(np.clip(rgb_patch * 1.5, 0, 1)) # Slight brighten
        ax.set_title(f"Cell ID: {cell_id}")
        ax.axis('off')
    
    plt.suptitle("Extracted Single-Cell Patches (Composite View)")
    plt.show()