# Stomach WSI Segmentation Prediction

Whole Slide Image segmentation using **Overlapping Patch + Weighted Blending** approach.

- Model trained at 1.0 mpp
- Output mask at 8.0 mpp

In [None]:
import openslide
import os
from PIL import Image
import numpy as np
from tqdm import tqdm
from glob import glob
import torch
import torch.nn.functional as F
import segmentation_models_pytorch as smp
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Configuration
wsi_list = glob('../../data/IHC_HE_Pair_Data_GA_SS/PD-L1(HnE)/*.ndpi')
model_mpp = 0.9188          # Model was trained at this MPP
output_mpp = 4.0         # Output mask resolution
image_size = 512         # Patch size for model input
overlap_ratio = 0.7      # 70% overlap between patches
batch_size = 8           # Batch size for inference

# Class configuration
class_list = [
    "Background",
    "Stroma",
    "Non_Tumor",
    "Tumor",
]
num_classes = len(class_list)

# Color map for visualization
color_map = {
    0: (255, 255, 255),  # Background - White
    1: (0, 255, 0),      # Stroma - Green
    2: (0, 0, 255),      # Non_Tumor - Blue
    3: (255, 255, 0),    # Tumor - Yellow
}

# Device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Number of WSIs: {len(wsi_list)}")
print(f"Model MPP: {model_mpp}, Output MPP: {output_mpp}")
print(f"Image size: {image_size}, Overlap ratio: {overlap_ratio}, Batch size: {batch_size}")

In [None]:
# Load trained model
model_path = '../../model/Tumor_region_segmentation/stomach/ST_callback.pt'

model = smp.DeepLabV3Plus(
    encoder_name="efficientnet-b5",
    encoder_weights=None,
    in_channels=3,
    classes=num_classes,
).to(device)

model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"Model loaded from: {model_path}")

In [None]:
def create_gaussian_weight_mask(size, sigma=0.25):
    """
    Create a 2D Gaussian weight mask for smooth blending.
    Center has higher weight, edges have lower weight.
    """
    x = np.linspace(-1, 1, size)
    y = np.linspace(-1, 1, size)
    xx, yy = np.meshgrid(x, y)
    
    # Gaussian distribution
    gaussian = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))
    
    # Normalize to [0, 1]
    gaussian = (gaussian - gaussian.min()) / (gaussian.max() - gaussian.min())
    
    return gaussian.astype(np.float32)

def create_linear_weight_mask(size, margin_ratio=0.25):
    """
    Create a linear weight mask that ramps up from edges.
    """
    margin = int(size * margin_ratio)
    weight = np.ones((size, size), dtype=np.float32)
    
    # Create linear ramps for edges
    for i in range(margin):
        val = (i + 1) / margin
        weight[i, :] *= val
        weight[-(i+1), :] *= val
        weight[:, i] *= val
        weight[:, -(i+1)] *= val
    
    return weight

# Create weight mask
weight_mask = create_gaussian_weight_mask(image_size, sigma=0.3)

# Visualize weight mask
plt.figure(figsize=(6, 5))
plt.imshow(weight_mask, cmap='hot')
plt.colorbar()
plt.title('Gaussian Weight Mask for Blending')
plt.show()

In [None]:
def get_wsi_mpp(slide):
    """Get MPP (microns per pixel) from WSI metadata."""
    try:
        mpp_x = float(slide.properties.get(openslide.PROPERTY_NAME_MPP_X, 0))
        mpp_y = float(slide.properties.get(openslide.PROPERTY_NAME_MPP_Y, 0))
        if mpp_x > 0 and mpp_y > 0:
            return (mpp_x + mpp_y) / 2
    except:
        pass
    
    # Try to get from NDPI specific properties
    try:
        if 'tiff.XResolution' in slide.properties:
            x_res = float(slide.properties['tiff.XResolution'])
            # Convert to mpp (assuming resolution is in pixels per cm)
            return 10000 / x_res
    except:
        pass
    
    # Default assumption for 40x objective
    print("Warning: Could not determine MPP, using default 0.25")
    return 0.25

