# Setup:
1. Create environment:
    In terminal, run:
    
    `conda env create -n caiman_test_env -f studio/app/optinist/wrappers/caiman/conda/caiman.yaml`

    `conda activate caiman_env`

2. Install some additional packages:

   `pip install pynwb imageio ipython jupyter notebook "pydantic<2.0.0" python-dotenv uvicorn xmltodict plotly`
  - If running in VS code, you may need to restart and/or select the correct environment with "Python: Select Interpreter"

3. Run this notebook

In [106]:
import os
import sys
import uuid
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('.'))

# Import OptiNiSt core data modules
from studio.app.dir_path import DIRPATH
from studio.app.common.dataclass import ImageData
from studio.app.optinist.dataclass import FluoData
# Import ROI detection modules
from studio.app.optinist.wrappers.caiman import motion_correction, cnmf
# Import OptiNiSt analysis modules
from studio.app.optinist.wrappers.optinist.dimension_reduction.pca import PCA

import numpy as np

# Import visualization modules
from caiman.utils.visualization import local_correlations
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

# Create input directories based on default saving path
input_dir = os.path.join(DIRPATH.INPUT_DIR, "1")
os.makedirs(input_dir, exist_ok=True)
unique_id = str(uuid.uuid4())[:8]  # Generate 8-char unique ID

In [122]:
# Input file path
input_file = os.path.join(input_dir, "sample_mouse2p_image.tiff")
sample_data = ImageData([input_file])

In [123]:
# Set parameters for motion correction
motion_correction_params = {
    'border_nan': 'copy', 
    'gSig_filt': None, 
    'is3D': False, 
    'max_deviation_rigid': 3, 
    'max_shifts': [6, 6], 
    'min_mov': None, 
    'niter_rig': 1, 
    'nonneg_movie': True, 
    'num_frames_split': 80, 
    'num_splits_to_process_els': None, 
    'num_splits_to_process_rig': None, 
    'overlaps': [32, 32], 
    'pw_rigid': False, 
    'shifts_opencv': True, 
    'splits_els': 14, 
    'splits_rig': 14, 
    'strides': [96, 96], 
    'upsample_factor_grid': 4, 
    'use_cuda': False
}

In [124]:
# Create output directory for motion correction
mc_function_id = f"caiman_mc_{unique_id}"
mc_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, mc_function_id)
os.makedirs(mc_output_dir, exist_ok=True)

In [None]:
# Perform motion correction
ret_mc = motion_correction.caiman_mc(sample_data, mc_output_dir, motion_correction_params)

In [111]:
# Set parameters for CNMF roi detection
caiman_cnmf_params = {
    'p': 1,
    'nb': 2,
    'merge_thr': 0.85,
    'rf': None,
    'stride': 6,
    'K': 4,
    'gSig': [4, 4], 
    'method_init': 'greedy_roi',
    'ssub': 1,
    'tsub': 1,
    'thr': 0.9,
}

In [112]:
# Create output directory for CNMF
cnmf_function_id = f"caiman_cnmf_{unique_id}"
cnmf_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, cnmf_function_id)
os.makedirs(cnmf_output_dir, exist_ok=True)

In [None]:
# Run CNMF for ROI detection
ret_cnmf = cnmf.caiman_cnmf(ret_mc['mc_images'], cnmf_output_dir, caiman_cnmf_params)

In [None]:
# Plot results of motion correction and CNMF

# Get variables from results
cnmf_function_id = list(ret_cnmf['nwbfile']['ROI'].keys())[0]
roi_list = ret_cnmf['nwbfile']['ROI'][cnmf_function_id]
mc_images = ret_mc['mc_images'].data
dims = mc_images.shape[1:]

# Create subplot figure
fig = make_subplots(rows=2, cols=2, 
                    subplot_titles=('Mean Image', 'All ROI Masks',
                                  'Mean Activity', 'Individual ROI Traces'),
                    vertical_spacing=0.12,
                    horizontal_spacing=0.1)

# 1. Mean Image
Cn = local_correlations(mc_images.transpose(1, 2, 0))
fig.add_trace(
    go.Heatmap(z=Cn, colorscale='gray', 
               showscale=False,
               name='Mean Image',
               showlegend=False),
    row=1, col=1
)

# 2. ROI Masks
combined_mask = np.zeros(dims)
for i, roi in enumerate(roi_list):
    combined_mask += roi['image_mask'] * (i + 1)
    
fig.add_trace(
    go.Heatmap(z=combined_mask,
               colorscale='viridis',
               showscale=True,
               name='ROI Masks',
               showlegend=False,
               colorbar=dict(title='ROI #',
                           len=0.4,
                           y=0.8)),
    row=1, col=2
)

# 3. Mean Activity
fluo_data = ret_cnmf['fluorescence'].data
mean_activity = np.mean(fluo_data, axis=0)
time_points = np.arange(len(mean_activity))

fig.add_trace(
    go.Scatter(x=time_points,
               y=mean_activity,
               mode='lines',
               name='Mean Activity',
               showlegend=False,
               legendgroup='mean_activity',
               legendgrouptitle_text='Mean Activity'),
    row=2, col=1
)

