In [None]:
%load_ext autoreload
%autoreload 2
# %flow mode reactive

from typing import Any, Tuple, List, Dict
import warnings
import datetime

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import statsmodels.api as sm

import datajoint as dj
from aeon.dj_pipeline.analysis.block_analysis import *
from aeon.dj_pipeline import acquisition, streams
from aeon.analysis.block_plotting import gen_hex_grad, conv2d

# Load data

In [None]:
def load_experiment_data(
    key: Dict[str, str], 
    pre_social_start: str, 
    pre_social_end: str, 
    social_start: str, 
    social_end: str, 
    post_social_start: str, 
    post_social_end: str
) -> Tuple[List[Dict[str, str]], pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Loads experiment data for given time periods.

    Args:
        key (dict): The key to filter the experiment data.
        pre_social_start (str): The start time for the pre-social period.
        pre_social_end (str): The end time for the pre-social period.
        social_start (str): The start time for the social period.
        social_end (str): The end time for the social period.
        post_social_start (str): The start time for the post-social period.
        post_social_end (str): The end time for the post-social period.

    Returns:
        tuple: A tuple containing:
            - patch_info (list of dict): Information about patches.
            - block_subject_patch_data_pre_social (pd.DataFrame): Data for the pre-social period.
            - block_subject_patch_data_social (pd.DataFrame): Data for the social period.
            - block_subject_patch_data_post_social (pd.DataFrame): Data for the post-social period.
    """
    exp_start = (Block & key & f"block_start >= '{pre_social_start}'").fetch("block_start", limit=1)[0]
    exp_end = (Block & key & f"block_start >= '{post_social_end}'").fetch("block_end", limit=1)[0]
    
    patch_info = (
        BlockAnalysis.Patch
        & key
        & f'block_start >= "{exp_start}"'
        & f'block_start <= "{exp_end}"'
    ).fetch('block_start', "patch_name", "patch_rate", "patch_offset", as_dict=True)

    block_subject_patch_data_pre_social = (BlockSubjectAnalysis.Patch() & key & f'block_start >= "{pre_social_start}"' & f'block_start <= "{pre_social_end}"').fetch(format="frame")
    block_subject_patch_data_pre_social.reset_index(level=["experiment_name"], drop=True, inplace=True) 
    block_subject_patch_data_pre_social.reset_index(inplace=True)
    
    block_subject_patch_data_social = (BlockSubjectAnalysis.Patch() & key & f'block_start >= "{social_start}"' & f'block_start <= "{social_end}"').fetch(format="frame")
    block_subject_patch_data_social.reset_index(level=["experiment_name"], drop=True, inplace=True)
    block_subject_patch_data_social.reset_index(inplace=True)
    
    block_subject_patch_data_post_social = (BlockSubjectAnalysis.Patch() & key & f'block_start >= "{post_social_start}"' & f'block_start <= "{post_social_end}"').fetch(format="frame")
    block_subject_patch_data_post_social.reset_index(level=["experiment_name"], drop=True, inplace=True)
    block_subject_patch_data_post_social.reset_index(inplace=True)
    
    return patch_info, block_subject_patch_data_pre_social, block_subject_patch_data_social, block_subject_patch_data_post_social

In [None]:
experiments = [
    {"name": "social0.2-aeon3", "pre_social_start": '2024-01-31 10:00:00', "pre_social_end": '2024-02-08 15:00:00', "social_start": '2024-02-09 16:00:00', "social_end": '2024-02-23 13:00:00', "post_social_start": '2024-02-25 16:00:00', "post_social_end": '2024-03-02 14:00:00'},
    {"name": "social0.2-aeon4", "pre_social_start": '2024-01-31 09:00:00', "pre_social_end": '2024-02-08 15:00:00', "social_start": '2024-02-09 16:00:00', "social_end": '2024-02-23 12:00:00', "post_social_start": '2024-02-25 16:00:00', "post_social_end": '2024-03-02 13:00:00'},
    {"name": "social0.3-aeon3", "pre_social_start": '2024-06-08 18:00:00', "pre_social_end": '2024-06-17 13:00:00', "social_start": '2024-06-25 10:00:00', "social_end": '2024-07-06 13:00:00', "post_social_start": '2024-07-07 15:00:00', "post_social_end": '2024-07-14 14:00:00'},
    {"name": "social0.3-aeon4", "pre_social_start": '2024-06-08 18:00:00', "pre_social_end": '2024-06-17 14:00:00', "social_start": '2024-06-19 11:00:00', "social_end": '2024-07-03 14:00:00', "post_social_start": '2024-07-04 10:00:00', "post_social_end": '2024-07-13 12:00:00'},
    {"name": "social0.4-aeon3", "pre_social_start": '2024-08-16 16:00:00', "pre_social_end": '2024-08-24 10:00:00', "social_start": '2024-08-28 10:00:00', "social_end": '2024-09-09 13:00:00', "post_social_start": '2024-09-09 17:00:00', "post_social_end": '2024-09-22 16:00:00'},
    {"name": "social0.4-aeon4", "pre_social_start": '2024-08-16 14:00:00', "pre_social_end": '2024-08-24 10:00:00', "social_start": '2024-08-28 09:00:00', "social_end": '2024-09-09 01:00:00', "post_social_start": '2024-09-09 14:00:00', "post_social_end": '2024-09-22 16:00:00'}
]

patch_info_dict = {}
block_subject_patch_data_social_first_half_dict = {}
block_subject_patch_data_social_dict = {}
block_subject_patch_data_post_social_dict = {}

for exp in experiments:
    key = {"experiment_name": exp["name"]}
    pre_social_start = datetime.strptime(exp["pre_social_start"], '%Y-%m-%d %H:%M:%S')
    pre_social_end = datetime.strptime(exp["pre_social_end"], '%Y-%m-%d %H:%M:%S')
    social_start = datetime.strptime(exp["social_start"], '%Y-%m-%d %H:%M:%S')
    social_end = datetime.strptime(exp["social_end"], '%Y-%m-%d %H:%M:%S')
    post_social_start = datetime.strptime(exp["post_social_start"], '%Y-%m-%d %H:%M:%S')
    post_social_end = datetime.strptime(exp["post_social_end"], '%Y-%m-%d %H:%M:%S')

    patch_info, _, block_subject_patch_data_social, block_subject_patch_data_post_social = load_experiment_data(
        key, pre_social_start, pre_social_end, social_start, social_end, post_social_start, post_social_end
    )

    # Drop rows where patch_name contains 'PatchDummy'
    block_subject_patch_data_social = block_subject_patch_data_social[~block_subject_patch_data_social['patch_name'].str.contains('PatchDummy')]
    block_subject_patch_data_post_social = block_subject_patch_data_post_social[~block_subject_patch_data_post_social['patch_name'].str.contains('PatchDummy')]

    # Check post-social data only has one subject being tracked per block
    max_num_subjects = block_subject_patch_data_post_social.groupby('block_start')['subject_name'].nunique().max()
    if max_num_subjects > 1:
        warnings.warn(f"Post social data for {exp['name']} has more than one subject being tracked per block. Data needs to be fixed or cleaned.")

    patch_info_dict[exp["name"]] = patch_info
    social_midpoint = social_start + (social_end - social_start) / 2
    block_subject_patch_data_social_first_half_dict[exp["name"]] = block_subject_patch_data_social[block_subject_patch_data_social['block_start'] < social_midpoint]
    block_subject_patch_data_social_dict[exp["name"]] = block_subject_patch_data_social
    block_subject_patch_data_post_social_dict[exp["name"]] = block_subject_patch_data_post_social

# Combine loaded data
block_subject_patch_data_social_first_half_combined = pd.concat(block_subject_patch_data_social_first_half_dict.values()) # TODO: take even less? Match the num of days in post social?
block_subject_patch_data_social_combined = pd.concat(block_subject_patch_data_social_dict.values())
block_subject_patch_data_post_social_combined = pd.concat(block_subject_patch_data_post_social_dict.values())

# Create block plots

### 1. Wheel distance spun per block, averaged by the number of mice

In [None]:
"""Plot."""
block_subject_patch_data_social_combined['final_wheel_cumsum'] = block_subject_patch_data_social_combined['wheel_cumsum_distance_travelled'].apply(lambda x: x[-1] if isinstance(x, np.ndarray) and len(x) > 0 else 0)
wheel_total_dist_averaged_social = block_subject_patch_data_social_combined.groupby('block_start')['final_wheel_cumsum'].sum() / 2
wheel_total_dist_averaged_social = wheel_total_dist_averaged_social.reset_index()

block_subject_patch_data_post_social_combined['final_wheel_cumsum'] = block_subject_patch_data_post_social_combined['wheel_cumsum_distance_travelled'].apply(lambda x: x[-1] if isinstance(x, np.ndarray) and len(x) > 0 else 0)
wheel_total_dist_averaged_post_social = block_subject_patch_data_post_social_combined.groupby('block_start')['final_wheel_cumsum'].sum().reset_index()

wheel_total_dist_averaged_social['condition'] = 'social'
wheel_total_dist_averaged_post_social['condition'] = 'post_social'
wheel_total_dist_averaged = pd.concat([wheel_total_dist_averaged_social, wheel_total_dist_averaged_post_social])

fig = go.Figure()

fig = px.box(
    wheel_total_dist_averaged,
    x="condition",
    y="final_wheel_cumsum",
    points="all",
    title="Wheel Distance Spun Per Block Averaged By Number Of Subjects",
    labels={"final_wheel_cumsum": "Wheel Distance Spun Per Block (cm)"},
)
fig.show()

### 2. Number of patch switches by each mouse per block

In [None]:
import pandas as pd
import numpy as np

def compute_patch_probabilities(
    df: pd.DataFrame
) -> pd.DataFrame:
    """
    Compute patch probabilities based on block and subject data.

    Args:
        df (pd.DataFrame): Input DataFrame containing block, subject, pellet, and patch data.

    Returns:
        pd.DataFrame: A DataFrame with the probabilities for each patch.
    """
    results = []

    # Precompute unique block-subject groups
    grouped_data = df.groupby(['block_start', 'subject_name'])

    for (block_start, subject_name), block_data in grouped_data:
        # Process pellet timestamps once
        pellet_timestamps = np.sort(np.unique([ts for sublist in block_data['pellet_timestamps'] for ts in sublist]))
        if len(pellet_timestamps) < 2:
            continue

        # Create pellet intervals DataFrame
        intervals_df = pd.DataFrame({
            'interval_start': pellet_timestamps[:-1],
            'interval_end': pellet_timestamps[1:],
            'pellet_number': np.arange(1, len(pellet_timestamps))
        })

        # Prepare a dict to hold in_patch_timestamps for each patch
        patches_data = {}
        for patch in block_data['patch_name'].unique():
            patch_data = block_data[block_data['patch_name'] == patch]
            if patch_data.shape[0] != 1:
                raise ValueError("More than one row per block start, subject, patch combination.")
            in_patch_timestamps = np.sort(patch_data.iloc[0]['in_patch_timestamps'])
            patches_data[patch] = in_patch_timestamps

        # Initialize a DataFrame to store counts per patch
        counts_df = intervals_df[['pellet_number']].copy()

        # For each patch, compute counts within each interval using numpy searchsorted
        for patch, in_patch_ts in patches_data.items():
            counts = np.zeros(len(intervals_df), dtype=int)
            if len(in_patch_ts) > 0:
                idx_start = np.searchsorted(in_patch_ts, intervals_df['interval_start'].values, side='left')
                idx_end = np.searchsorted(in_patch_ts, intervals_df['interval_end'].values, side='right')
                counts = idx_end - idx_start
            counts_df[f'count_in_{patch}'] = counts

        # Compute total counts per interval
        counts_df['total_counts'] = counts_df.filter(like='count_in_').sum(axis=1)

        # Avoid division by zero
        counts_df['total_counts'] = counts_df['total_counts'].replace(0, np.nan)

        # Compute probabilities per interval
        for idx, row in counts_df.iterrows():
            pellet_number = row['pellet_number']
            row_data = {
                'block_start': block_start,
                'subject_name': subject_name,
                'pellet_number': pellet_number
            }
            ts_in_patches = {patch: row[f'count_in_{patch}'] for patch in patches_data.keys()}
            ts_in_patches_total = row['total_counts']
            if pd.isna(ts_in_patches_total):
                prob = {patch: 0 for patch in ts_in_patches.keys()}
            else:
                prob = {patch: ts_in_patches[patch] / ts_in_patches_total for patch in ts_in_patches.keys()}
            row_data.update({f'prob_in_{patch}': prob[patch] for patch in patches_data.keys()})
            results.append(row_data)

    # Create final DataFrame
    prob_per_patch = pd.DataFrame(results)
    return prob_per_patch

def extract_hard_patch_probabilities(
    prob_per_patch: pd.DataFrame, 
    patch_info: List[Dict[str, Any]], 
    patch_rate: float = 0.002
) -> pd.DataFrame:
    """
    Compute the probabilities for hard patches where the patch rate matches a specified value.

    Args:
        prob_per_patch (pd.DataFrame): DataFrame containing probabilities per patch.
        patch_info (List[Dict[str, Any]]): List of dictionaries with patch information.
        patch_rate (float): The patch rate to filter by. Default is 0.002.

    Returns:
        pd.DataFrame: A DataFrame with the probabilities for each hard patch.
    """
    # Filter the hard patches based on the patch_rate
    hard_patches = [patch_dict for patch_dict in patch_info if patch_dict['patch_rate'] == patch_rate]
    
    results = []
    for hard_patch in hard_patches:
        block_start = hard_patch['block_start']
        patch_name = hard_patch['patch_name']
        
        # Extract the hard patch data
        hard_patch_data = prob_per_patch.loc[
            prob_per_patch['block_start'] == block_start, 
            ['block_start', 'subject_name', 'pellet_number', f'prob_in_{patch_name}']
        ]
        
        # Rename the column for hard patch probability
        hard_patch_data = hard_patch_data.rename(columns={f'prob_in_{patch_name}': 'prob_in_hard_patch'})
        
        # Append the result to the list
        results.append(hard_patch_data)
        
    # Concatenate all results into a single DataFrame
    prob_hard_patch = pd.concat(results, ignore_index=True)
    
    return prob_hard_patch


In [None]:
prob_per_patch_social_first_half_dict = {}
prob_per_patch_social_dict = {}
prob_per_patch_post_social_dict = {}
prob_hard_patch_social_first_half_dict = {}
prob_hard_patch_social_dict = {}
prob_hard_patch_post_social_dict = {}
prob_hard_patch_mean_social_first_half_dict = {}
prob_hard_patch_mean_social_dict = {}
prob_hard_patch_mean_post_social_dict = {}

for exp in experiments:
    exp_name = exp["name"]
    block_subject_patch_data_social_first_half = block_subject_patch_data_social_first_half_dict[exp_name]
    block_subject_patch_data_social = block_subject_patch_data_social_dict[exp_name]
    block_subject_patch_data_post_social = block_subject_patch_data_post_social_dict[exp_name]
    patch_info = patch_info_dict[exp_name]

    # Compute patch probabilities
    prob_per_patch_social_first_half_dict[exp_name] = compute_patch_probabilities(block_subject_patch_data_social_first_half)
    prob_per_patch_social_dict[exp_name] = compute_patch_probabilities(block_subject_patch_data_social)
    prob_per_patch_post_social_dict[exp_name] = compute_patch_probabilities(block_subject_patch_data_post_social)

    # Extract hard patch probabilities
    prob_hard_patch_social_first_half_dict[exp_name] = extract_hard_patch_probabilities(prob_per_patch_social_first_half_dict[exp_name], patch_info)
    prob_hard_patch_social_dict[exp_name] = extract_hard_patch_probabilities(prob_per_patch_social_dict[exp_name], patch_info)
    prob_hard_patch_post_social_dict[exp_name] = extract_hard_patch_probabilities(prob_per_patch_post_social_dict[exp_name], patch_info)

    # Calculate the mean hard patch probability per pellet number
    prob_hard_patch_mean_social_first_half_dict[exp_name] = prob_hard_patch_social_first_half_dict[exp_name].groupby('pellet_number').mean().reset_index()
    prob_hard_patch_mean_social_dict[exp_name] = prob_hard_patch_social_dict[exp_name].groupby('pellet_number').mean().reset_index()
    prob_hard_patch_mean_post_social_dict[exp_name] = prob_hard_patch_post_social_dict[exp_name].groupby('pellet_number').mean().reset_index()

# Combine the results
prob_hard_patch_social_first_half_combined = pd.concat(prob_hard_patch_social_first_half_dict.values())
prob_hard_patch_social_combined = pd.concat(prob_hard_patch_social_dict.values())
prob_hard_patch_post_social_combined = pd.concat(prob_hard_patch_post_social_dict.values())

prob_hard_patch_mean_social_first_half_combined = prob_hard_patch_social_first_half_combined.groupby('pellet_number').mean().reset_index()
prob_hard_patch_mean_social_combined = prob_hard_patch_social_combined.groupby('pellet_number').mean().reset_index()
prob_hard_patch_mean_post_social_combined = prob_hard_patch_post_social_combined.groupby('pellet_number').mean().reset_index()

In [None]:
def analyze_patch_probabilities(
    data: pd.DataFrame, 
    label: str
) -> Tuple[sm.OLS, np.ndarray]:
    """
    Analyze patch probabilities using a linear regression model.

    Args:
        data (pd.DataFrame): DataFrame containing the data to analyze.
        label (str): Label for the analysis.

    Returns:
        tuple: A tuple containing:
            - model (sm.OLS): The fitted linear regression model.
            - y_pred (np.ndarray): The predicted values from the model.
    """
    # Prepare the data for statsmodels (add a constant for the intercept)
    X = np.array(data['pellet_number'][0:35])
    y = np.array(data['prob_in_hard_patch'][0:35])
    # Add a constant to the independent variable X to calculate the intercept
    X_with_constant = sm.add_constant(X)
    # Fit the model using statsmodels
    model = sm.OLS(y, X_with_constant).fit()
    y_pred = model.predict(X_with_constant)
    # Get the p-value for the slope (it's the second value in pvalues)
    p_value = model.pvalues[1]
    print(f"P-value for the {label} slope: {p_value}")
    # Print full statistical summary
    print(f"{label} model summary: ", model.summary())
    return model, y_pred

model_social_first_half, y_pred_social_first_half = analyze_patch_probabilities(prob_hard_patch_mean_social_first_half_combined, "social first half")
model_social, y_pred_social = analyze_patch_probabilities(prob_hard_patch_mean_social_combined, "social")
model_post_social, y_pred_post_social = analyze_patch_probabilities(prob_hard_patch_mean_post_social_combined, "post-social")

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=prob_hard_patch_mean_social_first_half_combined['pellet_number'], #[0:35],
    y=prob_hard_patch_mean_social_first_half_combined['prob_in_hard_patch'], #[0:35],
    mode='lines',
    name='First Half of Social Data',
    marker=dict(color='blue')
))
    
fig.add_trace(go.Scatter(
    x=prob_hard_patch_mean_social_combined['pellet_number'], #[0:35],
    y=prob_hard_patch_mean_social_combined['prob_in_hard_patch'], #[0:35],
    mode='lines',
    name='Social Data',
    marker=dict(color='red')
))

fig.add_trace(go.Scatter(
    x=prob_hard_patch_mean_post_social_combined['pellet_number'], #[0:35],
    y=prob_hard_patch_mean_post_social_combined['prob_in_hard_patch'], #[0:35],
    mode='lines',
    name='Post Social Data',
    marker=dict(color='#00CC96')
))

fig.add_trace(go.Scatter(
    x=prob_hard_patch_mean_social_first_half_combined['pellet_number'][0:35],
    y=y_pred_social_first_half,
    mode='lines',
    name='Social First Half Linear Regression Line',
    line=dict(dash='dash'),
    marker=dict(color='blue')  # Optional: to make the regression line dashed
))

fig.add_trace(go.Scatter(
    x=prob_hard_patch_mean_social_combined['pellet_number'][0:35],
    y=y_pred_social,
    mode='lines',
    name='Social Linear Regression Line',
    line=dict(dash='dash'),
    marker=dict(color='red')  # Optional: to make the regression line dashed
))

fig.add_trace(go.Scatter(
    x=prob_hard_patch_mean_social_combined['pellet_number'][0:35],
    y=y_pred_post_social,
    mode='lines',
    name='Post Social Linear Regression Line',
    line=dict(dash='dash'),  # Optional: to make the regression line dashed
    marker=dict(color='#00CC96')
))

fig.update_layout(
    title='Probability of being in hard patch over time',
    xaxis_title='Pellet number in block',
    yaxis_title='Hard patch probability'
)

fig.show()

In [None]:
# Save the figure as an SVG file
# fig.write_image("hard_patch_probability.svg")

In [None]:
# Define the number of rows and columns for the subplot grid
num_experiments = len(experiments)
num_cols = 2  
num_rows = (num_experiments + num_cols - 1) // num_cols  

# Create a subplot grid
fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=[exp["name"] for exp in experiments])

# Iterate over each experiment and add a plot to the grid
for i, exp in enumerate(experiments):
    exp_name = exp["name"]
    row = (i // num_cols) + 1
    col = (i % num_cols) + 1

    # Add the plot to the grid
    fig.add_trace(go.Scatter(
        x=prob_hard_patch_mean_social_first_half_dict[exp_name]['pellet_number'][0:35], 
        y=prob_hard_patch_mean_social_first_half_dict[exp_name]['prob_in_hard_patch'][0:35],
        mode='lines', 
        name='First Half of Social Data',
        marker=dict(color='blue'),
        showlegend=(i == 0)),
    row=row, col=col)

    fig.add_trace(go.Scatter(
        x=prob_hard_patch_mean_social_dict[exp_name]['pellet_number'][0:35], 
        y=prob_hard_patch_mean_social_dict[exp_name]['prob_in_hard_patch'][0:35],
        mode='lines', 
        name='Social Data',
        marker=dict(color='red'),
        showlegend=(i == 0)),
    row=row, col=col)

    fig.add_trace(go.Scatter(
        x=prob_hard_patch_mean_post_social_dict[exp_name]['pellet_number'][0:35], 
        y=prob_hard_patch_mean_post_social_dict[exp_name]['prob_in_hard_patch'][0:35],
        mode='lines', 
        name='Post Social Data',
        marker=dict(color='#00CC96'),
        showlegend=(i == 0)),
    row=row, col=col)


fig.update_layout(height=800, width=1000)
fig.show()