def find_best_level(slide, target_mpp, base_mpp):
    """Find the best level to read from based on target MPP."""
    level_count = slide.level_count
    level_downsamples = slide.level_downsamples
    
    target_downsample = target_mpp / base_mpp
    
    best_level = 0
    best_diff = float('inf')
    
    for level in range(level_count):
        diff = abs(level_downsamples[level] - target_downsample)
        if level_downsamples[level] <= target_downsample and diff < best_diff:
            best_level = level
            best_diff = diff
    
    return best_level

In [None]:
def predict_wsi_overlapping(slide, model, device, model_mpp, output_mpp, 
                            patch_size=512, overlap_ratio=0.5, batch_size=8):
    """
    Predict on WSI using overlapping patches with weighted blending.
    
    Args:
        slide: OpenSlide object
        model: Trained segmentation model
        device: torch device
        model_mpp: MPP at which model was trained
        output_mpp: Desired output mask MPP
        patch_size: Size of patches for model input
        overlap_ratio: Overlap ratio between adjacent patches
        batch_size: Batch size for inference
    
    Returns:
        prediction_mask: Final prediction mask at output_mpp resolution
        thumbnail: WSI thumbnail for visualization
    """
    
    # Get WSI properties
    base_mpp = get_wsi_mpp(slide)
    wsi_w, wsi_h = slide.dimensions
    
    print(f"WSI dimensions: {wsi_w} x {wsi_h}")
    print(f"WSI base MPP: {base_mpp:.4f}")
    
    # Calculate scale factors
    read_scale = model_mpp / base_mpp  # Scale from base to model resolution
    output_scale = output_mpp / model_mpp  # Scale from model to output resolution
    
    # Patch size at level 0 (base resolution)
    patch_size_level0 = int(patch_size * read_scale)
    
    # Step size with overlap
    step_size = int(patch_size * (1 - overlap_ratio))
    step_size_level0 = int(step_size * read_scale)
    
    # Output dimensions at model_mpp resolution
    model_res_w = int(wsi_w / read_scale)
    model_res_h = int(wsi_h / read_scale)
    
    # Output dimensions at output_mpp resolution
    output_w = int(model_res_w / output_scale)
    output_h = int(model_res_h / output_scale)
    
    print(f"Model resolution size: {model_res_w} x {model_res_h}")
    print(f"Output mask size: {output_w} x {output_h}")
    print(f"Patch size at level 0: {patch_size_level0}")
    print(f"Step size at level 0: {step_size_level0}")
    
    # Find best level for reading
    read_level = find_best_level(slide, model_mpp, base_mpp)
    level_downsample = slide.level_downsamples[read_level]
    actual_read_scale = read_scale / level_downsample
    
    print(f"Reading from level {read_level} (downsample: {level_downsample:.2f})")
    
    # Initialize accumulators at model resolution
    prediction_sum = np.zeros((num_classes, model_res_h, model_res_w), dtype=np.float32)
    weight_sum = np.zeros((model_res_h, model_res_w), dtype=np.float32)
    
    # Create weight mask
    weight_mask = create_gaussian_weight_mask(patch_size, sigma=0.3)
    
    # Calculate number of patches
    n_patches_x = max(1, int(np.ceil((wsi_w - patch_size_level0) / step_size_level0)) + 1)
    n_patches_y = max(1, int(np.ceil((wsi_h - patch_size_level0) / step_size_level0)) + 1)
    total_patches = n_patches_x * n_patches_y
    
    print(f"Total patches: {n_patches_x} x {n_patches_y} = {total_patches}")
    
    # Generate patch coordinates
    patch_coords = []
    for y_idx in range(n_patches_y):
        for x_idx in range(n_patches_x):
            x = min(x_idx * step_size_level0, wsi_w - patch_size_level0)
            y = min(y_idx * step_size_level0, wsi_h - patch_size_level0)
            x = max(0, x)
            y = max(0, y)
            patch_coords.append((x, y))
    
    # Process patches in batches
    tf = ToTensor()
    
    for batch_start in tqdm(range(0, len(patch_coords), batch_size), desc="Processing patches"):
        batch_coords = patch_coords[batch_start:batch_start + batch_size]
        batch_images = []
        valid_coords = []
        
        for (x, y) in batch_coords:
            # Read patch at the best level
            try:
                patch = slide.read_region(
                    (x, y), 
                    read_level, 
                    (int(patch_size_level0 / level_downsample), 
                     int(patch_size_level0 / level_downsample))
                ).convert('RGB')
                
                # Resize to model input size
                patch = patch.resize((patch_size, patch_size), Image.BILINEAR)
                patch_tensor = tf(patch)
                
                # Check if patch is mostly background (white)
                patch_array = np.array(patch)
                white_ratio = np.mean(patch_array > 220)
                
                if white_ratio < 0.9:  # Skip mostly white patches
                    batch_images.append(patch_tensor)
                    valid_coords.append((x, y))
            except Exception as e:
                continue
        
        if len(batch_images) == 0:
            continue
        
        # Stack and predict
        batch_tensor = torch.stack(batch_images).to(device).float()
        
        with torch.no_grad():
            predictions = model(batch_tensor)
            predictions = F.softmax(predictions, dim=1)
            predictions = predictions.cpu().numpy()
        
        # Accumulate predictions with weights
        for i, (x, y) in enumerate(valid_coords):
            # Calculate position in model resolution
            x_model = int(x / read_scale)
            y_model = int(y / read_scale)
            
            # Get the region to update
            x_end = min(x_model + patch_size, model_res_w)
            y_end = min(y_model + patch_size, model_res_h)
            
            patch_w = x_end - x_model
            patch_h = y_end - y_model
            
            if patch_w <= 0 or patch_h <= 0:
                continue
            
            # Add weighted prediction
            for c in range(num_classes):
                prediction_sum[c, y_model:y_end, x_model:x_end] += \
                    predictions[i, c, :patch_h, :patch_w] * weight_mask[:patch_h, :patch_w]
            
            weight_sum[y_model:y_end, x_model:x_end] += weight_mask[:patch_h, :patch_w]
    
    # Normalize by weights
    weight_sum = np.maximum(weight_sum, 1e-6)  # Avoid division by zero
    for c in range(num_classes):
        prediction_sum[c] /= weight_sum
    
    # Get final prediction (argmax)
    prediction_model_res = np.argmax(prediction_sum, axis=0).astype(np.uint8)
    
    # Downsample to output resolution
    prediction_mask = Image.fromarray(prediction_model_res).resize(
        (output_w, output_h), 
        Image.NEAREST
    )
    prediction_mask = np.array(prediction_mask)
    
    # Get thumbnail for visualization
    thumbnail = slide.get_thumbnail((output_w, output_h))
    thumbnail = np.array(thumbnail.convert('RGB'))
    
    return prediction_mask, thumbnail, prediction_sum

