In [None]:
# =============================================================================
# IMPORTS AND SETUP
# =============================================================================
"""
Time-Series Analysis for Astrocyte Calcium Signaling Data

This notebook visualizes normalized calcium event features across three 
experimental conditions (Baseline, PSI/Drug, Washout) binned by frame groups.
It complements analysis.ipynb by providing temporal resolution of the data.

Prerequisites:
    - Run analysis.ipynb first to generate normalized CSV files
    - Normalized data should exist in Output__/<group>/ directories
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Set working directory
cwd = os.getcwd()

In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def sort_and_group_dataframe_by_frame(df, group_size, min_frame, max_frame):
    """
    Bins calcium events by their starting frame into discrete time groups.
    
    This function enables temporal analysis by grouping events that occur
    within the same frame window, allowing visualization of how calcium
    signaling features change over the course of the recording.
    
    Args:
        df: DataFrame containing calcium event data with 'Starting Frame' column
        group_size: Number of frames per bin (e.g., 10 = group frames 20-30, 30-40, etc.)
        min_frame: First frame to include (filters out early artifacts)
        max_frame: Last frame to include (filters out late artifacts)
    
    Returns:
        DataFrame with new 'Frame Group' column containing bin labels (e.g., "20-30")
        Events outside the frame range are excluded (NaN in Frame Group)
    
    Example:
        With group_size=10, min_frame=20, max_frame=100:
        Creates bins: "20-30", "30-40", "40-50", ..., "90-100"
    """
    # Validate inputs
    if not isinstance(group_size, int) or group_size <= 0:
        print("Error: 'group_size' must be a positive integer.")
        return df

    if 'Starting Frame' not in df.columns:
        print("Error: 'Starting Frame' column not found.")
        return df

    # Prepare data
    df['Starting Frame'] = pd.to_numeric(df['Starting Frame'])
    df.sort_values(by='Starting Frame', inplace=True)
    df.reset_index(drop=True, inplace=True)

    # Handle empty DataFrame
    if df.empty or df['Starting Frame'].empty:
        print("Warning: DataFrame is empty.")
        df['Frame Group'] = 'N/A'
        return df

    # Calculate bin edges (only complete groups)
    num_complete_groups = (max_frame - min_frame) // group_size
    final_edge = min_frame + (num_complete_groups * group_size)
    
    bin_edges = list(range(min_frame, final_edge + 1, group_size))
    
    # Ensure at least one valid bin
    if len(bin_edges) < 2:
        bin_edges = [min_frame, min_frame + group_size]

    # Create labels like "20-30", "30-40", etc.
    labels = [f"{bin_edges[i]}-{bin_edges[i+1]}" for i in range(len(bin_edges) - 1)]

    # Apply binning
    df['Frame Group'] = pd.cut(
        df['Starting Frame'],
        bins=bin_edges,
        labels=labels,
        right=True,
        include_lowest=True,
        duplicates='drop'
    )
    
    return df

In [None]:
# =============================================================================
# PLOTTING FUNCTIONS
# =============================================================================

# Features to plot (must match column names in normalized CSVs)
FEATURES_TO_PLOT = [
    # Morphological features
    "Basic - Area", 
    "Basic - Perimeter (only for 2D video)",
    "Basic - Circularity", 
    
    # Calcium curve intensity metrics
    "Curve - Max Df", 
    "Curve - Max Dff",
    "Curve - dat AUC", 
    "Curve - df AUC", 
    "Curve - dff AUC",
    
    # Temporal dynamics
    "Curve - Duration of visualized event overlay",
    "Curve - Duration 50% to 50% based on averge dF/F",
    "Curve - Duration 10% to 10% based on averge dF/F",
    "Curve - Rising duration 10% to 90% based on averge dF/F",
    "Curve - Decaying duration 90% to 10% based on averge dF/F",
    
    # Network/spatial features
    "Network - number of events in the same location",
    "Network - number of events in the same location with similar size only",
    "Network - maximum number of events appearing at the same time"
]

# Color scheme for experimental conditions
CONDITION_COLORS = {
    "Baseline": '#1f77b4',  # Blue
    "PSI": '#2ca02c',       # Green  
    "Washout": '#d62728'    # Red
}


def get_sorted_frame_groups(df):
    """
    Extract and sort unique frame group labels numerically.
    
    Args:
        df: DataFrame with 'Frame Group' column
        
    Returns:
        List of frame group labels sorted by starting frame number
    """
    if 'Frame Group' not in df.columns or df['Frame Group'].isna().all():
        return []
    
    unique_labels = df['Frame Group'].dropna().unique()
    
    # Sort by the numeric start of each range (e.g., "20-30" -> 20)
    return sorted(unique_labels, key=lambda x: int(str(x).split('-')[0]))


def plot_median(baseline, psi, washout, group_size, min_frame, max_frame, title=None):
    """
    Create time-series plots showing median feature values across frame groups.
    
    Generates a multi-panel figure where each subplot shows one feature's
    median values across time bins for all three conditions. Includes:
    - Individual median points for each frame group
    - Horizontal dashed lines showing overall condition averages
    - Shaded regions indicating variation (mean ± std of medians)
    - Vertical separators between conditions
    
    Args:
        baseline: DataFrame with baseline condition data
        psi: DataFrame with PSI/drug condition data  
        washout: DataFrame with washout condition data
        group_size: Number of frames per time bin
        min_frame: Starting frame for analysis
        max_frame: Ending frame for analysis
        title: Optional figure title (displayed in window)
    """
    # Bin events by frame group for each condition
    df_base = sort_and_group_dataframe_by_frame(baseline.copy(), group_size, min_frame, max_frame)
    df_psi = sort_and_group_dataframe_by_frame(psi.copy(), group_size, min_frame, max_frame)
    df_wash = sort_and_group_dataframe_by_frame(washout.copy(), group_size, min_frame, max_frame)

    # Map conditions to their data and colors
    conditions = {
        "Baseline": (df_base, CONDITION_COLORS["Baseline"]),
        "PSI": (df_psi, CONDITION_COLORS["PSI"]),
        "Washout": (df_wash, CONDITION_COLORS["Washout"])
    }

    # Setup figure grid
    num_plots = len(FEATURES_TO_PLOT)
    ncols = 4
    nrows = (num_plots + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows, ncols, figsize=(22, nrows * 5), squeeze=False)
    axes = axes.flatten()
    
    if title:
        fig.suptitle(title, fontsize=14, fontweight='bold')

    # Plot each feature
    for i, feature in enumerate(FEATURES_TO_PLOT):
        if i >= len(axes):
            break
        ax = axes[i]
        
        # Track x-axis positions across all conditions
        all_xtick_labels = []
        all_xtick_positions = []
        current_x = 1

        # Plot each condition sequentially
        for cond_idx, (cond_name, (df, color)) in enumerate(conditions.items()):
            if feature not in df.columns:
                continue

            frame_groups = get_sorted_frame_groups(df)
            if not frame_groups:
                continue

            # Collect medians for each frame group
            x_positions = []
            medians = []
            
            for frame_group in frame_groups:
                group_data = df[df['Frame Group'] == frame_group][feature]
                median_val = group_data.median()
                
                if pd.notna(median_val):
                    x_positions.append(current_x)
                    medians.append(median_val)
                    all_xtick_labels.append(frame_group)
                    all_xtick_positions.append(current_x)
                    current_x += 1

            # Plot this condition's data
            if x_positions:
                # Median points
                ax.plot(x_positions, medians, 'o', color=color, 
                       alpha=0.8, markersize=6, label=cond_name)
                
                # Overall average line for this condition
                avg = np.mean(medians)
                ax.hlines(y=avg, 
                         xmin=min(x_positions) - 0.3, 
                         xmax=max(x_positions) + 0.3,
                         colors=color, linestyles='--', linewidth=2, alpha=0.9)
                
                # Shaded variation region (mean ± std)
                std = np.std(medians)
                ax.fill_between(
                    [min(x_positions) - 0.3, max(x_positions) + 0.3],
                    avg - std, avg + std,
                    color=color, alpha=0.1
                )

            # Add separator line between conditions
            if cond_idx < len(conditions) - 1:
                ax.axvline(x=current_x - 0.5, color='grey', linestyle='--', linewidth=1)
                current_x += 1

        # Configure axes
        ax.set_title(feature, fontsize=11, wrap=True)
        ax.set_ylabel("Median Value", fontsize=9)
        ax.grid(True, linestyle='--', alpha=0.6)
        
        if all_xtick_positions:
            ax.set_xticks(all_xtick_positions)
            ax.set_xticklabels(all_xtick_labels, rotation=90, ha="center", fontsize=8)
        else:
            ax.text(0.5, 0.5, "No data available", ha='center', va='center', 
                   transform=ax.transAxes, color='gray')

        ax.margins(x=0.02)
        ax.legend(loc='best', fontsize=8)

    # Remove unused subplots
    for j in range(num_plots, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.show()

In [None]:
# =============================================================================
# DATA LOADING
# =============================================================================
"""
Load normalized data for each experimental group.
These files are generated by running analysis.ipynb first.

