    # Setup:
    1. Create environment:
        In terminal, run:
        `conda env create -n suite2p_test_env -f studio/app/optinist/wrappers/suite2p/conda/suite2p.yaml`

        `conda activate suite2p_test_env`
    
    2. Install additional packages:
    
       `pip install pynwb imageio ipython jupyter notebook plotly "pydantic<2.0.0" python-dotenv uvicorn xmltodict bcrypt`,
      - If running in VS code, you may need to restart and/or select the correct environment with \"Python: Select Interpreter\",
    
    3. Run this notebook


In [1]:
import os
import sys
import uuid
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('.'))

# Import OptiNiSt core data modules
from studio.app.dir_path import DIRPATH
from studio.app.common.dataclass import ImageData
from studio.app.optinist.dataclass import FluoData, BehaviorData
# Import ROI detection modules
from studio.app.optinist.wrappers.suite2p import file_convert, registration, roi
# Import OptiNiSt analysis modules
from studio.app.optinist.wrappers.optinist.basic_neural_analysis.eta import ETA

import numpy as np
import pandas as pd

# Import visualization modules
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

# Create input directories based on default saving path
input_dir = os.path.join(DIRPATH.INPUT_DIR, "1") 
os.makedirs(input_dir, exist_ok=True)
unique_id = str(uuid.uuid4())[:8]  # Generate 8-char unique ID

In [2]:
# Input file path
input_file = os.path.join(input_dir, "sample_mouse2p_image.tiff")
sample_data = ImageData([input_file])

In [3]:
# Set paraemters for file conversion
file_convert_params = {
    'nplanes': 1,
    'nchannels': 1,
    'force_sktiff': False,
    'batch_size': 500,
    'do_registration': 1
}

In [4]:
# Create output directory for file conversion,
convert_function_id = f"suite2p_convert_{unique_id}"
convert_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, convert_function_id)
os.makedirs(convert_output_dir, exist_ok=True)

In [None]:
# Run file conversion
ret_conv = file_convert.suite2p_file_convert(sample_data, convert_output_dir, file_convert_params)

In [6]:
# Set parameters for registration
registration_params = {
    'frames_include': -1,
    'keep_movie_raw': False,
    'do_bidiphase': False,

    'smooth_sigma': 1.15,
    'smooth_sigma_time': 0,
    'bidiphase': 0,
    'maxregshift': 0.1,
    'maxregshiftNR': 5,
    'nonrigid': True,
    'block_size': [128, 128],
    'snr_thresh': 1.2,
    'functional_chan' : 1,
    'align_by_chan' : 1,
    'reg_tif': False,
    'th_badframes': 1.0,
    'diameter': 0,

    # 1P setting
    '1Preg': False,
    'spatial_hp_reg': 42,
    'pre_smooth': 0,
    'spatial_taper': 40,
    'bidi_corrected': False,
}

In [7]:
# Create output directory for registration
reg_function_id = f"suite2p_reg_{unique_id}"
reg_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, convert_function_id)
os.makedirs(reg_output_dir, exist_ok=True)

In [None]:
# Perform registration (motion correction)
ret_reg = registration.suite2p_registration(ret_conv['ops'], reg_output_dir ,registration_params)

In [9]:
# Set parameters for ROI extraction
suite2p_roi_params = {
    # main settings
    'tau':  1.0,              # this is the main parameter for deconvolution
    'fs': 10.0,             # sampling rate (PER PLANE e.g. for 12 plane recordings it will be around 2.5)

    # classification parameters
    'soma_crop': True,        # crop dendrites for cell classification stats like compactness

    # cell detection settings
    'high_pass': 100,         # running mean subtraction with window of size 'high_pass(use low values for 1P)
    'sparse_mode': True,      # whether or not to run sparse_mode
    'max_overlap': 0.75,      # cells with more overlap than this get removed during triage before refinement
    'nbinned': 5000,          # max number of binned frames for cell detection
    'spatial_scale': 0,       # 0: multi-scale; 1: 6 pixels 2: 12 pixels 3: 24 pixels 4: 48 pixels
    'threshold_scaling': 1.0, # adjust the automatically determined threshold by this scalar multiplier
    'max_iterations': 20,     # maximum number of iterations to do cell detection

    # 1P settings
    'spatial_hp_detect': 25,  # window for spatial high-pass filtering for neuropil subtraction before detection

    # output settings
    'preclassify': 0.,       # apply classifier before signal extraction with probability 0.3

    # ROI extraction parameters
    'allow_overlap': False,      # pixels that are overlapping are thrown out (False) or added to both ROIs (True)
    'inner_neuropil_radius': 2,  # number of pixels to keep between ROI and neuropil donut
    'min_neuropil_pixels': 350,  # minimum number of pixels in the neuropil

    # deconvolution settings
    'neucoeff': .7,          # neuropil coefficient
}

