# Setup:
1. Create environment:
    In terminal, run:
    
    `conda env create -n caiman_test_env -f studio/app/optinist/wrappers/caiman/conda/caiman.yaml`

    `conda activate caiman_env`

2. Install some additional packages:

   `pip install pynwb imageio ipython jupyter notebook "pydantic<2.0.0" python-dotenv uvicorn xmltodict fastapi "python-jose[cryptography]" passlib python-multipart bcrypt plotly`
  - If running in VS code, you may need to restart and/or select the correct environment with "Python: Select Interpreter"

3. Run this notebook

In [31]:
import os
import sys
import uuid
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('.'))

from studio.app.optinist.wrappers.caiman import motion_correction, cnmf
from studio.app.common.dataclass import ImageData
from studio.app.dir_path import DIRPATH

# Create input/output directories
input_dir = os.path.join(DIRPATH.INPUT_DIR, "1") 
os.makedirs(input_dir, exist_ok=True)

unique_id = str(uuid.uuid4())[:8]  # Generate 8-char unique ID
mc_function_id = f"caiman_mc_{unique_id}"
mc_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, mc_function_id)
os.makedirs(mc_output_dir, exist_ok=True)

cnmf_function_id = f"caiman_cnmf_{unique_id}"
cnmf_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, cnmf_function_id)
os.makedirs(cnmf_output_dir, exist_ok=True)

In [32]:
# Input file path
input_file = os.path.join(input_dir, "sample_mouse2p_image.tiff")
sample_data = ImageData([input_file])

In [33]:
motion_correction_params = {
    'border_nan': 'copy', 
    'gSig_filt': None, 
    'is3D': False, 
    'max_deviation_rigid': 3, 
    'max_shifts': [6, 6], 
    'min_mov': None, 
    'niter_rig': 1, 
    'nonneg_movie': True, 
    'num_frames_split': 80, 
    'num_splits_to_process_els': None, 
    'num_splits_to_process_rig': None, 
    'overlaps': [32, 32], 
    'pw_rigid': False, 
    'shifts_opencv': True, 
    'splits_els': 14, 
    'splits_rig': 14, 
    'strides': [96, 96], 
    'upsample_factor_grid': 4, 
    'use_cuda': False
}

In [None]:
ret_mc = motion_correction.caiman_mc(sample_data, mc_output_dir, motion_correction_params)

In [35]:
caiman_cnmf_params = {
    'p': 1,
    'nb': 2,
    'merge_thr': 0.85,
    'rf': None,
    'stride': 6,
    'K': 4,
    'gSig': [4, 4], 
    'method_init': 'greedy_roi',
    'ssub': 1,
    'tsub': 1,
    'thr': 0.9,
}

In [None]:
ret_cnmf = cnmf.caiman_cnmf(ret_mc['mc_images'], cnmf_output_dir, caiman_cnmf_params)

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from plotly.subplots import make_subplots
from caiman.utils.visualization import local_correlations

# Get variables from results
cnmf_function_id = list(ret_cnmf['nwbfile']['ROI'].keys())[0]
roi_list = ret_cnmf['nwbfile']['ROI'][cnmf_function_id]
mc_images = ret_mc['mc_images'].data
dims = mc_images.shape[1:]

# Create subplot figure
fig = make_subplots(rows=2, cols=2, 
                    subplot_titles=('Mean Image', 'All ROI Masks',
                                  'Mean Activity', 'Individual ROI Traces'),
                    vertical_spacing=0.12,
                    horizontal_spacing=0.1)

# 1. Mean Image
Cn = local_correlations(mc_images.transpose(1, 2, 0))
fig.add_trace(
    go.Heatmap(z=Cn, colorscale='gray', 
               showscale=False,
               name='Mean Image',
               showlegend=False),
    row=1, col=1
)

# 2. ROI Masks
combined_mask = np.zeros(dims)
for i, roi in enumerate(roi_list):
    combined_mask += roi['image_mask'] * (i + 1)
    
fig.add_trace(
    go.Heatmap(z=combined_mask,
               colorscale='viridis',
               showscale=True,
               name='ROI Masks',
               showlegend=False,
               colorbar=dict(title='ROI #',
                           len=0.4,
                           y=0.8)),
    row=1, col=2
)

# 3. Mean Activity
fluo_data = ret_cnmf['fluorescence'].data
mean_activity = np.mean(fluo_data, axis=0)
time_points = np.arange(len(mean_activity))

fig.add_trace(
    go.Scatter(x=time_points,
               y=mean_activity,
               mode='lines',
               name='Mean Activity',
               showlegend=False,
               legendgroup='mean_activity',
               legendgrouptitle_text='Mean Activity'),
    row=2, col=1
)

# 4. Individual ROI Traces
colors = px.colors.qualitative.Set3
for i in range(min(5, fluo_data.shape[0])):
    fig.add_trace(
        go.Scatter(x=time_points,
                  y=fluo_data[i,:],
                  mode='lines',
                  name=f'ROI {i+1}',
                  line=dict(color=colors[i]),
                  showlegend=True,
                  legendgroup='roi_traces',
                  legendgrouptitle_text='ROI Traces'),
        row=2, col=2
    )

# Update layout with separate legends
fig.update_layout(
    height=800,
    width=1000,
    title=dict(
        text="CaImAn Analysis Results",
        x=0.5,
        y=0.95
    ),
    showlegend=True,
)

# Create separate legends for different subplots
fig.update_layout(
   legend=dict(
       x=1.0,
       y=0.4,
       traceorder='grouped',
       tracegroupgap=5
   )
)

# Update axes labels
fig.update_xaxes(title_text="X Position", row=1, col=1)
fig.update_yaxes(title_text="Y Position", row=1, col=1)
fig.update_xaxes(title_text="X Position", row=1, col=2)
fig.update_yaxes(title_text="Y Position", row=1, col=2)
fig.update_xaxes(title_text="Time Points", row=2, col=1)
fig.update_yaxes(title_text="Fluorescence", row=2, col=1)
fig.update_xaxes(title_text="Time Points", row=2, col=2)
fig.update_yaxes(title_text="Fluorescence", row=2, col=2)

fig.show()