# Setup:
1. Create environment:
    In terminal, run:
    
    `conda env create -n caiman_test_env -f studio/app/optinist/wrappers/caiman/conda/caiman.yaml`

    `conda activate caiman_test_env`

2. Install some additional packages:

   `pip install pynwb imageio ipython jupyter notebook "pydantic<2.0.0" python-dotenv uvicorn xmltodict plotly scikit-image opencv-python`
  - If running in VS code, you may need to restart and/or select the correct environment with "Python: Select Interpreter"

3. Run this notebook

### Run caiman on sample data to get labelimage

In [1]:
import os
import sys
import uuid
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('.'))


# Import OptiNiSt core data modules
from studio.app.dir_path import DIRPATH
from studio.app.common.dataclass import ImageData
from studio.app.optinist.dataclass import FluoData
# Import ROI detection modules
from studio.app.optinist.wrappers.caiman import motion_correction, cnmf
# Import OptiNiSt analysis modules
from studio.app.optinist.wrappers.optinist.dimension_reduction.pca import PCA

import numpy as np

# Create input directories based on default saving path
input_dir = os.path.join(DIRPATH.INPUT_DIR, "1")
os.makedirs(input_dir, exist_ok=True)
unique_id = str(uuid.uuid4())[:8]  # Generate 8-char unique ID

# Input file path
input_file = os.path.join(input_dir, "sample_mouse2p_image.tiff")
sample_data = ImageData([input_file])

In [None]:
# Motion correction

# Set parameters
motion_correction_params = {
    'border_nan': 'copy', 
    'gSig_filt': None, 
    'is3D': False, 
    'max_deviation_rigid': 3, 
    'max_shifts': [6, 6], 
    'min_mov': None, 
    'niter_rig': 1, 
    'nonneg_movie': True, 
    'num_frames_split': 80, 
    'num_splits_to_process_els': None, 
    'num_splits_to_process_rig': None, 
    'overlaps': [32, 32], 
    'pw_rigid': False, 
    'shifts_opencv': True, 
    'splits_els': 14, 
    'splits_rig': 14, 
    'strides': [96, 96], 
    'upsample_factor_grid': 4, 
    'use_cuda': False
}

# Create output directory for motion correction
mc_function_id = f"caiman_mc_{unique_id}"
mc_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, mc_function_id)
os.makedirs(mc_output_dir, exist_ok=True)

# Perform motion correction
ret_mc = motion_correction.caiman_mc(sample_data, mc_output_dir, motion_correction_params)

In [None]:
# CNMF roi detection
# Set parameters
caiman_cnmf_params = {
    'p': 2,
    'nb': 2,
    'merge_thr': 0.85,
    'stride': 6,
    'K': 10,
    'gSig': [4, 4], 
    'method_init': 'greedy_roi',
    'ssub': 2,
    'tsub': 2,
    'roi_thr': 0.9,
    'do_refit': False,
    'use_online': False,
    'use_cnn': False,
}


# Create output directory for CNMF
cnmf_function_id = f"caiman_cnmf_{unique_id}"
cnmf_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, cnmf_function_id)
os.makedirs(cnmf_output_dir, exist_ok=True)

# Run CNMF for ROI detection
cnmf_info = cnmf.caiman_cnmf(ret_mc['mc_images'], cnmf_output_dir, caiman_cnmf_params)

# Functions used in batch processing (adjusted slightly for notebook format)

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from scipy import sparse
from scipy.io import loadmat
from scipy.signal import convolve
from skimage import measure
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler

In [5]:
# PCA Analysis Section
def pca_analysis(fluorescence, roi_masks, params=None):
    """Perform PCA analysis similar to pca_analysis.py module"""
    # Get data shape
    n_cells = fluorescence.shape[0]
    print(f"PCA will use {n_cells} cells")
    
    # Check if we have enough ROIs for PCA
    if n_cells < 2:
        print("Not enough cells for PCA analysis (minimum 2 required)")
        # Create dummy placeholders
        dummy_scores = np.zeros((1, 1))
        dummy_components = np.zeros((1, 1))
        dummy_explained_variance = np.zeros(1)
        
        return {
            'scores': dummy_scores,
            'components': dummy_components,
            'explained_variance': dummy_explained_variance,
            'has_sufficient_data': False
        }
    
    # Set default parameters if none provided
    if params is None:
        params = {"n_components": min(50, n_cells), "standard_norm": True}
    
    # Prepare data
    if params.get("standard_norm", True):
        # Center the data
        data = fluorescence - np.mean(fluorescence, axis=1, keepdims=True)
        # Scale to unit variance
        std_values = np.std(data, axis=1, keepdims=True)
        # Avoid division by zero
        std_values[std_values == 0] = 1.0
        data = data / std_values
    else:
        data = fluorescence
    
    # Perform PCA
    pca = PCA(n_components=params["n_components"])
    scores = pca.fit_transform(data.T)  # time x components
    components = pca.components_  # components x cells
    explained_variance = pca.explained_variance_ratio_ * 100
    
    return {
        'scores': scores,
        'components': components,
        'explained_variance': explained_variance,
        'has_sufficient_data': True
    }

