# Validate Suite2p Binary Extraction

This notebook validates that frames were correctly extracted from raw ScanImage TIFFs to Suite2p `data_raw.bin` files.

**Test case:**
- Raw data: `\\rbo-s1\S1_DATA\lbm\kbarber\2025-11-04-mk311\raw\green`
- Suite2p output: `\\rbo-s1\S1_DATA\lbm\kbarber\2025-11-04-mk311\suite2p\plane{XX}_stitched\data_raw.bin`

**Validation steps:**
1. Load raw data with mbo_utilities
2. Load each plane's data_raw.bin
3. Compare frame-by-frame for each z-plane
4. Check for frame order issues, duplicates, or missing frames
5. Compute correlation metrics

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import mbo_utilities as mbo
from scipy.stats import pearsonr

%matplotlib inline

## 1. Load Raw Data

In [None]:
# Path to raw ScanImage TIFFs
raw_data_path = Path(r"\\rbo-s1\S1_DATA\lbm\kbarber\2025-11-04-mk311\raw\green")

print(f"Loading raw data from: {raw_data_path}")
print(f"Path exists: {raw_data_path.exists()}")
print()

# Load with FFT phase correction (matching what was used during extraction)
arr = mbo.imread(raw_data_path, fix_phase=True, use_fft=True)

print(f"Raw data loaded:")
print(f"  Shape: {arr.shape}")
print(f"  dtype: {arr.dtype}")
print(f"  Number of files: {len(arr.filenames)}")
print(f"  ROIs: {arr.rois if hasattr(arr, 'rois') else 'N/A'}")
print(f"  Channels per volume: {arr.num_channels if hasattr(arr, 'num_channels') else 'N/A'}")

## 2. Identify Suite2p Planes

In [None]:
# Suite2p output directory
suite2p_dir = Path(r"\\rbo-s1\S1_DATA\lbm\kbarber\2025-11-04-mk311\suite2p")

# Find all plane directories
plane_dirs = sorted([d for d in suite2p_dir.glob("plane*_stitched") if d.is_dir()])

print(f"Found {len(plane_dirs)} plane directories:")
for plane_dir in plane_dirs:
    data_raw_bin = plane_dir / "data_raw.bin"
    ops_npy = plane_dir / "ops.npy"
    status = "✓" if data_raw_bin.exists() and ops_npy.exists() else "✗"
    print(f"  {status} {plane_dir.name}")

## 3. Frame Organization Analysis

For multi-plane data, frames are organized as `(T*Z, Y, X)` where consecutive frames cycle through z-planes.
Each plane's binary should contain every Z-th frame.

In [None]:
# Determine z-plane organization
num_planes = len(plane_dirs)
total_frames = arr.shape[0]
frames_per_plane = total_frames // num_planes

print(f"Data organization:")
print(f"  Total frames in raw data: {total_frames}")
print(f"  Number of z-planes: {num_planes}")
print(f"  Expected frames per plane: {frames_per_plane}")
print(f"  Expected frame indices for plane 0: 0, {num_planes}, {num_planes*2}, ...")
print(f"  Expected frame indices for plane 1: 1, {num_planes+1}, {num_planes*2+1}, ...")

## 4. Load Suite2p Binaries

In [None]:
# Load all planes
suite2p_arrays = {}

for plane_dir in tqdm(plane_dirs, desc="Loading Suite2p binaries"):
    plane_name = plane_dir.name
    try:
        s2p_arr = mbo.imread(plane_dir / "data_raw.bin")
        suite2p_arrays[plane_name] = s2p_arr
        print(f"\n{plane_name}:")
        print(f"  Shape: {s2p_arr.shape}")
        print(f"  dtype: {s2p_arr.dtype}")
    except Exception as e:
        print(f"\nFailed to load {plane_name}: {e}")

print(f"\nSuccessfully loaded {len(suite2p_arrays)} planes")

## 5. Frame-by-Frame Comparison

Compare each plane's frames against the corresponding frames in the raw data.

In [None]:
def compare_frames(raw_frame, bin_frame):
    """Compare two frames and return metrics."""
    # Ensure same dtype for comparison
    raw_frame = raw_frame.astype(np.float32)
    bin_frame = bin_frame.astype(np.float32)
    
    # Exact match
    exact_match = np.array_equal(raw_frame, bin_frame)
    
    # Correlation
    corr, _ = pearsonr(raw_frame.flatten(), bin_frame.flatten())
    
    # Mean absolute difference
    mae = np.mean(np.abs(raw_frame - bin_frame))
    
    # Max absolute difference
    max_diff = np.max(np.abs(raw_frame - bin_frame))
    
    return {
        'exact_match': exact_match,
        'correlation': corr,
        'mae': mae,
        'max_diff': max_diff
    }

print("Starting frame-by-frame validation...")
print("This may take a while for large datasets.\n")

In [None]:
# Validate first few frames of each plane as a quick check
num_test_frames = 10
quick_results = []