In [10]:
# Create output directory for ROI detection
roi_function_id = f"suite2p_roi_{unique_id}"
roi_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, roi_function_id)
os.makedirs(roi_output_dir, exist_ok=True)

In [None]:
# Perform ROI extraction
ret_roi = roi.suite2p_roi(ret_reg['ops'], roi_output_dir, suite2p_roi_params)

In [None]:
# Plot results of registration and roi extraction

# Get data from the custom data classes
mean_img = ret_roi['max_proj'].data  # Using max projection for mean image
roi_data = ret_roi['cell_roi'].data  # Using cell ROIs
F = ret_roi['fluorescence'].data     # Fluorescence data
iscell = ret_roi['iscell'].data      # Cell classification data

# Create subplot figure
fig = make_subplots(rows=2, cols=2, 
                    subplot_titles=('Mean Image', 'ROI Masks',
                                  'Mean Activity', 'Individual ROI Traces'),
                    vertical_spacing=0.12,
                    horizontal_spacing=0.1)

# 1. Mean Image
fig.add_trace(
    go.Heatmap(z=mean_img, 
               colorscale='gray',
               showscale=False,
               name='Mean Image'),
    row=1, col=1
)

# 2. ROI Masks
fig.add_trace(
    go.Heatmap(z=roi_data,
               colorscale='viridis',
               showscale=True,
               name='ROI Masks',
               colorbar=dict(title='ROI #',
                           len=0.4,
                           y=0.8)),
    row=1, col=2
)

# 3. Mean Activity
mean_activity = np.mean(F, axis=0)  # Average across all cells
time_points = np.arange(len(mean_activity))

fig.add_trace(
    go.Scatter(x=time_points,
               y=mean_activity,
               mode='lines',
               name='Mean Activity',
               showlegend=False),
    row=2, col=1
)

