# Detectron2 Mask Generation for Sperm Cell Analysis

This notebook generates binary segmentation masks from SEM images using Detectron2 instance segmentation.

**Requirements:**
- Google Colab (free GPU access)
- SEM image stack (TIFF format)
- Trained Detectron2 model weights (contact authors)

**Outputs:**
- Binary mask stacks (one per organelle class)
- Visualization overlays

**Time:** ~2-7 minutes per 200-slice stack (with GPU)

## Step 1: Install Dependencies

In [None]:
!pip install pyyaml==5.1 torch torchvision tifffile opencv-python -q
!git clone 'https://github.com/facebookresearch/detectron2' -q
!cd detectron2 && pip install -e . -q
print('✅ Dependencies installed')

## Step 2: Import and Setup

In [None]:
import sys
sys.path.insert(0, '/content/detectron2')
import torch, detectron2, numpy as np, os, cv2, tifffile
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
print(f'✅ PyTorch: {torch.__version__}')
print(f'✅ CUDA: {torch.cuda.is_available()}')

## Step 3: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print('✅ Google Drive mounted')

## Step 4: Configure Model (USER INPUT REQUIRED)

In [None]:
MODEL_WEIGHTS_PATH = '/content/drive/MyDrive/path/to/your/model_final.pth'
NUM_CLASSES = 6
CLASS_NAMES = ['mitochondria', 'pseudopod', 'unfused_MO', 'MO', 'sperm_cell', 'nucleus']
SCORE_THRESHOLD = 0.05

## Step 5: Initialize Predictor

In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'))
cfg.MODEL.WEIGHTS = MODEL_WEIGHTS_PATH
cfg.MODEL.ROI_HEADS.NUM_CLASSES = NUM_CLASSES
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = SCORE_THRESHOLD
predictor = DefaultPredictor(cfg)
metadata = MetadataCatalog.get('coco_2014_val')
metadata.thing_classes = CLASS_NAMES
print(f'✅ Model loaded')

## Step 6: Configure Paths (USER INPUT REQUIRED)

In [None]:
STACK_PATH = '/content/drive/MyDrive/path/to/your/image_stack.tif'
OUTPUT_DIR = '/content/drive/MyDrive/path/to/output/masks'

## Step 7: Process Stack

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)
for class_name in CLASS_NAMES:
    os.makedirs(os.path.join(OUTPUT_DIR, class_name), exist_ok=True)
image_stack = tifffile.imread(STACK_PATH)
num_slices = image_stack.shape[0]
print(f'✅ Loaded {num_slices} slices')
combined_masks_dict = {class_name: [] for class_name in CLASS_NAMES}
for slice_idx in range(num_slices):
    slice_img = image_stack[slice_idx]
    if len(slice_img.shape) == 2:
        slice_img_color = cv2.cvtColor(slice_img, cv2.COLOR_GRAY2BGR)
    else:
        slice_img_color = slice_img
    outputs = predictor(slice_img_color)
    pred_classes = outputs['instances'].pred_classes.cpu().numpy()
    pred_masks = outputs['instances'].pred_masks.cpu().numpy()
    H, W = slice_img_color.shape[:2]
    for class_idx, class_name in enumerate(CLASS_NAMES):
        combined_mask = np.zeros((H, W), dtype=np.uint8)
        for i, pred_class in enumerate(pred_classes):
            if pred_class < len(CLASS_NAMES) and CLASS_NAMES[pred_class] == class_name:
                mask = (pred_masks[i].astype(np.uint8)) * 255
                combined_mask = cv2.bitwise_or(combined_mask, mask)
        combined_masks_dict[class_name].append(combined_mask)
    if (slice_idx + 1) % 10 == 0:
        print(f'Processed {slice_idx + 1}/{num_slices}')
print('✅ Processing complete')

## Step 8: Save Masks

In [None]:
for class_name, mask_list in combined_masks_dict.items():
    if len(mask_list) == 0:
        continue
    stack_array = np.array(mask_list)
    class_folder = os.path.join(OUTPUT_DIR, class_name)
    stack_filename = os.path.join(class_folder, f'{class_name}_stack.tif')
    tifffile.imwrite(stack_filename, stack_array)
    print(f'✅ Saved {class_name}: {len(mask_list)} slices')
print(f'\n✅ All masks saved to {OUTPUT_DIR}')