# 4. Individual ROI Traces
colors = px.colors.qualitative.Set3
for i in range(min(5, fluo_data.shape[0])):
    fig.add_trace(
        go.Scatter(x=time_points,
                  y=fluo_data[i,:],
                  mode='lines',
                  name=f'ROI {i+1}',
                  line=dict(color=colors[i]),
                  showlegend=True,
                  legendgroup='roi_traces',
                  legendgrouptitle_text='ROI Traces'),
        row=2, col=2
    )

# Update layout with separate legends
fig.update_layout(
    height=800,
    width=1000,
    title=dict(
        text="CaImAn Analysis Results",
        x=0.5,
        y=0.95
    ),
    showlegend=True,
)

# Create separate legends for different subplots
fig.update_layout(
   legend=dict(
       x=1.0,
       y=0.4,
       traceorder='grouped',
       tracegroupgap=5
   )
)

# Update axes labels
fig.update_xaxes(title_text="X Position", row=1, col=1)
fig.update_yaxes(title_text="Y Position", row=1, col=1)
fig.update_xaxes(title_text="X Position", row=1, col=2)
fig.update_yaxes(title_text="Y Position", row=1, col=2)
fig.update_xaxes(title_text="Time Points", row=2, col=1)
fig.update_yaxes(title_text="Fluorescence", row=2, col=1)
fig.update_xaxes(title_text="Time Points", row=2, col=2)
fig.update_yaxes(title_text="Fluorescence", row=2, col=2)

fig.show()

In [115]:
# Set parameters for PCA
pca_params = {
    'transpose': True,
    'standard_mean': True,
    'standard_std': True,
    'PCA': {
        'n_components': None,
        'whiten': False
    }
}

In [116]:
# Create output directory for PCA
pca_function_id = f"pca_{unique_id}"
pca_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, pca_function_id)
os.makedirs(pca_output_dir, exist_ok=True)

In [None]:
# PCA analysis
fluo_data_obj = FluoData(ret_cnmf['fluorescence'].data, file_name="fluorescence")
pca_results = PCA(fluo_data_obj, pca_output_dir, ret_cnmf['iscell'], pca_params)

# Print PCA results contents
print("\nPCA Results:")
for key, value in pca_results.items():
    if key != 'nwbfile': 
        print(f"\n{key}:")
        print(f"Shape: {value.data.shape}")

In [None]:
# Plot PCA results

# Create figure
fig = make_subplots(rows=2, cols=2,
                    subplot_titles=('Explained Variance', 'PCA Projection'))

# 1. Explained variance plot
evr_data = pca_results['explained_variance'].data.flatten()
x_vals = np.arange(1, len(evr_data) + 1)

fig.add_trace(
    go.Bar(x=x_vals,
           y=evr_data * 100,
           name='Explained Variance',
           showlegend=False),
    row=1, col=1
)

# 2. PCA projection plot with trajectory
proj_data = pca_results['projectedNd'].data.T  # Transpose to get (2000, 4)
time_points = np.arange(len(proj_data))

fig.add_trace(
    go.Scatter(x=proj_data[:, 0],  # PC1
               y=proj_data[:, 1],  # PC2
               mode='markers',
               marker=dict(
                   size=3,
                   color=time_points,
                   colorscale='Viridis',
                   showscale=True,
                   colorbar=dict(
                       len=0.4,
                       y=0.8,
                       )
               ),
               showlegend=False,
               name='PC Trajectory'),
    row=1, col=2
)

# 3. Component contributions plot
contrib_data = pca_results['contribution'].data
fig.add_trace(
    go.Bar(x=np.arange(len(contrib_data[0])),
           y=contrib_data[0],  # First PC's contributions
           name='PC1 Contributions',
           showlegend=False),
    row=2, col=1
)

# 4. Cumulative contributions plot
cumsum_data = pca_results['cumsum_contribution'].data
fig.add_trace(
    go.Bar(x=np.arange(len(cumsum_data[0])),
           y=cumsum_data[0],  # Cumulative contributions for PC1
           name='PC1 Cumulative',
           showlegend=False),
    row=2, col=2
)

# Update axes labels
fig.update_xaxes(title_text="Principal Component", row=1, col=1)
fig.update_yaxes(title_text="Explained Variance Ratio", row=1, col=1)
fig.update_xaxes(title_text="PC1", row=1, col=2)
fig.update_yaxes(title_text="PC2", row=1, col=2)
fig.update_xaxes(title_text="Component Index", row=2, col=1)
fig.update_yaxes(title_text="Contribution", row=2, col=1)
fig.update_xaxes(title_text="Component Index", row=2, col=2)
fig.update_yaxes(title_text="Cumulative Contribution", row=2, col=2)

# Update layout
fig.update_layout(
    height=800,
    width=1000,
    showlegend=True,
    title="PCA Analysis of CaImAn Results",
    yaxis=dict(range=[0, 100])  # Set range for explained variance
)

fig.show()