# Tiled Segmentation for Large Images

This notebook helps process images too large to segment in a single pass. It splits large images into overlapping tiles, allows external segmentation (e.g., Cellpose GUI), then merges the masks back together with proper handling of cells in overlap regions.

## Workflow
1. Load large pseudochannel image
2. Configure tiling parameters
3. Split and save tiles
4. **[Manual step]** Segment tiles externally (Cellpose GUI, etc.)
5. Load segmented tile masks
6. Merge masks with cell deduplication
7. Save merged mask

## Setup

In [None]:
import sys
from pathlib import Path

# Add src to path if running from notebooks folder
src_path = Path("../src").resolve()
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

from tiling import (
    TileInfo,
    compute_tile_grid,
    extract_tile,
    split_image,
    save_tile_info,
    load_tile_info,
    load_tile_masks,
    merge_tile_masks,
    relabel_mask,
)
from tiling.merge import save_merged_mask

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import tifffile

%matplotlib inline

In [None]:
# ==== CONFIGURATION ====

# Input: Pseudochannel image (1 or 2 channels)
# If 2-channel, expects format (2, H, W) with nuclear and membrane channels
IMAGE_PATH = Path("/mnt/Vol_c/JP_segmentation/cellpose_input.ome.tif")

# Tiling parameters
TILE_SIZE = 10000      # Tile size in pixels (before overlap)
OVERLAP = 400         # Overlap between adjacent tiles (should be ~2x cell diameter)

# Output directories
TILES_DIR = Path("/mnt/Vol_c/JP_segmentation/tiles")          # Where to save split tiles
MERGED_OUTPUT = Path("/mnt/Vol_c/JP_segmentation/cellpose_input_masks.ome.tif")            # Final merged mask

# Merging parameters
IOU_THRESHOLD = 0.5   # Minimum IoU to consider cells as matching in overlap

In [None]:
image = tifffile.imread(str(IMAGE_PATH))

print(f"Loaded image: {IMAGE_PATH.name}")
print(f"  Shape: {image.shape}")
print(f"  Dtype: {image.dtype}")

# Handle different image formats
if image.ndim == 2:
    image_shape = image.shape
    n_channels = 1
elif image.ndim == 3:
    # Could be (C, H, W) or (H, W, C)
    if image.shape[0] <= 4:  # Assume channels-first
        n_channels = image.shape[0]
        image_shape = image.shape[1:]
    else:  # Assume channels-last, transpose
        n_channels = image.shape[2]
        image_shape = image.shape[:2]
        image = image.transpose(2, 0, 1)
else:
    raise ValueError(f"Unexpected image dimensions: {image.ndim}")

print(f"  Image size: {image_shape[1]} x {image_shape[0]} pixels ({n_channels} channel(s))")
print(f"  Total pixels: {image_shape[0] * image_shape[1]:,}")

## 1. Load Image

## 2. Configure Tiling

Visualize the tile grid before splitting.

In [None]:
# Compute tile grid
tile_infos = compute_tile_grid(image_shape, TILE_SIZE, OVERLAP)

n_rows = max(t.row for t in tile_infos) + 1
n_cols = max(t.col for t in tile_infos) + 1

print(f"Tile grid: {n_rows} rows x {n_cols} cols = {len(tile_infos)} tiles")
print(f"Tile size: {TILE_SIZE} px, Overlap: {OVERLAP} px")

# Show tile shapes
shapes = set(t.shape for t in tile_infos)
print(f"Tile shapes: {shapes}")

In [None]:
# Visualize tile grid on image thumbnail
fig, ax = plt.subplots(figsize=(10, 10))