In [None]:
def colorize_mask(mask, color_map):
    """Convert class mask to RGB image."""
    h, w = mask.shape
    rgb_mask = np.zeros((h, w, 3), dtype=np.uint8)
    
    for class_idx, color in color_map.items():
        rgb_mask[mask == class_idx] = color
    
    return rgb_mask

def create_overlay(image, mask, color_map, alpha=0.4):
    """Create overlay of mask on image."""
    overlay = image.copy()
    
    for class_idx, color in color_map.items():
        if class_idx == 0:  # Skip background
            continue
        mask_area = (mask == class_idx)
        if mask_area.any():
            color_array = np.array(color, dtype=np.uint8)
            overlay[mask_area] = ((1 - alpha) * overlay[mask_area] + alpha * color_array).astype(np.uint8)
    
    return overlay

## Process First WSI

In [None]:
# Process first WSI
wsi_path = wsi_list[1]
print(f"Processing: {os.path.basename(wsi_path)}")
print("="*60)

# Open slide
slide = openslide.OpenSlide(wsi_path)

# Print slide info
print(f"Dimensions: {slide.dimensions}")
print(f"Level count: {slide.level_count}")
print(f"Level dimensions: {slide.level_dimensions}")
print(f"Level downsamples: {slide.level_downsamples}")

In [None]:
# Run prediction
prediction_mask, thumbnail, prob_maps = predict_wsi_overlapping(
    slide=slide,
    model=model,
    device=device,
    model_mpp=model_mpp,
    output_mpp=output_mpp,
    patch_size=image_size,
    overlap_ratio=overlap_ratio,
    batch_size=batch_size
)

print(f"\nPrediction mask shape: {prediction_mask.shape}")
print(f"Thumbnail shape: {thumbnail.shape}")