Groups:
    - WT: Wild Type (control)
    - AV: Antagonist Volinanserin  
    - IP: IP3R2 cKO (calcium signaling knockout)
    - CE: CalEx (calcium exchanger)
"""

# Output directory structure
OUTPUT_DIR = 'Output__'

# --- Wild Type (Control) ---
wt_base_df = pd.read_csv(f'{OUTPUT_DIR}/WT/WT_baseline_normalized.csv')
wt_psi_df = pd.read_csv(f'{OUTPUT_DIR}/WT/WT_drug_normalized.csv')
wt_wash_df = pd.read_csv(f'{OUTPUT_DIR}/WT/WT_washout_normalized.csv')

# --- Antagonist Volinanserin ---
av_base_df = pd.read_csv(f'{OUTPUT_DIR}/Antagonist- Volinanserin/AV_baseline_normalized.csv')
av_psi_df = pd.read_csv(f'{OUTPUT_DIR}/Antagonist- Volinanserin/AV_drug_normalized.csv')
av_wash_df = pd.read_csv(f'{OUTPUT_DIR}/Antagonist- Volinanserin/AV_washout_normalized.csv')

# --- IP3R2 cKO ---
ip_base_df = pd.read_csv(f'{OUTPUT_DIR}/IP3R2 cKO/IP_baseline_normalized.csv')
ip_psi_df = pd.read_csv(f'{OUTPUT_DIR}/IP3R2 cKO/IP_drug_normalized.csv')
ip_wash_df = pd.read_csv(f'{OUTPUT_DIR}/IP3R2 cKO/IP_washout_normalized.csv')

# --- CalEx ---
# Note: CE data files must be generated first by running analysis.ipynb
# If files don't exist, comment out these lines
# ce_base_df = pd.read_csv(f'{OUTPUT_DIR}/CalEx/CE_baseline_normalized.csv')
# ce_psi_df = pd.read_csv(f'{OUTPUT_DIR}/CalEx/CE_drug_normalized.csv')
# ce_wash_df = pd.read_csv(f'{OUTPUT_DIR}/CalEx/CE_washout_normalized.csv')

In [None]:
# =============================================================================
# ANALYSIS PARAMETERS
# =============================================================================

# Frame grouping settings
GROUP_SIZE = 10      # Number of frames per bin
MIN_FRAME = 20       # Start frame (excludes early recording artifacts)- do not change
MAX_FRAME = 100      # End frame (excludes late recording artifacts)- do not change

In [None]:
# =============================================================================
# GENERATE PLOTS
# =============================================================================
"""
Generate time-series plots for each experimental group.
Each plot shows all 16 features across Baseline -> PSI -> Washout conditions.
"""

# Wild Type (Control)
plot_median(wt_base_df, wt_psi_df, wt_wash_df, GROUP_SIZE, MIN_FRAME, MAX_FRAME, 
            title="Wild Type (WT)")

# Antagonist Volinanserin
plot_median(av_base_df, av_psi_df, av_wash_df, GROUP_SIZE, MIN_FRAME, MAX_FRAME,
            title="Antagonist Volinanserin (AV)")

# IP3R2 cKO
plot_median(ip_base_df, ip_psi_df, ip_wash_df, GROUP_SIZE, MIN_FRAME, MAX_FRAME,
            title="IP3R2 cKO (IP)")

# CalEx (uncomment when data is available)
# plot_median(ce_base_df, ce_psi_df, ce_wash_df, GROUP_SIZE, MIN_FRAME, MAX_FRAME,
#             title="CalEx (CE)")