# Show thumbnail of image
if image.ndim == 2:
    thumbnail = image[::max(1, image_shape[0]//500), ::max(1, image_shape[1]//500)]
else:
    # Show first channel for multi-channel
    thumbnail = image[0, ::max(1, image_shape[0]//500), ::max(1, image_shape[1]//500)]

ax.imshow(thumbnail, cmap='gray', extent=[0, image_shape[1], image_shape[0], 0])

# Draw tile boundaries
colors = plt.cm.tab10(np.linspace(0, 1, 10))
for info in tile_infos:
    color = colors[(info.row + info.col) % 10]
    rect = patches.Rectangle(
        (info.x_start, info.y_start),
        info.width,
        info.height,
        linewidth=2,
        edgecolor=color,
        facecolor='none',
        alpha=0.8,
    )
    ax.add_patch(rect)
    
    # Label tile
    ax.text(
        info.x_start + info.width/2,
        info.y_start + info.height/2,
        f"({info.row},{info.col})",
        ha='center',
        va='center',
        fontsize=8,
        color='white',
        fontweight='bold',
        bbox=dict(boxstyle='round', facecolor=color, alpha=0.7),
    )

ax.set_title(f"Tile Grid: {n_rows}x{n_cols} = {len(tile_infos)} tiles")
ax.set_xlabel("X (pixels)")
ax.set_ylabel("Y (pixels)")
plt.tight_layout()
plt.show()

In [None]:
# Visualize overlap regions
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(thumbnail, cmap='gray', extent=[0, image_shape[1], image_shape[0], 0])

# Draw overlap regions in red
for info in tile_infos:
    # Left overlap
    if info.overlap_left > 0:
        rect = patches.Rectangle(
            (info.x_start, info.y_start),
            info.overlap_left,
            info.height,
            facecolor='red',
            alpha=0.3,
        )
        ax.add_patch(rect)
    
    # Top overlap
    if info.overlap_top > 0:
        rect = patches.Rectangle(
            (info.x_start, info.y_start),
            info.width,
            info.overlap_top,
            facecolor='red',
            alpha=0.3,
        )
        ax.add_patch(rect)

ax.set_title(f"Overlap Regions (red) - {OVERLAP} px overlap")
ax.set_xlabel("X (pixels)")
ax.set_ylabel("Y (pixels)")
plt.tight_layout()
plt.show()

## 3. Split and Save Tiles

In [None]:
# Create output directory
TILES_DIR.mkdir(parents=True, exist_ok=True)

# Split image into tiles
tiles, tile_infos = split_image(
    image,
    tile_size=TILE_SIZE,
    overlap=OVERLAP,
    output_dir=TILES_DIR,
    filename_pattern="tile_r{row}_c{col}.tif",
)

print(f"Saved {len(tiles)} tiles to {TILES_DIR}")

# Save tile metadata
tile_info_path = save_tile_info(
    tile_infos,
    TILES_DIR / "tile_info.json",
    image_shape=image_shape,
    tile_size=TILE_SIZE,
    overlap=OVERLAP,
)
print(f"Saved tile metadata to {tile_info_path}")

In [None]:
# List generated tiles
tile_files = sorted(TILES_DIR.glob("tile_r*_c*.tif"))
print(f"Generated {len(tile_files)} tile files:")
for f in tile_files[:6]:
    print(f"  {f.name}")
if len(tile_files) > 6:
    print(f"  ... and {len(tile_files) - 6} more")

## 4. Segment Tiles Externally

**This is a manual step.** Open the tiles in your preferred segmentation tool:

### Option A: Cellpose GUI
```bash
cellpose --dir <tiles_directory> --save_tif
```
This will create `tile_r{r}_c{c}_cp_masks.tif` files.

### Option B: Cellpose Python
```python
from cellpose import models

model = models.Cellpose(gpu=True, model_type='cyto3')

for tile_path in tiles_dir.glob('tile_r*_c*.tif'):
    img = tifffile.imread(tile_path)
    masks, flows, styles, diams = model.eval(img, channels=[2, 1])
    
    mask_path = tile_path.parent / f"{tile_path.stem}_cp_masks.tif"
    tifffile.imwrite(mask_path, masks.astype(np.uint32))
```

### Option C: Other segmentation tools
Any tool that produces integer label masks will work. Save masks with the pattern:
```
tile_r{row}_c{col}_cp_masks.tif
```

**After segmentation, continue to the next section.**

In [None]:
# Check which masks exist
mask_files = sorted(TILES_DIR.glob("tile_r*_c*_cp_masks.tif"))
print(f"Found {len(mask_files)} mask files (expected {len(tile_infos)}):")

if len(mask_files) == 0:
    print("\n  [!] No mask files found. Run segmentation first.")
elif len(mask_files) < len(tile_infos):
    print(f"\n  [!] Missing {len(tile_infos) - len(mask_files)} mask files.")
    # Find missing
    existing = {f.stem.replace('_cp_masks', '') for f in mask_files}
    expected = {f"tile_r{t.row}_c{t.col}" for t in tile_infos}
    missing = expected - existing
    print(f"  Missing: {sorted(missing)[:5]}{'...' if len(missing) > 5 else ''}")
else:
    print("  All tiles have masks. Ready to merge!")

## 5. Load Segmented Tile Masks

In [None]:
# Load tile metadata
tile_infos, metadata = load_tile_info(TILES_DIR / "tile_info.json")

print(f"Loaded metadata for {len(tile_infos)} tiles")
print(f"  Original image shape: {metadata.get('image_shape', 'unknown')}")
print(f"  Tile size: {metadata.get('tile_size', 'unknown')}")
print(f"  Overlap: {metadata.get('overlap', 'unknown')}")

In [None]:
# Load all tile masks
tile_masks = load_tile_masks(
    TILES_DIR,
    tile_infos,
    mask_pattern="tile_r{row}_c{col}_cp_masks.tif",
)

print(f"Loaded {len(tile_masks)} tile masks")

# Show cell counts per tile
total_cells = 0
for (row, col), mask in sorted(tile_masks.items()):
    n_cells = len(np.unique(mask)) - 1  # Exclude background
    total_cells += n_cells
    print(f"  Tile ({row},{col}): {n_cells} cells")

print(f"\nTotal cells across all tiles: {total_cells}")
print(f"(Note: cells in overlap regions are counted multiple times)")

## 6. Merge Masks

Stitch tile masks back together, deduplicating cells in overlap regions.

In [None]:
# Get original image shape from metadata
original_shape = tuple(metadata['image_shape'])

print(f"Merging {len(tile_masks)} tiles into {original_shape} image...")
print(f"Using IoU threshold: {IOU_THRESHOLD}")

# Merge tiles
merged_mask = merge_tile_masks(
    tile_masks,
    tile_infos,
    original_shape,
    iou_threshold=IOU_THRESHOLD,
)

n_cells = len(np.unique(merged_mask)) - 1  # Exclude background
print(f"\nMerge complete!")
print(f"  Output shape: {merged_mask.shape}")
print(f"  Total cells: {n_cells}")
print(f"  Cells removed as duplicates: {total_cells - n_cells}")

In [None]:
# Visualize merged mask
fig, axes = plt.subplots(1, 2, figsize=(14, 7))

# Original image (downsampled)
ax = axes[0]
if image.ndim == 2:
    img_display = image
else:
    img_display = image[0]  # First channel

scale = max(1, max(img_display.shape) // 1000)
ax.imshow(img_display[::scale, ::scale], cmap='gray')
ax.set_title('Original Image (downsampled)')
ax.axis('off')

# Merged mask (downsampled)
ax = axes[1]
mask_display = merged_mask[::scale, ::scale]
ax.imshow(mask_display, cmap='nipy_spectral', interpolation='nearest')
ax.set_title(f'Merged Mask ({n_cells} cells)')
ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Zoom into an overlap region to verify merge quality
if len(tile_infos) > 1:
    # Find a tile with both left and top overlap
    test_tile = None
    for info in tile_infos:
        if info.overlap_left > 0 and info.overlap_top > 0:
            test_tile = info
            break
    
    if test_tile is None:
        # Fall back to any tile with overlap
        for info in tile_infos:
            if info.overlap_left > 0 or info.overlap_top > 0:
                test_tile = info
                break
    
    if test_tile:
        # Extract region around overlap
        margin = 100
        y1 = max(0, test_tile.y_start - margin)
        y2 = min(original_shape[0], test_tile.y_start + test_tile.overlap_top + margin)
        x1 = max(0, test_tile.x_start - margin)
        x2 = min(original_shape[1], test_tile.x_start + test_tile.overlap_left + margin)
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        
        # Image region
        ax = axes[0]
        if image.ndim == 2:
            region_img = image[y1:y2, x1:x2]
        else:
            region_img = image[0, y1:y2, x1:x2]
        ax.imshow(region_img, cmap='gray')
        ax.axvline(test_tile.x_start - x1, color='red', linestyle='--', label='Tile boundary')
        ax.axhline(test_tile.y_start - y1, color='red', linestyle='--')
        ax.set_title(f'Overlap Region - Tile ({test_tile.row},{test_tile.col})')
        ax.legend()
        
        # Mask region
        ax = axes[1]
        region_mask = merged_mask[y1:y2, x1:x2]
        ax.imshow(region_mask, cmap='nipy_spectral', interpolation='nearest')
        ax.axvline(test_tile.x_start - x1, color='white', linestyle='--')
        ax.axhline(test_tile.y_start - y1, color='white', linestyle='--')
        ax.set_title('Merged Mask (check for duplicates)')
        
        plt.tight_layout()
        plt.show()
else:
    print("Only one tile - no overlap to visualize.")

## 7. Save Merged Mask

In [None]:
# Save merged mask
output_path = save_merged_mask(merged_mask, MERGED_OUTPUT, compress=True)

print(f"Saved merged mask to: {output_path}")
print(f"  Shape: {merged_mask.shape}")
print(f"  Cells: {n_cells}")

# Verify file size
size_mb = output_path.stat().st_size / (1024 * 1024)
print(f"  File size: {size_mb:.1f} MB")

## Summary Statistics

In [None]:
# # load mask image if necessary
# merged_mask = tifffile.imread(MERGED_OUTPUT)

In [None]:
# Compute cell statistics
from scipy import ndimage

# Cell sizes
cell_labels = np.unique(merged_mask)
cell_labels = cell_labels[cell_labels != 0]

cell_sizes = ndimage.sum(np.ones_like(merged_mask), merged_mask, cell_labels)

print("Cell Size Statistics:")
print(f"  Total cells: {len(cell_sizes)}")
print(f"  Mean size: {np.mean(cell_sizes):.1f} pixels")
print(f"  Median size: {np.median(cell_sizes):.1f} pixels")
print(f"  Min size: {np.min(cell_sizes):.0f} pixels")
print(f"  Max size: {np.max(cell_sizes):.0f} pixels")
print(f"  Std dev: {np.std(cell_sizes):.1f} pixels")

In [None]:
# Cell size histogram
fig, ax = plt.subplots(figsize=(10, 4))

bin_edges = np.linspace(cell_sizes.min(), cell_sizes.max(), 51)
counts, bin_edges = np.histogram(cell_sizes, bins=bin_edges)
ax.bar(bin_edges[:-1], counts, width=np.diff(bin_edges), align='edge', edgecolor='black', alpha=0.7)
ax.axvline(np.median(cell_sizes), color='red', linestyle='--', label=f'Median: {np.median(cell_sizes):.0f}')
ax.axvline(np.mean(cell_sizes), color='green', linestyle='--', label=f'Mean: {np.mean(cell_sizes):.0f}')

ax.set_xlabel('Cell Size (pixels)')
ax.set_ylabel('Count')
ax.set_title('Cell Size Distribution')
ax.legend()

plt.tight_layout()
plt.show()

## Cleanup (Optional)

Remove tile files after successful merge.

In [None]:
# Uncomment to delete tiles after merge
import shutil
shutil.rmtree(TILES_DIR)
print(f"Removed tiles directory: {TILES_DIR}")