for plane_idx, (plane_name, s2p_arr) in enumerate(suite2p_arrays.items()):
    print(f"\nQuick check: {plane_name} (first {num_test_frames} frames)")
    
    for frame_idx in range(min(num_test_frames, s2p_arr.shape[0])):
        # Get corresponding raw frame index (every num_planes-th frame, starting at plane_idx)
        raw_frame_idx = plane_idx + (frame_idx * num_planes)
        
        # Load frames
        raw_frame = arr[raw_frame_idx]
        bin_frame = s2p_arr[frame_idx]
        
        # Compare
        metrics = compare_frames(raw_frame, bin_frame)
        
        quick_results.append({
            'plane': plane_name,
            'plane_idx': plane_idx,
            'bin_frame_idx': frame_idx,
            'raw_frame_idx': raw_frame_idx,
            **metrics
        })
        
        status = "✓" if metrics['exact_match'] else f"✗ (corr={metrics['correlation']:.4f}, mae={metrics['mae']:.2f})"
        print(f"  Frame {frame_idx} (raw {raw_frame_idx}): {status}")

df_quick = pd.DataFrame(quick_results)
print("\n" + "="*60)
print("Quick Check Summary:")
print(f"  Total frames checked: {len(df_quick)}")
print(f"  Exact matches: {df_quick['exact_match'].sum()} / {len(df_quick)}")
print(f"  Mean correlation: {df_quick['correlation'].mean():.6f}")
print(f"  Mean MAE: {df_quick['mae'].mean():.4f}")
print(f"  Max difference: {df_quick['max_diff'].max():.4f}")

## 6. Full Validation (All Frames)

**Warning:** This will load and compare every frame. For large datasets, this may take significant time and memory.

In [None]:
# Option to run full validation (set to True to enable)
RUN_FULL_VALIDATION = False

if RUN_FULL_VALIDATION:
    full_results = []
    
    for plane_idx, (plane_name, s2p_arr) in enumerate(suite2p_arrays.items()):
        print(f"\nValidating {plane_name} ({s2p_arr.shape[0]} frames)...")
        
        for frame_idx in tqdm(range(s2p_arr.shape[0]), desc=f"{plane_name}"):
            raw_frame_idx = plane_idx + (frame_idx * num_planes)
            
            raw_frame = arr[raw_frame_idx]
            bin_frame = s2p_arr[frame_idx]
            
            metrics = compare_frames(raw_frame, bin_frame)
            
            full_results.append({
                'plane': plane_name,
                'plane_idx': plane_idx,
                'bin_frame_idx': frame_idx,
                'raw_frame_idx': raw_frame_idx,
                **metrics
            })
    
    df_full = pd.DataFrame(full_results)
    
    print("\n" + "="*60)
    print("FULL VALIDATION RESULTS:")
    print(f"  Total frames checked: {len(df_full)}")
    print(f"  Exact matches: {df_full['exact_match'].sum()} / {len(df_full)}")
    print(f"  Mean correlation: {df_full['correlation'].mean():.6f}")
    print(f"  Min correlation: {df_full['correlation'].min():.6f}")
    print(f"  Mean MAE: {df_full['mae'].mean():.4f}")
    print(f"  Max MAE: {df_full['mae'].max():.4f}")
    print(f"  Max difference: {df_full['max_diff'].max():.4f}")
    
    # Flag suspicious frames (correlation < 0.99)
    suspicious = df_full[df_full['correlation'] < 0.99]
    if len(suspicious) > 0:
        print(f"\n⚠ Found {len(suspicious)} frames with correlation < 0.99:")
        print(suspicious[['plane', 'bin_frame_idx', 'raw_frame_idx', 'correlation', 'mae']].to_string())
    else:
        print("\n✓ All frames have correlation >= 0.99")
else:
    print("Full validation disabled. Set RUN_FULL_VALIDATION = True to run.")

## 7. Visual Comparison

Display side-by-side comparison of raw vs. binary frames.

In [None]:
# Choose a plane and frame to visualize
test_plane = "plane01_stitched"
test_frame_idx = 5  # Frame index in the binary file

if test_plane in suite2p_arrays:
    s2p_arr = suite2p_arrays[test_plane]
    plane_idx = list(suite2p_arrays.keys()).index(test_plane)
    raw_frame_idx = plane_idx + (test_frame_idx * num_planes)
    
    raw_frame = arr[raw_frame_idx]
    bin_frame = s2p_arr[test_frame_idx]
    
    metrics = compare_frames(raw_frame, bin_frame)
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Raw frame
    im0 = axes[0].imshow(raw_frame, cmap='gray', vmin=0, vmax=np.percentile(raw_frame, 99.5))
    axes[0].set_title(f"Raw Frame {raw_frame_idx}\n(from ScanImage TIFFs)")
    axes[0].axis('off')
    plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
    
    # Binary frame
    im1 = axes[1].imshow(bin_frame, cmap='gray', vmin=0, vmax=np.percentile(bin_frame, 99.5))
    axes[1].set_title(f"Binary Frame {test_frame_idx}\n({test_plane}/data_raw.bin)")
    axes[1].axis('off')
    plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    
    # Difference
    diff = np.abs(raw_frame.astype(np.float32) - bin_frame.astype(np.float32))
    im2 = axes[2].imshow(diff, cmap='hot', vmin=0, vmax=np.percentile(diff, 99.5))
    axes[2].set_title(f"Absolute Difference\nMAE: {metrics['mae']:.2f}, Corr: {metrics['correlation']:.6f}")
    axes[2].axis('off')
    plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nComparison metrics:")
    print(f"  Exact match: {metrics['exact_match']}")
    print(f"  Correlation: {metrics['correlation']:.6f}")
    print(f"  Mean absolute error: {metrics['mae']:.4f}")
    print(f"  Max absolute difference: {metrics['max_diff']:.4f}")
