# Agent A: Colocalization Analysis

This final notebook calculates co-localization between different smFISH channels and summarizes the results per nucleus.

### Goals:
1. Load spots and nucleus assignments from Step 3.
2. Perform pair-wise co-localization analysis between all channels within each nucleus.
3. Summarize counts (single spots and co-localized spots) per nucleus.
4. Export final CSV results.

In [None]:
import os
import pandas as pd
import numpy as np
from scipy.spatial import KDTree
from itertools import combinations
import matplotlib.pyplot as plt

# Default rendering settings
%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 6)

## 1. Parameters

Define the distance threshold for co-localization (in pixels).

In [None]:
spot_dir = "./processed_data/03_Spots"
output_dir = "./processed_data/04_Analysis"

coloc_dist_threshold = 2.5  # pixels

os.makedirs(output_dir, exist_ok=True)

## 2. Analysis Functions

Function to find co-localizing spots between two sets of coordinates using KDTree.

In [None]:
def find_colocalized(spots1: pd.DataFrame, spots2: pd.DataFrame, threshold: float) -> int:
    """Count how many spots in spots1 have a neighbor in spots2 within threshold."""
    if spots1.empty or spots2.empty: return 0
    
    tree2 = KDTree(spots2[['y', 'x']].values)
    # Find neighbors for all spots in spots1
    distances, indices = tree2.query(spots1[['y', 'x']].values, distance_upper_bound=threshold)
    
    # query returns inf distance if no neighbor found within threshold
    coloc_count = np.sum(distances < np.inf)
    return int(coloc_count)

## 3. Process All Files

Iterate through each CSV from Step 3 and perform the analysis.

In [None]:
spot_files = [f for f in os.listdir(spot_dir) if f.endswith("_spots.csv")]
all_summaries = []

for filename in spot_files:
    print(f"Analyzing {filename}...")
    df_spots = pd.read_csv(os.path.join(spot_dir, filename))
    
    # Group by Nucleus ID (excluding background 0)
    nuclei = df_spots[df_spots['nucleus_id'] > 0]['nucleus_id'].unique()
    channels = sorted(df_spots['channel'].unique())
    channel_pairs = list(combinations(channels, 2))
    
    for nuc_id in nuclei:
        nuc_spots = df_spots[df_spots['nucleus_id'] == nuc_id]
        summary = {
            'filename': filename,
            'nucleus_id': nuc_id
        }
        
        # Counts per channel
        for ch in channels:
            summary[f'ch{ch}_count'] = len(nuc_spots[nuc_spots['channel'] == ch])
            
        # Colocalization counts
        for ch1, ch2 in channel_pairs:
            s1 = nuc_spots[nuc_spots['channel'] == ch1]
            s2 = nuc_spots[nuc_spots['channel'] == ch2]
            
            coloc_count = find_colocalized(s1, s2, coloc_dist_threshold)
            summary[f'ch{ch1}_ch{ch2}_coloc'] = coloc_count
            
        all_summaries.append(summary)

# Final Dataframe
df_final = pd.DataFrame(all_summaries)
summary_path = os.path.join(output_dir, "final_summary_per_nucleus.csv")
df_final.to_csv(summary_path, index=False)

print(f"\nAnalysis complete. Results saved to: {summary_path}")
display(df_final.head())