# 4. Individual ROI Traces - plot all cells
cell_indices = np.where(iscell == 1)[0]  # Get indices of classified cells
colors = px.colors.qualitative.Set3 * (1 + len(cell_indices)//len(px.colors.qualitative.Set3))  # Repeat colors if needed

for i, cell_idx in enumerate(cell_indices):  # Plot all classified cells
    color = colors[i % len(px.colors.qualitative.Set3)]
    fig.add_trace(
        go.Scatter(x=time_points,
                  y=F[cell_idx,:],
                  mode='lines',
                  name=f'ROI {cell_idx+1}',
                  line=dict(
                      color=color,
                      width=1,  # Make lines thinner
                  ),
                  opacity=0.5,  # Add transparency
                  legendgroup='roi_traces',
                  legendgrouptitle_text='ROI Traces',
                  showlegend=(i < 10)  # Only show first 10 traces in legend
                  ),
        row=2, col=2
    )

# Update layout
fig.update_layout(
    height=800,
    width=1000,
    title=dict(
        text="Suite2P Analysis Results",
        x=0.5,
        y=0.95
    ),
    showlegend=True,
    legend=dict(
        x=1.0,
        y=0.1,
        traceorder='grouped',
        tracegroupgap=5
    )
)

# Update axes labels
fig.update_xaxes(title_text="X Position", row=1, col=1)
fig.update_yaxes(title_text="Y Position", row=1, col=1)
fig.update_xaxes(title_text="X Position", row=1, col=2)
fig.update_yaxes(title_text="Y Position", row=1, col=2)
fig.update_xaxes(title_text="Time Points", row=2, col=1)
fig.update_yaxes(title_text="Fluorescence", row=2, col=1)
fig.update_xaxes(title_text="Time Points", row=2, col=2)
fig.update_yaxes(title_text="Fluorescence", row=2, col=2)

fig.show()

In [None]:
# Load and check behavior data dimensions at each step
behavior_file = os.path.join(input_dir, "sample_mouse2p_behavior.csv")

# Read raw CSV and check dimensions
raw_df = pd.read_csv(behavior_file, header=None)
print("Raw DataFrame shape:", raw_df.shape)

# Convert to numpy and check dimensions
behavior_array = raw_df.values
print("Numpy array shape:", behavior_array.shape)

# Create BehaviorData object and check dimensions
behavior_data = BehaviorData(behavior_array)
print("BehaviorData shape:", behavior_data.data.shape)

# Check fluorescence data dimensions
print("Fluorescence data shape:", ret_roi['fluorescence'].data.shape)

# Create fluorescence data object from Suite2p results
fluo_data = FluoData(ret_roi['fluorescence'].data, file_name="fluorescence")

In [15]:
# Set parameters for ETA analysis
eta_params = {
    'transpose_x': True,
    'transpose_y': False,
    'event_col_index': 1,  # Adjust this based on which column in your behavior CSV contains the event data
    'trigger_type': 'up',
    'trigger_threshold': 0,
    'pre_event': -10,
    'post_event': 10
}

In [16]:
# Create output directory for ETA
eta_function_id = f"eta_{unique_id}"
eta_output_dir = os.path.join(DIRPATH.OUTPUT_DIR, "1", unique_id, eta_function_id)
os.makedirs(eta_output_dir, exist_ok=True)

In [None]:
# Run ETA analysis
eta_results = ETA(fluo_data, behavior_data, eta_output_dir, ret_roi['iscell'], eta_params)

In [None]:
# Create ETA results plot
fig = make_subplots(rows=2, cols=1,
                    subplot_titles=('Event-Triggered Average (First ROI)',
                                  'Event-Triggered Average Heatmap'),
                    vertical_spacing=0.2)

# 1. Plot mean for first ROI with error bands
mean_data = eta_results['mean'].data
std_data = eta_results['mean'].std
time_points = eta_results['mean'].index

# Plot mean line
ROI_to_plot = 0  # First ROI
fig.add_trace(
    go.Scatter(
        x=time_points,
        y=mean_data[ROI_to_plot],
        mode='lines',
        name='ROI: '+ str(ROI_to_plot),
        line=dict(color='rgb(31, 119, 180)'),
        showlegend=True,
    ),
    row=1, col=1
)

# Add error bands (mean ± std)
fig.add_trace(
    go.Scatter(
        x=time_points + time_points[::-1],
        y=list(mean_data[0] + std_data[0]) + list(mean_data[0] - std_data[0])[::-1],
        fill='toself',
        fillcolor='rgba(31, 119, 180, 0.2)',
        line=dict(color='rgba(255,255,255,0)'),
        showlegend=False,
        name='Standard Deviation',
    ),
    row=1, col=1
)

# 2. Plot heatmap
heatmap_data = eta_results['mean_heatmap'].data
heatmap_cols = eta_results['mean_heatmap'].columns

fig.add_trace(
    go.Heatmap(
        z=heatmap_data,
        x=heatmap_cols,
        colorscale='Viridis',
        showscale=True,
        colorbar=dict(
            title='Normalized<br>Response',
            len=0.4,
            y=0.2,
            ),
    ),
    row=2, col=1
)

# Update layout
fig.update_layout(
    height=800,
    width=1000,
    title_text="Event-Triggered Average Analysis",
    showlegend=True
)

# Update axes labels
fig.update_xaxes(title_text="Time relative to event (frames)", row=1, col=1)
fig.update_yaxes(title_text="Fluorescence", row=1, col=1)
fig.update_xaxes(title_text="Time relative to event (frames)", row=2, col=1)
fig.update_yaxes(title_text="ROI #", row=2, col=1)

fig.show()