# Agent D: Spot Detection (Spotiflow)

This notebook performs subpixel-accurate spot detection in multi-color smFISH channels using Spotiflow. 

### Workflow:
1. **Tuning**: Load a sample image and define detection parameters (Threshold, Sigma).
2. **Preview**: Visualize results to ensure real signals are captured and background is discarded.
3. **Batch**: Apply parameters to all FOVs and channels.
4. **Assignment**: Assign each spot to a Nucleus ID based on masks from Step 2.

In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import tifffile
import matplotlib.pyplot as plt
from spotiflow.model import Spotiflow
from scipy.spatial import KDTree

# Default rendering settings
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 12)

## 1. Environment & Paths

Define where your MIPs and Segmentation masks are located.

In [None]:
mip_dir = "./processed_data/01_MIPs"
seg_dir = "./processed_data/02_Segmentation"
output_dir = "./processed_data/03_Spots"

# Pre-trained model name (e.g., 'general' or path to custom model)
model_name = "general"

os.makedirs(output_dir, exist_ok=True)

## 2. Parameter Tuning (Preview Mode)

**ADJUST THESE VALUES** based on your visual inspection of the preview below.

- `prob_threshold`: Higher value means stricter detection (captures brighter spots).
- `sigma`: Approximate spot size (usually around 1.0 - 1.5).

In [None]:
# Pick a file and channel to preview
sample_file = os.listdir(mip_dir)[0]
preview_channel_idx = 0  # 647

# TUNING PARAMETERS
prob_threshold = 0.5
sigma = 1.0

# Load Model
model = Spotiflow.from_pretrained(model_name)

# Load and Predict
img = tifffile.imread(os.path.join(mip_dir, sample_file))
ch_img = img[preview_channel_idx]
spots, details = model.predict(ch_img, prob_thresh=prob_threshold, sigma=sigma)

print(f"Preview: Found {len(spots)} spots in {sample_file} (Ch {preview_channel_idx})")

# Visualization
plt.imshow(ch_img, cmap='gray', vmax=np.percentile(ch_img, 98))
plt.scatter(spots[:, 1], spots[:, 0], s=20, facecolors='none', edgecolors='r', alpha=0.5)
plt.title(f"Preview: {sample_file} - {len(spots)} spots")
plt.show()


## 3. Batch Processing

Once satisfied with the parameters, run the detection on all files and channels.

In [None]:
mip_files = [f for f in os.listdir(mip_dir) if f.endswith(".tif")]

for filename in mip_files:
    print(f"Processing {filename}...")
    img_path = os.path.join(mip_dir, filename)
    img = tifffile.imread(img_path)
    
    # Load segmentation mask to assign Nucleus IDs
    seg_path = os.path.join(seg_dir, filename.replace(".tif", "_masks.tif"))
    masks = tifffile.imread(seg_path)
    
    all_spots = []
    
    # Iterate through channels (excluding DAPI if it's the last one)
    # We assume DAPI is the last channel and shouldn't be processed for spots
    for c in range(img.shape[0] - 1):
        ch_img = img[c]
        spots, details = model.predict(ch_img, prob_thresh=prob_threshold, sigma=sigma)
        
        for s_idx, (y, x) in enumerate(spots):
            # Get Nucleus ID at this coordinate
            # Round coords to nearest integer for mask lookup
            ry, rx = int(round(y)), int(round(x))
            # Check bounds
            ry = min(max(0, ry), masks.shape[0]-1)
            rx = min(max(0, rx), masks.shape[1]-1)
            
            nuc_id = masks[ry, rx]
            
            all_spots.append({
                'channel': c,
                'y': y,
                'x': x,
                'intensity': details['prob_map'][ry, rx], # Using probability as pseudo-intensity
                'nucleus_id': nuc_id
            })
            
    # Save CSV
    df = pd.DataFrame(all_spots)
    csv_path = os.path.join(output_dir, filename.replace(".tif", "_spots.csv"))
    df.to_csv(csv_path, index=False)
    print(f"  Total spots found: {len(df)}")