In [6]:
# K-means Analysis Section
def kmeans_analysis(fluorescence, roi_masks, params=None):
    """Perform KMeans analysis similar to kmeans_analysis.py module"""
    # Get data shape
    n_cells = fluorescence.shape[0]
    print(f"KMeans will use {n_cells} cells")
    
    # Set default parameters if none provided
    if params is None:
        params = {}
    
    # Ensure n_clusters exists and doesn't exceed the number of cells
    params["n_clusters"] = min(params.get("n_clusters", 3), n_cells)
    
    # Handle case when there are insufficient cells for clustering
    if n_cells < 2:
        print("Not enough cells for KMeans clustering (minimum 2 required)")
        # Set dummy values
        cluster_labels = np.zeros(max(1, n_cells), dtype=int)
        corr_matrix = np.ones((max(1, n_cells), max(1, n_cells)), dtype=float)
        
        return {
            'labels': cluster_labels,
            'corr_matrix': corr_matrix,
            'has_sufficient_data': False
        }
    
    # Calculate correlation matrix
    corr_matrix = np.corrcoef(fluorescence)
    
    # Determine optimal number of clusters using silhouette score
    k_range = range(2, min(21, n_cells))  # Test cluster numbers from 2 to 20 (or max cells)
    silhouette_values = []
    
    # Skip silhouette computation if we have too few cells
    if n_cells > 3:
        for k in k_range:
            kmeans = KMeans(n_clusters=k, init='k-means++', n_init=10, random_state=42)
            labels = kmeans.fit_predict(corr_matrix)
            silhouette_values.append(silhouette_score(corr_matrix, labels))
        
        # Choose k with highest silhouette score
        best_k_idx = np.argmax(silhouette_values)
        k_optimal = k_range[best_k_idx]
    else:
        k_optimal = min(n_cells, params["n_clusters"])
    
    # Perform clustering with optimal k
    kmeans = KMeans(n_clusters=k_optimal, init='k-means++', n_init=10, random_state=42)
    cluster_labels = kmeans.fit_predict(corr_matrix)
    
    return {
        'labels': cluster_labels,
        'corr_matrix': corr_matrix,
        'silhouette_values': silhouette_values if n_cells > 3 else None,
        'k_optimal': k_optimal,
        'has_sufficient_data': True
    }