else:
    print(f"Plane {test_plane} not found in loaded arrays")

## 8. Test for Frame Order Issues

Check if any frames appear to be out of order by comparing temporal neighbors.

In [None]:
# Test temporal consistency for one plane
test_plane = "plane01_stitched"
num_temporal_tests = 50  # Test first N frames

if test_plane in suite2p_arrays:
    s2p_arr = suite2p_arrays[test_plane]
    plane_idx = list(suite2p_arrays.keys()).index(test_plane)
    
    temporal_results = []
    
    print(f"Testing temporal consistency for {test_plane}...\n")
    
    for t in range(1, min(num_temporal_tests, s2p_arr.shape[0])):
        # Compare frame t with frame t-1
        raw_t = arr[plane_idx + (t * num_planes)]
        raw_t_prev = arr[plane_idx + ((t-1) * num_planes)]
        
        bin_t = s2p_arr[t]
        bin_t_prev = s2p_arr[t-1]
        
        # Correlation between consecutive frames (should be similar for raw and binary)
        raw_corr, _ = pearsonr(raw_t.flatten(), raw_t_prev.flatten())
        bin_corr, _ = pearsonr(bin_t.flatten(), bin_t_prev.flatten())
        
        temporal_results.append({
            'frame_idx': t,
            'raw_temporal_corr': raw_corr,
            'bin_temporal_corr': bin_corr,
            'corr_diff': abs(raw_corr - bin_corr)
        })
    
    df_temporal = pd.DataFrame(temporal_results)
    
    print(f"Temporal consistency check:")
    print(f"  Mean raw temporal correlation: {df_temporal['raw_temporal_corr'].mean():.4f}")
    print(f"  Mean binary temporal correlation: {df_temporal['bin_temporal_corr'].mean():.4f}")
    print(f"  Mean correlation difference: {df_temporal['corr_diff'].mean():.6f}")
    print(f"  Max correlation difference: {df_temporal['corr_diff'].max():.6f}")
    
    # Plot temporal correlations
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.plot(df_temporal['frame_idx'], df_temporal['raw_temporal_corr'], 'b-', label='Raw data', alpha=0.7)
    ax.plot(df_temporal['frame_idx'], df_temporal['bin_temporal_corr'], 'r--', label='Binary data', alpha=0.7)
    ax.set_xlabel('Frame index')
    ax.set_ylabel('Correlation with previous frame')
    ax.set_title(f'Temporal Consistency Check: {test_plane}')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    # Flag potential frame order issues
    suspicious_order = df_temporal[df_temporal['corr_diff'] > 0.01]
    if len(suspicious_order) > 0:
        print(f"\n⚠ Found {len(suspicious_order)} frames with suspicious temporal correlation:")
        print(suspicious_order.to_string())
    else:
        print("\n✓ No temporal ordering issues detected")

## 9. Summary Report

In [None]:
print("="*80)
print("VALIDATION SUMMARY")
print("="*80)
print(f"\nRaw data: {raw_data_path}")
print(f"  Total frames: {arr.shape[0]}")
print(f"  Shape: {arr.shape}")
print(f"\nSuite2p output: {suite2p_dir}")
print(f"  Number of planes: {len(suite2p_arrays)}")
print(f"  Frames per plane: {frames_per_plane}")

print(f"\nQuick validation ({num_test_frames} frames per plane):")
print(f"  Frames checked: {len(df_quick)}")
print(f"  Exact matches: {df_quick['exact_match'].sum()} ({df_quick['exact_match'].sum()/len(df_quick)*100:.1f}%)")
print(f"  Mean correlation: {df_quick['correlation'].mean():.6f}")
print(f"  Mean MAE: {df_quick['mae'].mean():.4f}")

if df_quick['exact_match'].all():
    print("\n✓ PASSED: All tested frames match exactly!")
elif df_quick['correlation'].min() > 0.999:
    print("\n✓ PASSED: All frames have very high correlation (>0.999)")
    print("  Small differences may be due to floating-point precision or phase correction.")
elif df_quick['correlation'].min() > 0.99:
    print("\n⚠ WARNING: Frames have high correlation (>0.99) but are not exact matches.")
    print("  This may indicate minor differences in processing or precision.")
else:
    print("\n✗ FAILED: Some frames have low correlation (<0.99)")
    print("  This suggests frames may be shuffled, corrupted, or improperly extracted.")
    print("\n  Frames with correlation < 0.99:")
    print(df_quick[df_quick['correlation'] < 0.99][['plane', 'bin_frame_idx', 'raw_frame_idx', 'correlation']].to_string())

print("\n" + "="*80)