In [None]:
# Visualize results
fig, axes = plt.subplots(2, 2, figsize=(20, 20))

# Original thumbnail
axes[0, 0].imshow(thumbnail)
axes[0, 0].set_title(f'Original WSI Thumbnail\n({thumbnail.shape[1]}x{thumbnail.shape[0]} @ {output_mpp} mpp)', fontsize=14)
axes[0, 0].axis('off')

# Prediction mask (colorized)
colored_mask = colorize_mask(prediction_mask, color_map)
axes[0, 1].imshow(colored_mask)
axes[0, 1].set_title('Prediction Mask (Colorized)', fontsize=14)
axes[0, 1].axis('off')

# Overlay
overlay = create_overlay(thumbnail, prediction_mask, color_map, alpha=0.5)
axes[1, 0].imshow(overlay)
axes[1, 0].set_title('Overlay (Original + Prediction)', fontsize=14)
axes[1, 0].axis('off')

# Class distribution
unique, counts = np.unique(prediction_mask, return_counts=True)
class_counts = {class_list[i]: counts[list(unique).index(i)] if i in unique else 0 
                for i in range(num_classes)}
total_pixels = prediction_mask.size

colors = [np.array(color_map[i])/255.0 for i in range(num_classes)]
bars = axes[1, 1].bar(class_list, [class_counts[c] for c in class_list], color=colors, edgecolor='black')
axes[1, 1].set_ylabel('Pixel Count')
axes[1, 1].set_title('Class Distribution', fontsize=14)
axes[1, 1].tick_params(axis='x', rotation=45)

# Add percentage labels
for bar, count in zip(bars, [class_counts[c] for c in class_list]):
    percentage = 100 * count / total_pixels
    axes[1, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                    f'{percentage:.1f}%', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
# Show color legend
fig, ax = plt.subplots(figsize=(10, 2))
for i, (class_idx, color) in enumerate(color_map.items()):
    ax.barh(0, 1, left=i, color=np.array(color)/255.0, edgecolor='black', linewidth=2)
    ax.text(i+0.5, 0, class_list[class_idx], ha='center', va='center', 
            fontsize=12, fontweight='bold')

ax.set_xlim(0, len(color_map))
ax.set_ylim(-0.5, 0.5)
ax.set_title('Class Color Legend', fontsize=14)
ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Visualize probability maps for each class
fig, axes = plt.subplots(1, num_classes, figsize=(20, 5))

# Downsample probability maps to output resolution
for c in range(num_classes):
    prob_map_resized = Image.fromarray(prob_maps[c]).resize(
        (prediction_mask.shape[1], prediction_mask.shape[0]), 
        Image.BILINEAR
    )
    prob_map_resized = np.array(prob_map_resized)
    
    im = axes[c].imshow(prob_map_resized, cmap='hot', vmin=0, vmax=1)
    axes[c].set_title(f'{class_list[c]}\nProbability Map', fontsize=12)
    axes[c].axis('off')
    plt.colorbar(im, ax=axes[c], fraction=0.046, pad=0.04)

plt.suptitle('Class Probability Maps', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

In [None]:
# Save results (optional)
output_dir = '../../results/wsi_predictions/stomach/'
os.makedirs(output_dir, exist_ok=True)

wsi_name = os.path.splitext(os.path.basename(wsi_path))[0]

# Save prediction mask
np.save(f'{output_dir}{wsi_name}_mask.npy', prediction_mask)

# Save colored mask as image
Image.fromarray(colored_mask).save(f'{output_dir}{wsi_name}_colored_mask.png')

# Save overlay
Image.fromarray(overlay).save(f'{output_dir}{wsi_name}_overlay.png')

print(f"Results saved to: {output_dir}")

In [None]:
# Close slide
slide.close()
print("Done!")

## Summary

This notebook performs WSI segmentation using:

1. **Overlapping Patches**: 50% overlap to ensure smooth predictions across patch boundaries
2. **Weighted Blending**: Gaussian weight mask gives higher weight to patch centers, lower weight to edges
3. **Multi-scale Processing**: Reads at optimal level, processes at model resolution (1.0 mpp), outputs at 8.0 mpp

### Output
- Prediction mask at 8.0 mpp resolution
- Colored visualization
- Overlay on original thumbnail
- Class probability maps