In [7]:
def generate_pca_visualization(scores, explained_variance, components, roi_masks, output_dir):
    """Generate PCA visualization with separate files for each component"""
    # Check if inputs are valid
    if components is None or scores is None:
        print("Warning: Missing PCA components or scores")
        return
    
    # Handle the case of insufficient ROIs - create error images
    is_data_insufficient = (
        components.shape[0] < 2 or 
        scores.shape[1] < 2 or 
        np.allclose(components, 0, atol=1e-7) or 
        np.allclose(scores, 0, atol=1e-7) or 
        np.all(np.isnan(components)) or 
        np.all(np.isnan(scores))
    )
    
    if is_data_insufficient:
        # Create error image for variance plot
        plt.figure()
        plt.text(
            0.5, 0.5,
            "Insufficient ROIs for PCA analysis.\nAt least 2 ROIs required.",
            ha="center", va="center",
            transform=plt.gca().transAxes
        )
        plt.axis("off")
        variance_path = os.path.join(output_dir, "pca_analysis_variance.png")
        plt.savefig(variance_path, bbox_inches="tight")
        plt.close()
        
        # Similar error images for other plots
        scatter_path = os.path.join(output_dir, "pca_analysis.png")
        contrib_path = os.path.join(output_dir, "pca_contribution.png")
        spatial_path = os.path.join(output_dir, "pca_component_1_spatial.png")
        time_path = os.path.join(output_dir, "pca_component_1_time.png")
        
        for path in [scatter_path, contrib_path, spatial_path, time_path]:
            plt.figure()
            plt.text(
                0.5, 0.5,
                "Insufficient ROIs for PCA analysis.\nAt least 2 ROIs required.",
                ha="center", va="center",
                transform=plt.gca().transAxes
            )
            plt.axis("off")
            plt.savefig(path, bbox_inches="tight")
            plt.close()        
        return
    
    # Number of components to visualize
    num_components = min(50, components.shape[0], scores.shape[1])
    plots_to_show = 10  # Set to 10 as too many make legend illegible
    
    # 1. Plot explained variance
    plt.figure()
    num_display = min(plots_to_show, len(explained_variance))
    plt.bar(range(1, num_display + 1), explained_variance[:num_display])
    plt.title("Explained Variance")
    plt.xlabel("Principal Component")
    plt.ylabel("Explained Variance (%)")
    plt.grid(True, alpha=0.3)
    
    variance_path = os.path.join(output_dir, "pca_analysis_variance.png")
    plt.savefig(variance_path, bbox_inches="tight")
    plt.close()
    
    # 2. Create PCA scatter plot (first 2-3 components)
    plt.figure()
    if scores.shape[1] >= 2:
        plt.scatter(scores[:, 0], scores[:, 1], alpha=0.7)
        plt.xlabel("PC 1")
        plt.ylabel("PC 2")
        
        if scores.shape[1] >= 3:
            plt.figure()
            plt.scatter(
                scores[:, 0], scores[:, 1], 
                c=scores[:, 2], cmap="viridis", alpha=0.7
            )
            plt.colorbar(label="PC 3")
            plt.xlabel("PC 1")
            plt.ylabel("PC 2")
    
    scatter_path = os.path.join(output_dir, "pca_analysis.png")
    plt.savefig(scatter_path, bbox_inches="tight")
    plt.close()
    
    # 3. For each component, create time course and spatial map
    for i in range(num_components):
        # Time course
        plt.figure()
        plt.plot(scores[:, i], linewidth=2)
        plt.title(f"PC {i+1} Time Course")
        plt.xlabel("Time")
        plt.ylabel("Component Value")
        plt.grid(True, alpha=0.3)
        
        time_path = os.path.join(output_dir, f"pca_component_{i+1}_time.png")
        plt.savefig(time_path, bbox_inches="tight")
        plt.close()
        
        # Spatial map - attempt only if roi_masks has appropriate shape
        component_weights = components[i]  # Using actual weights, not absolute values
        
        # Create spatial component maps
        if roi_masks is not None and hasattr(roi_masks, "shape"):
            try:
                # Extract valid cell IDs (non-NaN values) from roi_masks
                non_nan_mask = (
                    ~np.isnan(roi_masks)
                    if np.any(np.isnan(roi_masks))
                    else np.ones_like(roi_masks, dtype=bool)
                )
                
                if np.any(non_nan_mask):
                    # Create component map
                    component_map = np.full_like(roi_masks, np.nan)
                    
                    # Get unique cell IDs
                    valid_ids = np.unique(roi_masks[non_nan_mask])
                    valid_ids = np.sort(valid_ids)
                    
                    # Map each cell's weight to its spatial location
                    for idx, cell_id in enumerate(valid_ids):
                        if idx < len(component_weights):
                            # Find pixels for this cell and assign component weight
                            cell_mask = np.isclose(roi_masks, cell_id)
                            if np.any(cell_mask):
                                component_map[cell_mask] = component_weights[idx]
                    
                    # Check if map has valid data
                    if not np.all(np.isnan(component_map)):
                        # Use symmetric divergent colormap with consistent scaling
                        vmax = np.nanmax(np.abs(component_map))
                        
                        plt.figure()
                        im = plt.imshow(component_map, cmap="RdBu_r", vmin=-vmax, vmax=vmax)
                        plt.colorbar(im, label="Component Weight")
                        plt.title(f"PC {i+1} Spatial Map")
                        
                        spatial_path = os.path.join(output_dir, f"pca_component_{i+1}_spatial.png")
                        plt.savefig(spatial_path, bbox_inches="tight")
                        plt.close()
                    else:
                        raise ValueError("No valid values in component map")
                else:
                    raise ValueError("No non-NaN values found in ROI mask")
            
            except Exception as e:
                print(f"Error creating spatial map for PC {i+1}: {str(e)}")
                
                # Create fallback visualization
                plt.figure()
                plt.bar(range(len(component_weights)), component_weights)
                plt.title(f"PC {i+1} Component Weights")
                plt.xlabel("Cell Index")
                plt.ylabel("Weight")
                plt.grid(True, alpha=0.3)
                
                spatial_path = os.path.join(output_dir, f"pca_component_{i+1}_spatial.png")
                plt.savefig(spatial_path, bbox_inches="tight")
                plt.close()
        else:
            # Create alternative visualization using direct component values
            plt.figure()
            plt.bar(range(len(component_weights)), component_weights)
            plt.title(f"PC {i+1} Component Weights")
            plt.xlabel("Cell Index")
            plt.ylabel("Weight")
            plt.grid(True, alpha=0.3)
            
            spatial_path = os.path.join(output_dir, f"pca_component_{i+1}_spatial.png")
            plt.savefig(spatial_path, bbox_inches="tight")
            plt.close()
    
    # 5. Save the contribution weights as a separate visualization
    plt.figure()
    top_n = min(plots_to_show, components.shape[0])
    for i in range(top_n):
        plt.bar(
            range(len(components[i])),
            components[i],  # Using actual weights, not absolute values
            alpha=0.7,
            label=f"PC {i+1}",
        )
    plt.xlabel("Cell Index")
    plt.ylabel("Component Weight")
    plt.title("PCA Component Contributions")
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    contrib_path = os.path.join(output_dir, "pca_contribution.png")
    plt.savefig(contrib_path, bbox_inches="tight")
    plt.close()


In [8]:
def generate_kmeans_visualization(labels, corr_matrix, fluorescence, roi_masks, output_dir):
    """Generate KMeans visualizations with separate files for each component"""
    if labels is None or len(labels) == 0:
        print("Warning: Missing cluster labels")
        return
    
    # Handle the case of insufficient ROIs
    is_data_insufficient = (
        labels is None
        or len(labels) < 2
        or corr_matrix is None
        or corr_matrix.shape[0] < 2
    )
    
    if is_data_insufficient:
        # Create error image for correlation matrix plot
        plt.figure()
        plt.text(
            0.5, 0.5,
            "Insufficient ROIs for k-means clustering.\nAt least 2 ROIs required.",
            ha="center", va="center",
            transform=plt.gca().transAxes
        )
        plt.axis("off")
        matrix_path = os.path.join(output_dir, "clustering_analysis.png")
        plt.savefig(matrix_path, bbox_inches="tight")
        plt.close()
        
        # Create error images for other plots
        time_path = os.path.join(output_dir, "cluster_time_courses.png")
        map_path = os.path.join(output_dir, "cluster_spatial_map.png")
        
        for path in [time_path, map_path]:
            plt.figure()
            plt.text(
                0.5, 0.5,
                "Insufficient ROIs for k-means clustering.\nAt least 2 ROIs required.",
                ha="center", va="center",
                transform=plt.gca().transAxes
            )
            plt.axis("off")
            plt.savefig(path, bbox_inches="tight")
            plt.close()
        
        return
    
    # Reorder correlation matrix based on clusters
    sort_idx = np.argsort(labels)
    sorted_corr_matrix = corr_matrix[sort_idx][:, sort_idx]
    
    # Calculate cluster information
    unique_clusters = np.unique(labels)
    n_clusters = len(unique_clusters)
    colors = plt.cm.jet(np.linspace(0, 1, n_clusters))
    custom_cmap = ListedColormap(colors)
    
    # 1. Correlation matrix heatmap
    plt.figure()
    im = plt.imshow(sorted_corr_matrix, cmap="jet")
    plt.colorbar(im)
    plt.title(f"K-means Clustering (k={n_clusters})")
    plt.xlabel("Cells")
    plt.ylabel("Cells")
    
    matrix_path = os.path.join(output_dir, "clustering_analysis.png")
    plt.savefig(matrix_path, bbox_inches="tight")
    plt.close()
    
    # 2. Mean time courses by cluster
    if fluorescence is not None and fluorescence.shape[0] >= len(labels):
        plt.figure()
        cluster_averages = []
        
        for i, cluster in enumerate(unique_clusters):
            cluster_mask = labels == cluster
            if np.any(cluster_mask):
                cluster_avg = np.mean(fluorescence[cluster_mask], axis=0)
                plt.plot(
                    cluster_avg, color=colors[i], linewidth=2, label=f"Cluster {i+1}"
                )
                cluster_averages.append(cluster_avg)
        
        plt.title("Mean Time Course by Cluster")
        plt.xlabel("Time")
        plt.ylabel("Fluorescence")
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        time_path = os.path.join(output_dir, "cluster_time_courses.png")
        plt.savefig(time_path, bbox_inches="tight")
        plt.close()
    
    # 3. Spatial cluster map - attempt only if roi_masks has appropriate shape
    if roi_masks is not None and hasattr(roi_masks, "shape"):
        try:
            # Create cluster colormap
            unique_clusters = np.unique(labels)
            n_clusters = len(unique_clusters)
            colors = plt.cm.jet(np.linspace(0, 1, n_clusters))
            custom_cmap = ListedColormap(colors)
            
            # Check for 3D mask (standard case with multiple ROIs)
            if len(roi_masks.shape) == 3:
                cluster_map = np.zeros(roi_masks.shape[:2])
                
                # Create cluster map
                for i, label in enumerate(labels):
                    if i < roi_masks.shape[2]:
                        roi_mask = roi_masks[:, :, i]
                        cluster_map[roi_mask > 0] = (
                            label + 1
                        )  # +1 to avoid 0 (background)
                
                # Create a mask of all cell locations
                all_cells_mask = np.zeros(roi_masks.shape[:2], dtype=bool)
                for i in range(roi_masks.shape[2]):
                    all_cells_mask |= roi_masks[:, :, i] > 0
                
                # Create masked cluster map for better visualization
                masked_cluster_map = np.ma.masked_array(
                    cluster_map,
                    mask=~all_cells_mask,  # Mask background (non-cell areas)
                )
                
                # Plot cluster map
                plt.figure()
                im = plt.imshow(
                    masked_cluster_map, cmap=custom_cmap, interpolation="nearest"
                )
                
                # Add colorbar with cluster labels
                colorbar = plt.colorbar(im, ticks=np.arange(1, n_clusters + 1))
                colorbar.set_label("Cluster")
                
                # Add cluster legend with unique colors
                handles = [
                    plt.Rectangle((0, 0), 1, 1, color=colors[i])
                    for i in range(n_clusters)
                ]
                plt.legend(
                    handles,
                    [f"Cluster {i+1}" for i in range(n_clusters)],
                    loc="upper right",
                    bbox_to_anchor=(1.3, 1),
                )
                
                plt.title("Cluster Spatial Map")
                
                # Save maps
                map_path = os.path.join(output_dir, "cluster_spatial_map.png")
                plt.savefig(map_path, bbox_inches="tight")
                plt.close()
                
            # Simpler 2D mask case
            elif len(roi_masks.shape) == 2:
                cluster_map = np.zeros(roi_masks.shape)
                # Use most common cluster for the mask
                if len(labels) > 0:
                    counts = np.bincount(labels)
                    most_common = np.argmax(counts) if len(counts) > 0 else 0
                    cluster_map[roi_masks > 0] = most_common + 1
                
                # Plot and save as above
                plt.figure()
                im = plt.imshow(cluster_map, cmap=custom_cmap)
                plt.colorbar(im, label="Cluster")
                plt.title("Cluster Assignments")
                
                map_path = os.path.join(output_dir, "cluster_spatial_map.png")
                plt.savefig(map_path, bbox_inches="tight")
                plt.close()
        except Exception as e:
            print(f"Could not create cluster spatial map: {str(e)}")

### Add your own caiman data produced using caiman.ipynb

In [9]:
# A_or = # Add path to your cnmf output data here

roi_masks = cnmf_info["cell_roi"].data  
timecourse = cnmf_info["fluorescence"].data

In [None]:
import os

roi_masks = cnmf_info["cell_roi"].data  
timecourse = cnmf_info["fluorescence"].data

# Ensure output directory exists
output_dir = "./test_plots"
os.makedirs(output_dir, exist_ok=True)


# Run PCA Analysis
print("\n===== Running PCA Analysis =====")
pca_results = pca_analysis(timecourse, roi_masks)
if pca_results['has_sufficient_data']:
    generate_pca_visualization(
        pca_results['scores'], 
        pca_results['explained_variance'], 
        pca_results['components'], 
        roi_masks, 
        output_dir
    )
    print("PCA analysis and visualization completed.")
else:
    print("PCA analysis skipped due to insufficient data.")

# Run KMeans Analysis
print("\n===== Running KMeans Analysis =====")
kmeans_results = kmeans_analysis(timecourse, roi_masks)
if kmeans_results['has_sufficient_data']:
    generate_kmeans_visualization(
        kmeans_results['labels'], 
        kmeans_results['corr_matrix'], 
        timecourse, 
        roi_masks, 
        output_dir
    )
    print("KMeans analysis and visualization completed.")
else:
    print("KMeans analysis skipped due to insufficient data.")

print("\nAll analyses completed. Results saved to:", output_dir)