In [None]:
%load_ext autoreload
%autoreload 2

import datajoint as dj
import aeon
from aeon.io import api
from aeon.schema.schemas import social02
from aeon.dj_pipeline.analysis.block_analysis import * 
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
import plotly.subplots as sp
import numpy as np
from plotly.subplots import make_subplots
from datetime import timedelta, datetime
from scipy import stats
from shapely.geometry import Point, Polygon
import random

# 0. Import and process data

Experiment to analyse:

In [None]:
key = {"experiment_name": "social0.2-aeon3"}

In [None]:
arena = 'AEON3'

Load social interaction data:

In [None]:
# TODO: replace with revised csv path
base_path = '/ceph/aeon/aeon/code/scratchpad/Orsi/'

tube_test_path = f'all_tube_test_videos/{arena}_tube_tests_revised_final.csv'
fights_path = f'all_fighting_videos/{arena}_fights.csv'
chasing_path = f'all_chasing_videos/{arena}_chases.csv'

# Open CSV containing tube test data.
tube_test_df = pd.read_csv(base_path + tube_test_path)
# Open CSV containing fighting data.
fights_df = pd.read_csv(base_path + fights_path)
# Open CSV containing chasing data.
chasing_df = pd.read_csv(base_path + chasing_path)

Clean up social interaction data:

In [None]:
# Convert the start and end timestamps to datetime
chasing_df['start_timestamp'] = pd.to_datetime(chasing_df['start_timestamp'], format='%Y-%m-%dT%H-%M-%S')
chasing_df['end_timestamp'] = pd.to_datetime(chasing_df['end_timestamp'], format='%Y-%m-%dT%H-%M-%S')

fights_df['start_timestamp'] = pd.to_datetime(fights_df['start_timestamp'], format='%Y-%m-%dT%H-%M-%S')
fights_df['end_timestamp'] = pd.to_datetime(fights_df['end_timestamp'], format='%Y-%m-%dT%H-%M-%S')

tube_test_df['start_timestamp'] = pd.to_datetime(tube_test_df['start_timestamp'], format='%Y-%m-%dT%H-%M-%S')
tube_test_df['end_timestamp'] = pd.to_datetime(tube_test_df['end_timestamp'], format='%Y-%m-%dT%H-%M-%S')

# Add a 'behavior_type' column to each data frame
chasing_df['behavior_type'] = 'chasing'
fights_df['behavior_type'] = 'fighting'
tube_test_df['behavior_type'] = 'tube_test'

#remane colunms to domiant_id
chasing_df.rename(columns={'chaser_id': 'dominant_id'}, inplace=True)
tube_test_df.rename(columns={'winner_id': 'dominant_id'}, inplace=True)
fights_df['dominant_id'] = fights_df.get('dominant_id', 'NaN')

# Combine the data frames
combined_df = pd.concat([chasing_df, fights_df, tube_test_df])
# Replace NaN values in 'dominant_id' with a string 'NaN'
combined_df['dominant_id'] = combined_df['dominant_id'].fillna('NaN')
combined_df.head()

Add pre/post tubetest data manually:

In [None]:
# Create the data
pre_tube_test = {
    'behavior_type': ['pre_tube_test'] * 10,
    'dominant_id': ['BAA-1104045'] * 2 + ['BAA-1104047'] * 8
}

# Create the DataFrame
pre_tube_test = pd.DataFrame(pre_tube_test)


# Create the data
post_tube_test = {
    'behavior_type': ['pre_tube_test'] * 10,
    'dominant_id': ['BAA-1104045'] * 1 + ['BAA-1104047'] * 9
}

# Create the DataFrame
post_tube_test = pd.DataFrame(post_tube_test)

Get metadata like info:

In [None]:
#get start an edn as first and last block start
first_block_start = (BlockAnalysis() & key).fetch()['block_start'][0]
last_block_start= (BlockAnalysis() & key).fetch()['block_start'][-1]
# Calculate the start/end of the same day
experiment_start = datetime(first_block_start.year, first_block_start.month, first_block_start.day, 0, 0, 0)
experiment_end = datetime(last_block_start.year, last_block_start.month, last_block_start.day+1, 0, 0, 0)
experiment_end

In [None]:
#TODO: find out where to get this from data, now its manual
pre_solo_start = experiment_start
pre_solo_end = datetime.strptime('2024-02-09 00:00:00', '%Y-%m-%d %H:%M:%S')
social_start =  datetime.strptime('2024-02-09 00:00:00', '%Y-%m-%d %H:%M:%S') 
social_end = datetime.strptime('2024-02-24 00:00:00', '%Y-%m-%d %H:%M:%S')
post_solo_start = datetime.strptime('2024-02-25 00:00:00', '%Y-%m-%d %H:%M:%S')
post_solo_end = experiment_end

In [None]:
#ids of the animals
unique_ids = np.unique(combined_df['dominant_id'])
unique_ids = unique_ids[unique_ids != 'NaN']
unique_ids[0]

In [None]:
# Define light cycle periods
night_start = '08:00'
night_end = '19:00'
twilight_start = '07:00'
twilight_end = '08:00'
dawn_start = '19:00'
dawn_end = '20:00'
day_start = '20:00'
day_end = '07:00'

In [None]:
#get metadata base on fist solo block using api
metadata = (
    api.load(f'/ceph/aeon/aeon/data/raw/{arena}/social0.2', social02.Metadata, experiment_start, first_block_start).iloc[0].metadata
)
metadata

Plotting style settings:

In [None]:
# Create a dictionary for color mapping
id_color_map = {
    'NaN': 'grey',
    unique_ids[1]: 'purple',
    unique_ids[0]: 'green'
} 
behaviour_map = {
    'chasing': 'blue',
    'fighting': 'red',
    'tube_test': 'orange'
} 


# 1. Dominance interactions

## 1.1 Temporal patterns of interactions and stability of dominance

In [None]:
# Create the raster plot
fig = px.scatter(
    combined_df,
    x='start_timestamp',
    y='behavior_type',
    color='dominant_id',
    title='Behavior Raster Plot',
    labels={'start_timestamp': 'Time', 'behavior_type': 'Behavior Type'},
    color_discrete_map=id_color_map
)

# Set x-axis limits
fig.update_xaxes(range=[social_start, social_end])
# Set x-axis limits
fig.update_yaxes(range=[-1.5, 2.5])  # Example date range

# Iterate over each day in the two-week period
# Iterate over each day in the two-week period
current_day = social_start
while current_day < social_end:
    # Define the start and end times for each period
    night_start_time = current_day.replace(hour=int(night_start.split(':')[0]), minute=int(night_start.split(':')[1]))
    night_end_time = current_day.replace(hour=int(night_end.split(':')[0]), minute=int(night_end.split(':')[1]))
    twilight_start_time = current_day.replace(hour=int(twilight_start.split(':')[0]), minute=int(twilight_start.split(':')[1]))
    twilight_end_time = current_day.replace(hour=int(twilight_end.split(':')[0]), minute=int(twilight_end.split(':')[1]))
    dawn_start_time = current_day.replace(hour=int(dawn_start.split(':')[0]), minute=int(dawn_start.split(':')[1]))
    dawn_end_time = current_day.replace(hour=int(dawn_end.split(':')[0]), minute=int(dawn_end.split(':')[1]))
    day_start_time = current_day.replace(hour=int(day_start.split(':')[0]), minute=int(day_start.split(':')[1]))
    
    # Calculate day_end_time correctly
    if int(day_end.split(':')[0]) < int(day_start.split(':')[0]):
        day_end_time = current_day + timedelta(days=1)
        day_end_time = day_end_time.replace(hour=int(day_end.split(':')[0]), minute=int(day_end.split(':')[1]))
    else:
        day_end_time = current_day.replace(hour=int(day_end.split(':')[0]), minute=int(day_end.split(':')[1]))

    # Add horizontal lines for light and dark periods
    fig.add_shape(
        type="line",
        x0=twilight_start_time,
        x1=twilight_end_time,
        y0=-1,
        y1=-1,
        line=dict(color="gray", width=4)
    )
    fig.add_shape(
        type="line",
        x0=night_start_time,
        x1=night_end_time,
        y0=-1,
        y1=-1,
        line=dict(color="black", width=4)
    )
    fig.add_shape(
        type="line",
        x0=dawn_start_time,
        x1=dawn_end_time,
        y0=-1,
        y1=-1,
        line=dict(color="gray", width=4)
    )
    fig.add_shape(
        type="line",
        x0=day_start_time,
        x1=day_end_time,
        y0=-1,
        y1=-1,
        line=dict(color="white", width=4)
    )
    
    # Move to the next day
    current_day += timedelta(days=1)

# Show the plot
fig.show()

In [None]:
# Create a list of time bins
time_bins = pd.date_range(start=social_start, end=social_end, freq='24h')

# Bin the start_timestamp into the created time bins
combined_df['time_bin'] = pd.cut(combined_df['start_timestamp'], bins=time_bins)

# Calculate the number of interactions for each behavior in each time bin
interaction_counts = combined_df.groupby(['time_bin', 'behavior_type']).size().reset_index(name='total_interactions')

# Convert the time_bin column to strings
# Extract the start time of each bin
interaction_counts['time_bin_start'] = interaction_counts['time_bin'].apply(lambda x: x.left)

# Convert the start time to a string format
interaction_counts['time_bin_start'] = interaction_counts['time_bin_start'].dt.strftime('%Y-%m-%d %H:%M')

# Create the line plot
fig = px.line(
    interaction_counts,
    x='time_bin_start',
    y='total_interactions',
    color='behavior_type',
    title='Number of Interactions Over Time for Each Behavior',
    labels={'time_bin_start': 'Time Bin', 'total_interactions': 'Number of Interactions'},
    color_discrete_map=behaviour_map
)

# Show the plot
fig.show()

In [None]:
# Calculate the fraction of winning for chasing
chaser_counts_chasing = chasing_df['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_chasing.columns = ['dominant_id', 'p_wins']
chaser_counts_chasing['behavior_type'] = 'Chasing'

# Calculate the fraction of winning for tube test
chaser_counts_tube_test = tube_test_df['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_tube_test.columns = ['dominant_id', 'p_wins']
chaser_counts_tube_test['behavior_type'] = 'Tube Test'

# Calculate the fraction of winning for pre tube test
chaser_counts_pre_tube_test = pre_tube_test['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_pre_tube_test.columns = ['dominant_id', 'p_wins']
chaser_counts_pre_tube_test['behavior_type'] = 'Pre Tube Test'

# Calculate the fraction of winning for post tube test
chaser_counts_post_tube_test = post_tube_test['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_post_tube_test.columns = ['dominant_id', 'p_wins']
chaser_counts_post_tube_test['behavior_type'] = 'Post Tube Test'

# Combine the data
combined_counts = pd.concat([
    chaser_counts_chasing, 
    chaser_counts_tube_test, 
    chaser_counts_pre_tube_test, 
    chaser_counts_post_tube_test
])

# Create the scatter plot
fig = px.scatter(
    combined_counts,
    x='behavior_type',
    y='p_wins',
    color='dominant_id',
    title='Dominance per Behavior Type and Dominant ID',
    labels={'behavior_type': 'Behavior Type', 'p_wins': 'Proportion of wins'},
    color_discrete_map=id_color_map
)

fig.update_yaxes(range=[0, 1])  # Example date range

# Show the plot
fig.show()

In [None]:
# calculate te number of winning events over all events for every 42h time bin and plot over time
# Create a list of time bins
time_bins = pd.date_range(start=social_start, end=social_end, freq='24h')

# Bin the start_timestamp into the created time bins
tube_test_df['time_bin'] = pd.cut(tube_test_df['start_timestamp'], bins=time_bins)

# Calculate the number of winning events and total events in each time bin and id
winning_counts = tube_test_df.groupby(['time_bin', 'dominant_id']).size().reset_index(name='total_events')

# Calculate the fraction of winning events for each dominant_id in each time bin
winning_counts['fraction_winning'] = winning_counts.groupby('time_bin')['total_events'].transform(lambda x: x / x.sum())

# Create a new column for day numbers starting from 1
winning_counts['day_number'] = winning_counts['time_bin'].cat.codes + 1

# Create the line plot for total events
fig = px.line(
    winning_counts,
    x='day_number',
    y='total_events',
    color='dominant_id',
    title='Winning Events Over Time (Tube test)',
    labels={'day_number': 'Day Number', 'total_events': 'Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create the line plot for fraction of winning events
fig = px.line(
    winning_counts,
    x='day_number',
    y='fraction_winning',
    color='dominant_id',
    title='Fraction of Winning Events Over Time (Tube test)',
    labels={'day_number': 'Day Number', 'fraction_winning': 'Proportion of Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create a list of time bins
time_bins = pd.date_range(start=social_start, end=social_end, freq='24h')

# Bin the start_timestamp into the created time bins
chasing_df['time_bin'] = pd.cut(chasing_df['start_timestamp'], bins=time_bins)

# Calculate the number of winning events and total events in each time bin and id
winning_counts = chasing_df.groupby(['time_bin', 'dominant_id']).size().reset_index(name='total_events')

# Calculate the fraction of winning events for each dominant_id in each time bin
winning_counts['fraction_winning'] = winning_counts.groupby('time_bin')['total_events'].transform(lambda x: x / x.sum())

# Create a new column for day numbers starting from 1
winning_counts['day_number'] = winning_counts['time_bin'].cat.codes + 1

# Create the line plot for total events
fig = px.line(
    winning_counts,
    x='day_number',
    y='total_events',
    color='dominant_id',
    title='Winning Events Over Time (Chasing)',
    labels={'day_number': 'Day Number', 'total_events': 'Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create the line plot for fraction of winning events
fig = px.line(
    winning_counts,
    x='day_number',
    y='fraction_winning',
    color='dominant_id',
    title='Fraction of Winning Events Over Time (Chasing)',
    labels={'day_number': 'Day Number', 'fraction_winning': 'Proportion of Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

In [None]:
# calculate te number of winning events over all events for every 42h time bin and plot over time
# Create a list of time bins
time_bins = pd.date_range(start=social_start, end=social_end, freq='24h')
# Bin the start_timestamp into the created time bins
tube_test_df['time_bin'] = pd.cut(tube_test_df['start_timestamp'], bins=time_bins)
# Calculate the number of winning events and total events in each time bin and id
winning_counts = tube_test_df.groupby(['time_bin', 'dominant_id']).size().reset_index(name='total_events')
# Calculate the fraction of winning events for each dominant_id in each time bin
winning_counts['fraction_winning'] = winning_counts.groupby('time_bin')['total_events'].transform(lambda x: x / x.sum())

# Convert the time_bin column to strings
winning_counts['time_bin'] = winning_counts['time_bin'].astype(str)


# Create the line plot
fig = px.line(
    winning_counts,
    x='time_bin',
    y='total_events',
    color = 'dominant_id',
    title='Winning Events Over Time (Tube test)',
    labels={'time_bin': 'Time Bin', 'total_events': 'Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create the line plot
fig = px.line(
    winning_counts,
    x='time_bin',
    y='fraction_winning',
    color = 'dominant_id',
    title='Fraction of Winning Events Over Time (Tube test)',
    labels={'time_bin': 'Time Bin', 'fraction_winning': 'Proportion of Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

## 1.2 Chasing speed and dominance

In [None]:
#TODO: we should do id swapping cleaning bfore calculating this (hope is that it doesnt matter much cuz two animals are pretty close to each other and synchronised)

In [None]:
# get this data if already calcuklated, if not calcualte here, but will ake a while
chasing_speed_data = pd.read_csv(base_path + '/all_chasing_videos/chasing_speed.csv')
chasing_speed_data

if chasing_speed_data.empty:
    print('No chasing speed data available')
    #get centroid data during chasing events
    block_subjects = (
        BlockAnalysis.Subject.proj('position_x', 'position_y', 'position_timestamps')
        & key
        & f'block_start >= "{social_start}"'
        & f'block_start <= "{social_end}"'
    )
    block_subjects_dict = block_subjects.fetch(as_dict=True)
    
    # Convert chasing_df timestamps to datetime
    chasing_df['start_timestamp'] = pd.to_datetime(chasing_df['start_timestamp'])
    chasing_df['end_timestamp'] = pd.to_datetime(chasing_df['end_timestamp'])

    # Initialize an empty list to store the filtered position data
    filtered_positions = []

    # Loop through each row in chasing_df so we only get position that are within a chase
    for _, row in chasing_df.iterrows():
        start_time = row['start_timestamp']
        end_time = row['end_timestamp']
        
        # Filter the position data for each subject
        for s in block_subjects_dict:
            # Convert position timestamps to datetime
            position_timestamps = pd.to_datetime(s['position_timestamps'])
            
            # Filter based on the start and end timestamps
            mask = (position_timestamps >= start_time) & (position_timestamps <= end_time)
            
            # Create a DataFrame for the filtered data
            filtered_data = pd.DataFrame(
        {
                    "subject_name": [s["subject_name"]] * sum(mask),
                    "position_timestamps": position_timestamps[mask],
                    "position_x": pd.Series(s["position_x"])[mask].values,
                    "position_y": pd.Series(s["position_y"])[mask].values,
                    "start_time": [start_time] * sum(mask),
                    "end_time": [end_time] * sum(mask)
                }
            )
            
            # Append the filtered data to the list
            filtered_positions.append(filtered_data)

    # Concatenate the list of filtered data into a single DataFrame
    subjects_positions_df = pd.concat(filtered_positions)
    
    # Calculate the avg speed of each chasing event
    # calculate the distance travelled by each subject and divide by the duration of chase
    subjects_positions_df["speed"] = (
        subjects_positions_df.groupby("subject_name")[["position_x", "position_y"]].diff().apply(np.linalg.norm, axis=1)
        / subjects_positions_df.reset_index()
        .groupby("subject_name")["position_timestamps"]
        .diff()
        .dt.total_seconds()
        .values
    )
    chasing_speed_data = subjects_positions_df

In [None]:
# get rid of nans, unrealistically high speeds and infs
cm2px = 5.4 
max_speed_threshold = 100 * cm2px # in cm/s
chasing_speed_data = chasing_speed_data[chasing_speed_data['speed'] < max_speed_threshold]
chasing_speed_data = chasing_speed_data[~chasing_speed_data['speed'].isna()]
chasing_speed_data = chasing_speed_data[~chasing_speed_data['speed'].isin([np.inf, -np.inf])]


In [None]:
# get average speed per chase
# Step 1: Group speed_df by 'chase_id' and 'subject_name' to calculate avg speed per subject per chase
avg_speed_per_subject = (
    chasing_speed_data.groupby(['start_time', 'subject_name'])['speed']
    .mean()
    .reset_index()
    .rename(columns={'speed': 'avg_speed_per_subject'})
)

# Step 2: Now, group again by 'chase_id' to average over the subjects' average speeds
avg_speed_per_chase = (
    avg_speed_per_subject.groupby('start_time')['avg_speed_per_subject']
    .mean()
    .reset_index()
    .rename(columns={'avg_speed_per_subject': 'avg_speed_per_chase'})
)
avg_speed_per_chase.rename(columns={'start_time': 'start_timestamp'}, inplace=True)

#merge back inot chase df
chasing_df['start_timestamp'] = pd.to_datetime(chasing_df['start_timestamp'])
avg_speed_per_chase['start_timestamp'] = pd.to_datetime(avg_speed_per_chase['start_timestamp'])
avg_chasing_speed_df = chasing_df.merge(avg_speed_per_chase, on='start_timestamp', how='left')
#get rid of rows whre missings speed
avg_chasing_speed_df = avg_chasing_speed_df[~avg_chasing_speed_df['avg_speed_per_chase'].isna()]

In [None]:
#plot the speed of chases per chaser id
# Initialize the figure
fig = go.Figure()

for dominant_id in avg_chasing_speed_df['dominant_id'].unique():
    if pd.isna(dominant_id):
        color = 'gray'  # Assign a default color for NaN values
    else:
        color = id_color_map.get(dominant_id, 'black')  # Fallback to black if the id is not in the map
    
    dominant_df = avg_chasing_speed_df[avg_chasing_speed_df['dominant_id'] == dominant_id]
    fig.add_trace(go.Box(
        x=dominant_df['dominant_id'],
        y=dominant_df['avg_speed_per_chase'],
        name=str(dominant_id),  # Convert NaN to string to display properly
        boxpoints='all',  # Show individual points
        jitter=0.3,  # Add some jitter to avoid overlap
        pointpos=-1.8,  # Position of the individual points (to the right of the box)
        marker=dict(color=color)
    ))
# Update layout
fig.update_layout(
    title='Average Speed of Chases per Chaser ID',
    xaxis_title='Chaser ID',
    yaxis_title='Average Speed',
    boxmode='group',  # Group box plots by chaser ID
    showlegend=False
)

# Filter data for the two chaser IDs
chaser1_speeds = avg_chasing_speed_df[avg_chasing_speed_df['dominant_id'] == unique_ids[0]]['avg_speed_per_chase']
chaser2_speeds = avg_chasing_speed_df[avg_chasing_speed_df['dominant_id'] == unique_ids[1]]['avg_speed_per_chase']

# Perform the t-test
t_stat, p_value_ttest = stats.ttest_ind(chaser1_speeds, chaser2_speeds, equal_var=False)


# Add annotations for statistical test results
fig.add_annotation(
    x=0.5, y=1.05, xref='paper', yref='paper',
    text=f"T-test p-value: {p_value_ttest:.3f}",
    showarrow=False,
    font=dict(size=12, color='black'),
    align='center'
)

# Show the plot
fig.show()

# 2. Dominance and foraging

In [None]:
#assign dominant id
# Calculate the frequency of each dominant_id in combined_df
dominant_id_counts = combined_df['dominant_id'].value_counts()
# Identify the dominant_id with the highest frequency
dominant_id = dominant_id_counts.idxmax()
subordinate_id = unique_ids[0] if unique_ids[0] != dominant_id else unique_ids[1]


## 2.1 Pellet consumption

Get foraging data:

In [None]:
#get pellet data
foraging_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps', 'patch_threshold')
    & {"spinnaker_video_source_name": "CameraTop"} #this is the video source name which we rstrict once we selected tuff to keep in table
    & key)

# Fetch the data
pellet_data = foraging_query.fetch()

# Initialize an empty list to store the data
data = []

# Loop through each entry in the pellet_data array
for entry in pellet_data:
    experiment_name = entry[0]
    block_start = entry[1]
    patch_name = entry[2]
    subject_id = entry[3]
    pellet_count = entry[4]
    pellet_timestamps = entry[5]
    patch_threshold =  entry[6]

    
    # For each pellet timestamp, create a dictionary and append to the list
    for pellet_timestamp, threshold in zip(pellet_timestamps, patch_threshold):
        pellet_timestamp = pd.to_datetime(pellet_timestamp)
        # Determine the period based on the timestamp
        if pre_solo_start <= pellet_timestamp <= pre_solo_end:
            period = 'pre_solo'
        elif social_start <= pellet_timestamp <= social_end:
            period = 'social'
        elif post_solo_start <= pellet_timestamp <= post_solo_end:
            period = 'post_solo'
        else:
            ValueError(f"Timestamp {pellet_timestamp} does not fall within any period")

        data.append({
            'time': pellet_timestamp,
            'subject_id': subject_id,
            'threshold': threshold,
            'rank': 'dominant' if subject_id == dominant_id else 'subordinate',
            'period': period
        })
# Convert the list of dictionaries into a DataFrame
pellet_df = pd.DataFrame(data)

# Convert the 'time' column to datetime
pellet_df['time'] = pd.to_datetime(pellet_df['time'])



In [None]:
# data cleaning: filter out pellets that are less than 2s after previos pellet for teh subject
pellet_df = pellet_df.sort_values(by=['subject_id', 'period', 'time'])
pellet_df['time_diff'] = pellet_df.groupby(['subject_id', 'period'])['time'].diff()
pellet_df['time_diff'] = pellet_df['time_diff'].dt.total_seconds()
print(f'Number of pellets before filtering: {len(pellet_df)}')
pellet_df = pellet_df[pellet_df['time_diff'].isna() | (pellet_df['time_diff'] > 2)]
print(f'Number of pellets after filtering: {len(pellet_df)}')

In [None]:
def create_raster_plot(df, id_color_map):
    # Create subplots: two for pre_solo, two for post_solo, and one for social
    fig = make_subplots(
        rows=5, cols=1, shared_xaxes=False,
        subplot_titles=(f'Pre Solo - {unique_ids[0]}', f'Pre Solo - {unique_ids[1]}', f'Post Solo - {unique_ids[0]}', f'Post Solo - {unique_ids[1]}', 'Social')
    )

    # Filter the DataFrame for pre_solo, post_solo, and social periods
    pre_solo_df = df[df['period'] == 'pre_solo']
    post_solo_df = df[df['period'] == 'post_solo']
    social_df = df[df['period'] == 'social']

    # Create scatter plots for each period if the DataFrame is not empty
    if not pre_solo_df.empty:
        for subject_id in pre_solo_df['subject_id'].unique():
            subject_df = pre_solo_df[pre_solo_df['subject_id'] == subject_id]
            pre_solo_scatter = px.scatter(
                subject_df,
                x='time',
                y='subject_id',
                color='subject_id',
                labels={'time': 'Time'},
                color_discrete_map=id_color_map
            ).data[0]
            pre_solo_scatter.showlegend = False
            row = 1 if subject_id == unique_ids[0] else 2
            fig.add_trace(pre_solo_scatter, row=row, col=1)

    if not post_solo_df.empty:
        for subject_id in post_solo_df['subject_id'].unique():
            subject_df = post_solo_df[post_solo_df['subject_id'] == subject_id]
            post_solo_scatter = px.scatter(
                subject_df,
                x='time',
                y='subject_id',
                color='subject_id',
                labels={'time': 'Time'},
                color_discrete_map=id_color_map
            ).data[0]
            post_solo_scatter.showlegend = False
            row = 3 if subject_id == unique_ids[0] else 4
            fig.add_trace(post_solo_scatter, row=row, col=1)
        
    if not social_df.empty:
        for subject_id in social_df['subject_id'].unique():
            subject_df = social_df[social_df['subject_id'] == subject_id]
            social_scatter = px.scatter(
                subject_df,
                x='time',
                y='subject_id',
                color='subject_id',
                labels={'time': 'Time'},
                color_discrete_map=id_color_map
            ).data[0]
            fig.add_trace(social_scatter, row=5, col=1)

    # Update layout
    fig.update_layout(
        title_text='Pellet Raster Plots',
        showlegend=True,
        height = 800
    )

    # Remove y-axis labels
    for i in range(1, 6):
        fig.update_yaxes(title_text='', row=i, col=1)

    # Show the plot
    fig.show()

# Create the raster plot for pre_solo, post_solo, and social periods
create_raster_plot(pellet_df, id_color_map)

In [None]:
# Define a function to process each period
def process_period(period_df, row, time_bin_size, plot_type='pellet_count', daily=False):
    if not period_df.empty:
        # Create a copy of the DataFrame to avoid modifying the original
        period_df_copy = period_df.copy()

        # Set the 'time' as the DataFrame index
        period_df_copy.set_index('time', inplace=True)

        if plot_type == 'pellet_count':
            # Bin the data by time and count the number of pellets per bin per subject
            binned_df = period_df_copy.groupby([pd.Grouper(freq=time_bin_size), 'subject_id']).size().reset_index(name='pellet_count')
        elif plot_type == 'threshold':
            # Bin the data by time and calculate the mean, standard deviation, and SEM of the threshold per bin per subject
            binned_df = period_df_copy.groupby([pd.Grouper(freq=time_bin_size), 'subject_id'])['threshold'].agg(['mean', 'std', 'count']).reset_index()
            binned_df.rename(columns={'mean': 'average_threshold', 'std': 'std_threshold', 'count': 'n'}, inplace=True)
            binned_df['sem_threshold'] = binned_df['std_threshold'] / np.sqrt(binned_df['n'])

        # Create a complete time range for the bins
        start_time = period_df['time'].min().floor(time_bin_size)
        end_time = period_df['time'].max().ceil(time_bin_size)
        time_range = pd.date_range(start=start_time, end=end_time, freq=time_bin_size)

        # Get unique subject IDs
        subject_ids = binned_df['subject_id'].unique()

        # Create a MultiIndex with all combinations of time_range and subject_ids
        multi_index = pd.MultiIndex.from_product([time_range, subject_ids], names=['time', 'subject_id'])

        # Create an empty DataFrame with the MultiIndex
        complete_df = pd.DataFrame(index=multi_index).reset_index()

        # Merge the complete DataFrame with the binned data
        complete_df = complete_df.merge(binned_df, on=['time', 'subject_id'], how='left').fillna(0)

        if plot_type == 'pellet_count':
            complete_df['pellet_count'] = complete_df['pellet_count'].astype(int)
            if daily:
                complete_df['hour_of_day'] = complete_df['time'].dt.hour
                complete_df['date'] = complete_df['time'].dt.date
                
                average_pellets_per_hour = complete_df.groupby(['subject_id', 'hour_of_day'])['pellet_count'].mean().reset_index()

                # Add individual daily data traces with faint lines
                for subject_id in complete_df['subject_id'].unique():
                    subject_data = complete_df[complete_df['subject_id'] == subject_id]
                    for date in subject_data['date'].unique():
                        daily_data = subject_data[subject_data['date'] == date]
                        fig.add_trace(go.Scatter(
                            x=daily_data['hour_of_day'],
                            y=daily_data['pellet_count'],
                            mode='lines',
                            name=f'{subject_id} (daily)',
                            line=dict(color=id_color_map[subject_id], width=1),
                            opacity=0.2,
                            showlegend=False
                        ), row=row, col=1)

                # Add the average data with more prominent lines
                for subject_id in average_pellets_per_hour['subject_id'].unique():
                    avg_data = average_pellets_per_hour[average_pellets_per_hour['subject_id'] == subject_id]
                    fig.add_trace(go.Scatter(
                        x=avg_data['hour_of_day'],
                        y=avg_data['pellet_count'],
                        mode='lines',
                        name=f'{subject_id} (average)',
                        line=dict(color=id_color_map[subject_id], width=2),
                        showlegend=False
                    ), row=row, col=1)
            else:
                for subject_id in subject_ids:
                    subject_df = complete_df[complete_df['subject_id'] == subject_id]
                    line_plot = px.line(
                        subject_df,
                        x='time',
                        y='pellet_count',
                        color='subject_id',
                        labels={'pellet_count': 'Pellet Count', 'time': 'Time', 'subject_id': 'Subject ID'},
                        color_discrete_map=id_color_map
                    ).data[0]
                    line_plot.showlegend = False
                    fig.add_trace(line_plot, row=row, col=1)
        elif plot_type == 'threshold':
            for subject_id in complete_df['subject_id'].unique():
                subject_df = complete_df[complete_df['subject_id'] == subject_id]
                fig.add_trace(go.Scatter(
                    x=subject_df['time'],
                    y=subject_df['average_threshold'],
                    mode='lines+markers',
                    name=f'{subject_id} Average Threshold',
                    line=dict(color=id_color_map[subject_id]),
                    error_y=dict(
                        type='data',
                        array=subject_df['sem_threshold'],
                        visible=True
                    )
                ), row=row, col=1)


In [None]:
key = {"experiment_name": "social0.2-aeon3"}
foraging_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps', 'patch_threshold')
    & {"spinnaker_video_source_name": "CameraTop"}
    & key
    & 'block_start >= "2024-02-26 00:00:00"'
    & 'block_start <= "2024-02-27 00:00:00"'
)

foraging_query

In [None]:
# Filter the DataFrame for pre_solo, post_solo, and social periods
pre_solo_df = pellet_df[pellet_df['period'] == 'pre_solo']
post_solo_df = pellet_df[pellet_df['period'] == 'post_solo']
social_df = pellet_df[pellet_df['period'] == 'social']

# Define unique IDs
unique_ids = pellet_df['subject_id'].unique()

# Create subplots for 1-hour bins
fig = make_subplots(
    rows=5, cols=1, shared_xaxes=False,
    subplot_titles=(
        f'Pre Solo - {unique_ids[0]}', 
        f'Pre Solo - {unique_ids[1]}', 
        f'Post Solo - {unique_ids[0]}', 
        f'Post Solo - {unique_ids[1]}',
        f'Social - {unique_ids[0]} and {unique_ids[1]}'
    )
)

# Process each period for 1-hour bins
process_period(pre_solo_df[pre_solo_df['subject_id'] == unique_ids[0]], 1, '1h', 'pellet_count')
process_period(pre_solo_df[pre_solo_df['subject_id'] == unique_ids[1]], 2, '1h', 'pellet_count')
process_period(post_solo_df[post_solo_df['subject_id'] == unique_ids[0]], 3, '1h', 'pellet_count')
process_period(post_solo_df[post_solo_df['subject_id'] == unique_ids[1]], 4, '1h', 'pellet_count')
process_period(social_df, 5, '1h', 'pellet_count')

# Update layout
fig.update_layout(
    title_text='Pellets during Pre, Post Solo, and Social Periods (Time-binned)',
    showlegend=True,
    height=1000
)

# Remove y-axis labels
for i in range(1, 6):
    fig.update_yaxes(title_text='', row=i, col=1)

# Show the plot
fig.show()

# Repeat the process for 24-hour bins
fig = make_subplots(
    rows=5, cols=1, shared_xaxes=False,
    subplot_titles=(
        f'Pre Solo - {unique_ids[0]}', 
        f'Pre Solo - {unique_ids[1]}', 
        f'Post Solo - {unique_ids[0]}', 
        f'Post Solo - {unique_ids[1]}',
        f'Social - {unique_ids[0]} and {unique_ids[1]}'
    )
)

process_period(pre_solo_df[pre_solo_df['subject_id'] == unique_ids[0]], 1, '24h', 'pellet_count')
process_period(pre_solo_df[pre_solo_df['subject_id'] == unique_ids[1]], 2, '24h', 'pellet_count')
process_period(post_solo_df[post_solo_df['subject_id'] == unique_ids[0]], 3, '24h', 'pellet_count')
process_period(post_solo_df[post_solo_df['subject_id'] == unique_ids[1]], 4, '24h', 'pellet_count')
process_period(social_df, 5, '24h', 'pellet_count')

# Update layout
fig.update_layout(
    title_text='Pellets during Pre, Post Solo, and Social Periods (Time-binned)',
    showlegend=True,
    height=1000
)

# Remove y-axis labels
for i in range(1, 6):
    fig.update_yaxes(title_text='', row=i, col=1)

# Show the plot
fig.show()

# Process threshold data
fig = make_subplots(
    rows=5, cols=1, shared_xaxes=False,
    subplot_titles=(
        f'Pre Solo - {unique_ids[0]}', 
        f'Pre Solo - {unique_ids[1]}', 
        f'Post Solo - {unique_ids[0]}', 
        f'Post Solo - {unique_ids[1]}',
        f'Social - {unique_ids[0]} and {unique_ids[1]}'
    )
)

process_period(pre_solo_df[pre_solo_df['subject_id'] == unique_ids[0]], 1, '24h', 'threshold')
process_period(pre_solo_df[pre_solo_df['subject_id'] == unique_ids[1]], 2, '24h', 'threshold')
process_period(post_solo_df[post_solo_df['subject_id'] == unique_ids[0]], 3, '24h', 'threshold')
process_period(post_solo_df[post_solo_df['subject_id'] == unique_ids[1]], 4, '24h', 'threshold')
process_period(social_df, 5, '24h', 'threshold')

# Update layout
fig.update_layout(
    title='Average Threshold with SEM Error Bars during Pre, Post Solo, and Social Periods (24h bins)',
    xaxis_title='Time',
    yaxis_title='Average Threshold',
    showlegend=False,
    height=1000
)

# Show the plot
fig.show()

# Process daily pellet count data
fig = make_subplots(
    rows=5, cols=1, shared_xaxes=False,
    subplot_titles=(
        f'Pre Solo - {unique_ids[0]}', 
        f'Pre Solo - {unique_ids[1]}', 
        f'Post Solo - {unique_ids[0]}', 
        f'Post Solo - {unique_ids[1]}',
        f'Social - {unique_ids[0]} and {unique_ids[1]}'
    )
)

process_period(pre_solo_df[pre_solo_df['subject_id'] == unique_ids[0]], 1, '1h', 'pellet_count', daily=True)
process_period(pre_solo_df[pre_solo_df['subject_id'] == unique_ids[1]], 2, '1h', 'pellet_count', daily=True)
process_period(post_solo_df[post_solo_df['subject_id'] == unique_ids[0]], 3, '1h', 'pellet_count', daily=True)
process_period(post_solo_df[post_solo_df['subject_id'] == unique_ids[1]], 4, '1h', 'pellet_count', daily=True)
process_period(social_df, 5, '1h', 'pellet_count', daily=True)

# Update layout
fig.update_layout(
    title='Pellets throughout the day',
    xaxis_title='Time',
    yaxis_title='Pellet Count',
    showlegend=False,
    height=1000  # Adjust the height as needed
)

# Show the plot
fig.show()

## 2.2 Patch types

Get patch prefernce data:

In [None]:
preference_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps', 'patch_threshold')
    * BlockSubjectAnalysis.Preference.proj('final_preference_by_wheel','final_preference_by_time')
    * BlockAnalysis.Patch.proj('patch_rate', 'patch_offset')
    & key
    & f'block_start >= "{social_start}"'
    & f'block_start <= "{social_end}"'
)
preference_data = preference_query.fetch()

# Initialize an empty list to store the data
preference_list = []

# Loop through each entry in the pellet_data array
for entry in preference_data:
    experiment_name = entry[0]
    block_start = entry[1]
    patch_name = entry[2]
    subject_id = entry[3]
    pellet_count = entry[4]
    #pellet_timestamps = entry[5]
    #patch_threshold = entry[6]
    final_preference_by_wheel = entry[7]
    final_preference_by_time = entry[8]
    patch_rate = entry[9]
    patch_offset = entry[10]
    rank = 'dominant' if subject_id == dominant_id else 'subordinate'
    patch_type = 'easy' if patch_rate == 0.01 else 'medium' if patch_rate == 0.0033 else 'hard' if patch_rate == 0.002 else 'unknown'

# Create a dictionary for the current entry
    data_dict = {
        'experiment_name': experiment_name,
        'block_start': block_start,
        'patch_name': patch_name,
        'subject_id': subject_id,
        'pellet_count': pellet_count,
        'final_preference_by_wheel': final_preference_by_wheel,
        'final_preference_by_time': final_preference_by_time,
        'patch_rate': patch_rate,
        'patch_offset': patch_offset,
        'rank': rank,
        'patch_type': patch_type
    }
    
    # Append the dictionary to the list
    preference_list.append(data_dict)

# Convert the list of dictionaries into a DataFrame
preference_df = pd.DataFrame(preference_list)

# Convert the 'block_start' column to datetime
preference_df['block_start'] = pd.to_datetime(preference_df['block_start'])

# Display the DataFrame
preference_df.head()

In [None]:
# Filter for blocks where both subjects got above the threshold pellets overall
pellet_threshold = 10

# Group by block_start, patch_name, and subject_id and calculate the sum of pellet_count
grouped_df = preference_df.groupby(['block_start', 'patch_name', 'subject_id']).agg({'pellet_count': 'sum'}).reset_index()

# Filter the groups where the sum of pellet_count is above the threshold
filtered_groups = grouped_df[grouped_df['pellet_count'] >= pellet_threshold]

# Get unique values for block_start, patch_name, and subject_id
unique_blocks = filtered_groups['block_start'].unique()
unique_patches = preference_df['patch_name'].unique()
unique_subjects = preference_df['subject_id'].unique()

# Create a DataFrame with all possible combinations of block_start, patch_name, and subject_id
all_combinations = pd.MultiIndex.from_product([unique_blocks, unique_patches, unique_subjects], names=['block_start', 'patch_name', 'subject_id']).to_frame(index=False)

# Merge the all_combinations DataFrame with the original DataFrame to retain all columns
preference_df_filtered = all_combinations.merge(preference_df, on=['block_start', 'patch_name', 'subject_id'], how='left').fillna(0)

preference_df_filtered.head()

In [None]:
def plot_variable_by_patch_rate(df, variable, title, yaxis_title, id_color_map, category_order=['easy', 'medium', 'hard']):
    """
    Plots a box plot for a given variable by patch rate for each subject.

    Parameters:
    - df: DataFrame containing the data
    - variable: The variable to plot (column name in the DataFrame)
    - title: Title of the plot
    - yaxis_title: Title of the y-axis
    - id_color_map: Dictionary mapping subject IDs to colors
    - category_order: Order of categories for the x-axis (default is ['easy', 'medium', 'hard'])
    """
    # Initialize the figure
    fig = go.Figure()

    # Add box plot for each subject
    for subject_id in df['subject_id'].unique():
        subject_df = df[df['subject_id'] == subject_id]
        fig.add_trace(go.Box(
            x=subject_df['patch_type'],
            y=subject_df[variable],
            name=subject_id,
            boxpoints='all',  # Show individual points
            jitter=0.25,  # Add some jitter for individual points
            pointpos=-1.4,  # Position of the individual points (to the right)
            marker=dict(color=id_color_map[subject_id])
        ))

    # Update layout
    fig.update_layout(
        title=title,
        xaxis_title='Patch Rate',
        yaxis_title=yaxis_title,
        boxmode='group',  # Group box plots by patch_type
        xaxis=dict(categoryorder='array', categoryarray=category_order),  
        showlegend=True
    )

    # Show the plot
    fig.show()


In [None]:
plot_variable_by_patch_rate(
    df=preference_df_filtered,
    variable='pellet_count',
    title='Pellets per Patch Rate',
    yaxis_title='Pellet Count',
    id_color_map=id_color_map
)

plot_variable_by_patch_rate(
    df=preference_df_filtered,
    variable='final_preference_by_time',
    title='Preference Index per Patch Rate',
    yaxis_title='Preference Index (time)',
    id_color_map=id_color_map
)

plot_variable_by_patch_rate(
    df=preference_df_filtered,
    variable='final_preference_by_wheel',
    title='Preference Index per Patch Rate',
    yaxis_title='Preference Index (wheel)',
    id_color_map=id_color_map
)

In [None]:
def plot_daily_avg_variable(df, variable, id_color_map, title='Average Daily Value by Day'):
    """
    Plots the average daily value of a specified variable by day for each subject and patch type.

    Parameters:
    - df: DataFrame containing the data
    - variable: The variable to plot (column name in the DataFrame)
    - id_color_map: Dictionary mapping subject IDs to colors
    - title: Title of the plot (default is 'Average Daily Value by Day')
    """
    # Ensure block_start is a datetime object
    df['block_start'] = pd.to_datetime(df['block_start'])

    # Group by day, subject_id, and patch_type, and calculate the average of the specified variable
    df['day'] = df['block_start'].dt.date
    daily_avg_df = df.groupby(['day', 'subject_id', 'patch_type']).agg({variable: 'mean'}).reset_index()

    # Exclude incomplete days
    first_day = daily_avg_df['day'].min()
    last_day = daily_avg_df['day'].max()
    filtered_daily_avg_df = daily_avg_df[(daily_avg_df['day'] != first_day) & (daily_avg_df['day'] != last_day)]

    # Create the plot
    fig = px.line(
        filtered_daily_avg_df,
        x='day',
        y=variable,
        color='subject_id',
        line_dash='patch_type',
        title=title,
        labels={'day': 'Day', variable: variable.replace('_', ' ').title()},
        category_orders={'patch_type': ['easy', 'medium', 'hard']},
        color_discrete_map=id_color_map
    )

    # Show the plot
    fig.show()

In [None]:
plot_daily_avg_variable(preference_df_filtered, 'pellet_count', id_color_map, title='Average Daily Pellets by Day')
plot_daily_avg_variable(preference_df_filtered, 'final_preference_by_wheel', id_color_map, title='Average Daily Preference Index by Day')
plot_daily_avg_variable(preference_df_filtered, 'final_preference_by_time', id_color_map, title='Average Daily Preference Index by Day')

# 3. Dominance and other variables

## 3.1 Weight

In [None]:
weight_query = (
    BlockAnalysis.Subject.proj('weights', 'weight_timestamps')
    & {"spinnaker_video_source_name": "CameraTop"} #this is the video source name which we rstrict once we selected tuff to keep in table
    & key
    & f'block_start >= "{social_start}"'
    & f'block_start <= "{social_end}"'
)

weight_data =  weight_query.fetch()

# Initialize an empty list to store the data
data = []

# Loop through each entry in the weight_data array
for entry in weight_data:
    subject_name = entry[0]
    block_start = entry[1]
    subject_id = entry[2]
    weights = entry[3]
    weight_timestamps = entry[4]
    
    # For each weight measurement, create a dictionary and append to the list
    for time, weight in zip(weight_timestamps, weights):
        data.append({'time': time, 'weight': weight, 'subject_id': subject_id})

# Convert the list of dictionaries into a DataFrame
weight_df = pd.DataFrame(data)

# Convert the 'time' column to datetime
weight_df['time'] = pd.to_datetime(weight_df['time'])


In [None]:
#data cleaning: filter out measurements below 25g TODO: come up with better way to clean data
weight_df = weight_df[weight_df['weight'] > 25]

In [None]:
# Create the line plot
fig = px.line(
    weight_df,
    x='time',
    y='weight',
    color='subject_id',
    title='Weight during social period',
    labels={'weight': 'Weight (g)', 'time': 'Time', 'subject_id': 'Subject ID'} ,
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

In [None]:
# Set the 'time' column as the index
time_weight_df = weight_df.set_index('time', inplace=False)

# Resample the data into hourly bins and calculate the average weight for each bin
avg_weight_df = time_weight_df.groupby('subject_id').resample('24h').mean()
#print(avg_weight_df)

# Reset the index to use 'time' as a column again
avg_weight_df.reset_index(inplace=True)

# Create the line plot with resampled and interpolated data
fig = px.line(
    avg_weight_df,
    x='time',
    y='weight',
    color='subject_id',
    title='Average Weight during Social Period (Daily)',
    labels={'weight': 'Weight (g)', 'time': 'Time', 'subject_id': 'Subject ID'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# 4.Social interactions and foraging

## 4.1 Temporal relationship between interactions and foraging

In [None]:
#Perpare interaction data for heatmmaps

# Check if 'start_timestamp' is already the index
if 'start_timestamp' not in combined_df.index.names:
    combined_df.set_index('start_timestamp', inplace=True)
# Resample the data into 1-hour bins and count the number of events for each behaviour type
events_per_hour = combined_df.groupby('behavior_type').resample('1h').size().reset_index(name='event_count')

#normalise teh event counts for each behaviour type to get a metric between 0-1
# Calculate the min and max event count for each behavior type
min_max_counts = events_per_hour.groupby('behavior_type')['event_count'].agg(['min', 'max']).reset_index()
min_max_counts.rename(columns={'min': 'min_event_count', 'max': 'max_event_count'}, inplace=True)

# Merge the min and max event counts back to the original DataFrame
events_per_hour = events_per_hour.merge(min_max_counts, on='behavior_type')

# Apply min-max normalization
events_per_hour['normalized_event_count'] = (
    (events_per_hour['event_count'] - events_per_hour['min_event_count']) /
    (events_per_hour['max_event_count'] - events_per_hour['min_event_count'])
)

# Drop the 'min_event_count' and 'max_event_count' columns if not needed
events_per_hour.drop(columns=['min_event_count', 'max_event_count'], inplace=True)



In [None]:
# Prepare pellet data for heatmaps
# Define the time bin size (e.g., 1 hour)
time_bin_size = '1h'
# Create a copy of the DataFrame to avoid modifying the original
pellet_df_copy = pellet_df.copy()

# Set the 'time' as the DataFrame index
pellet_df_copy.set_index('time', inplace=True)

# Bin the data by time and count the number of pellets per bin per subject
binned_pellet_df = pellet_df_copy.groupby([pd.Grouper(freq=time_bin_size), 'subject_id']).size().reset_index(name='pellet_count')
    
# Create a complete time range for the bins
time_range = pd.date_range(start=social_start, end=social_end, freq=time_bin_size)

# Get unique subject IDs
subject_ids = binned_pellet_df['subject_id'].unique()

# Create a MultiIndex with all combinations of time_range and subject_ids
multi_index = pd.MultiIndex.from_product([time_range, subject_ids], names=['time', 'subject_id'])

# Create an empty DataFrame with the MultiIndex
pellets_per_hour = pd.DataFrame(index=multi_index).reset_index()

# Merge the complete DataFrame with the binned pellet data
pellets_per_hour = pellets_per_hour.merge(binned_pellet_df, on=['time', 'subject_id'], how='left').fillna(0)

# Ensure pellet_count is an integer
pellets_per_hour['pellet_count'] = pellets_per_hour['pellet_count'].astype(int)

#normalise hte pellet counts pr subject
# Calculate the min and max event count for each behavior type
min_max_counts = pellets_per_hour.groupby('subject_id')['pellet_count'].agg(['min', 'max']).reset_index()
min_max_counts.rename(columns={'min': 'min_event_count', 'max': 'max_event_count'}, inplace=True)

# Merge the min and max event counts back to the original DataFrame
pellets_per_hour = pellets_per_hour.merge(min_max_counts, on='subject_id')

# Apply min-max normalization
pellets_per_hour['normalized_pellet_count'] = (
    (pellets_per_hour['pellet_count'] - pellets_per_hour['min_event_count']) /
    (pellets_per_hour['max_event_count'] - pellets_per_hour['min_event_count'])
)

# Drop the 'min_event_count' and 'max_event_count' columns if not needed
pellets_per_hour.drop(columns=['min_event_count', 'max_event_count'], inplace=True)


In [None]:
#merge two dfs
#add pellet count to event table table
events_per_hour.rename(columns={'start_timestamp': 'time'}, inplace=True)
# Merge the DataFrames on 'time' and 'subject_id'
merged_df = pd.merge(events_per_hour, pellets_per_hour, on=['time'], how='left')


In [None]:
# Pivot the data to create a DataFrame for the heatmap
heatmap_data = merged_df.pivot_table(
    index='time',
    columns=['subject_id', 'behavior_type'],
    values=['normalized_event_count', 'normalized_pellet_count']
)

# Flatten the MultiIndex to create unique labels for the heatmap rows
heatmap_data.columns = [f'{val}_{subj}_{beh}' for val, subj, beh in heatmap_data.columns]

# Create a new DataFrame with just the rows you want to plot

# Extract pellet counts for both subjects (assuming no behavior type distinction is needed)
pellet_counts_47 = heatmap_data.filter(like='normalized_pellet_count_BAA-1104047')
pellet_counts_45 = heatmap_data.filter(like='normalized_pellet_count_BAA-1104045')

# Average event counts across both subjects for each behavior type
event_counts_chasing = heatmap_data.filter(like='chasing').mean(axis=1) *2
event_counts_fighting = heatmap_data.filter(like='fighting').mean(axis=1) *2
event_counts_tube_test = heatmap_data.filter(like='tube_test').mean(axis=1) *2

# Combine the data into a single DataFrame for plotting
heatmap_final = pd.DataFrame({
    'Pellet Count BAA-1104047': pellet_counts_47.mean(axis=1),
    'Pellet Count BAA-1104045': pellet_counts_45.mean(axis=1),
    'Chasing Events': event_counts_chasing,
    'Fighting Events': event_counts_fighting,
    'Tube Test Events': event_counts_tube_test
}, index=heatmap_data.index)

# Create the heatmap
fig = px.imshow(
    heatmap_final.T,  # Transpose to have the measurements on the y-axis
    labels={'x': 'Time (1h Bins)', 'y': 'Measurement'},
    color_continuous_scale='Viridis',
    aspect='auto',
    zmin=0, zmax=1  # Set the color scale to range from 0 to 1
)

# Update layout
fig.update_layout(
    title='Pellet Counts and Events During Social Period',
    xaxis_title='Time (1h Bins)',
    yaxis_title=None,
    coloraxis_colorbar=dict(title='Normalized Value'),
    template='plotly_white',
    height=500  # Adjust height as needed
)

# Show the plot
fig.show()

Comparing pellet counts after interactions vs at actie times when no interactions.

In [None]:
#getting interactions not folowed by other interactions
time_window_length = 5  # in minutes

# Initialize a DataFrame to store good chases
good_events_df = pd.DataFrame()
combined_df = combined_df.reset_index()

# Iterate through each row in chasing_df
for index, event in combined_df.iterrows():
    event_time = event['start_timestamp']

    # Calculate the start and end time of the time window
    start_time = event_time
    end_time = event_time + pd.Timedelta(minutes=time_window_length)
    
    # Filter out events that happened within the time window
    events_in_time_window = combined_df[
        (combined_df['start_timestamp'] >= start_time) & 
        (combined_df['start_timestamp'] <= end_time) & 
        (combined_df.index != index)  # Exclude the current event
    ]
    
    # If no events are found in the time window, add the chase to good_chases_df
    if events_in_time_window.empty:
        good_events_df = pd.concat([good_events_df, event.to_frame().T])

In [None]:
#For baseline, choose random timestamps duirng the light off OR not in nest (as a proxy for awake times) where no events occur 

# Function to check if a timestamp is valid
def is_valid_timestamp(new_timestamp, existing_timestamps, time_window, night_start, night_end):
    # Extract time from the timestamp
    new_time = new_timestamp.time()
    
    # Check if the timestamp falls within the light cycle period
    if not (night_start <= new_time < night_end):
        return False
    
    # Check if the timestamp is within the time window length
    for ts in existing_timestamps:
        if abs((new_timestamp - ts).total_seconds()) < time_window.total_seconds():
            return False
    return True

start_timestamps = combined_df['start_timestamp'].sort_values().reset_index(drop=True)
time_window = pd.Timedelta(minutes=time_window_length)   

# Generate  random valid timestamps
valid_timestamps = []
min_timestamp = start_timestamps.min()
max_timestamp = start_timestamps.max()

while len(valid_timestamps) < len(good_events_df):
    random_timestamp = min_timestamp + timedelta(seconds=np.random.randint(0, int((max_timestamp - min_timestamp).total_seconds())))
    if is_valid_timestamp(random_timestamp, start_timestamps, time_window, datetime.strptime(night_start, '%H:%M').time(), datetime.strptime(night_end, '%H:%M').time()):
        valid_timestamps.append(random_timestamp)
        start_timestamps = pd.concat([start_timestamps, pd.Series([random_timestamp])]).sort_values().reset_index(drop=True)

# Convert the list of valid timestamps to a DataFrame
valid_timestamps_df = pd.DataFrame(valid_timestamps, columns=['random_timestamp'])

In [None]:
#Cocatenate timestmps for event a no event post periods
# Extract start_timestamps from good_events
good_events_start_timestamps = good_events_df['start_timestamp']
good_events_behaviour_type = good_events_df['behavior_type']
good_events_dominant_id = good_events_df['dominant_id']

# Create DataFrame for good_events with event column set to True
good_events_df_with_event = pd.DataFrame({
    'start_timestamps': good_events_start_timestamps,
    'event': True,
    'behaviour_type': good_events_behaviour_type,
    'dominant_id': good_events_dominant_id
})

# Create DataFrame for valid_timestamps with event column set to False
valid_timestamps_df_with_event = pd.DataFrame({
    'start_timestamps': valid_timestamps_df['random_timestamp'],
    'event': False
})
# Concatenate the DataFrames
all_timestamps_df = pd.concat([good_events_df_with_event, valid_timestamps_df_with_event]).sort_values(by='start_timestamps').reset_index(drop=True)

# Add end_timestamps column
all_timestamps_df['end_timestamps'] = all_timestamps_df['start_timestamps'] + timedelta(minutes=time_window_length)

# Display the new DataFrame
all_timestamps_df

In [None]:
#get pellets within time window
results = []

# Get all unique subject IDs
all_subject_ids = pellet_df['subject_id'].unique()

for timestamp in all_timestamps_df.iterrows():
    start_time = timestamp[1]['start_timestamps']
    end_time = timestamp[1]['end_timestamps']
    
    # Get the pellet data within the time range
    pellet_data = pellet_df[(pellet_df['time'] >= start_time) & (pellet_df['time'] <= end_time)]
    
    # Calculate the total number of pellets per subject
    total_pellets = pellet_data['subject_id'].value_counts()
    
    # Append results for each subject, including those with zero pellets
    for subject_id in all_subject_ids:
        pellet_count = total_pellets.get(subject_id, 0)
        result = timestamp[1].to_dict()  # Copy all columns from all_timestamps_df
        result.update({
            'subject_id': subject_id,
            'pellet_number': pellet_count
        })
        results.append(result)

# Convert results to DataFrame
event_pellets_df = pd.DataFrame(results)

# Replace 'nan' values in 'behaviour_type' column with 'no_interaction'
event_pellets_df['behaviour_type'] = event_pellets_df['behaviour_type'].fillna('no_interaction')

# Display the new DataFrame
event_pellets_df

In [None]:
# Only get events where both subjects had pellets after the event
grouped = event_pellets_df.groupby('start_timestamps')

# Define a function to filter groups
def filter_group(group):
    # Check if all pellet numbers are non-null and non-zero
    return group['pellet_number'].notna().all() and (group['pellet_number'] != 0).all()

# Apply the filter function to each group and concatenate the results
filtered_df = pd.concat([group for name, group in grouped if filter_group(group)])


In [None]:
# Define the order of categories for 'behaviour_type'
category_order = ['no_interaction'] + [x for x in filtered_df['behaviour_type'].unique() if x != 'no_interaction']


# Create the scatter plot
fig = px.box(
    filtered_df,
    x='behaviour_type',
    y='pellet_number',
    color='subject_id',
    title='Pellets post events (if at least 1 pellet)',
    labels={'subject_id': 'Subject ID', 'pellet_number': 'Pellets 5min post event', 'behaviour_type': 'Social interaction type'},
    category_orders={'behaviour_type': category_order},
    color_discrete_map=id_color_map,
    points='all',
)

# Show the plot
fig.show()


# Create the scatter plot
fig = px.box(
    filtered_df[filtered_df['subject_id'] != 'nan'],
    x='event',
    y='pellet_number',
    color='subject_id',
    title='Pellets after social interaction vs no interaction',
    labels={'subject_id': 'Subject ID', 'pellet_number': 'Pellets 5min post event', 'event': 'Post social interation'},
    category_orders={'behaviour_type': category_order},
    color_discrete_map=id_color_map,
    points='all',
)

# Show the plot
fig.show()




## 4.2 Outcome of intractions and foraging

In [None]:

# Get number of pellets per subject in each post-chase period
results = []
for index, chase in chasing_df.iterrows():
    dominant_id = str(chase['dominant_id'])
    subordinate_id = (set(unique_ids) - {dominant_id}).pop()
    chase_time = chase['start_timestamp']
    start_time = chase_time
    end_time = chase_time + pd.Timedelta(minutes=time_window)
    
    pellet_data = pellet_df[(pellet_df['time'] >= start_time) & (pellet_df['time'] <= end_time)]
    total_pellets = pellet_data['subject_id'].value_counts()
    
    for subject_id, outcome in [(dominant_id, 'dominant'), (subordinate_id, 'subordinate')]:
        results.append({
            'start_timestamps': start_time,
            'end_timestamps': end_time,
            'outcome': outcome,
            'subject_id': subject_id,
            'pellet_number': total_pellets.get(subject_id, 0)
        })

results_df = pd.DataFrame(results)

# Filter events where both subjects had pellets
results_df = results_df.groupby('start_timestamps').filter(lambda group: group['pellet_number'].notna().all() and (group['pellet_number'] != 0).all())

# Plot number of pellets post tubetest per subject when dominant and when subordinate
fig = px.box(
    results_df[results_df['subject_id'] != 'nan'],
    x='subject_id',
    y='pellet_number',
    color='outcome',
    title='Pellets post chase based on outcome (all events)',
    labels={'subject_id': 'Subject ID', 'pellet_number': f'Pellets {time_window}min post event', 'outcome': 'Outcome of event'},
    points='all',
)

fig.show()


# Get number of pellets per subject in each post-tube test period
results = []
for index, test in tube_test_df.iterrows():
    dominant_id = str(test['dominant_id'])
    subordinate_id = (set(unique_ids) - {dominant_id}).pop()
    test_time = test['start_timestamp']
    start_time = test_time
    end_time = test_time + pd.Timedelta(minutes=time_window)
    
    pellet_data = pellet_df[(pellet_df['time'] >= start_time) & (pellet_df['time'] <= end_time)]
    total_pellets = pellet_data['subject_id'].value_counts()
    
    for subject_id, outcome in [(dominant_id, 'dominant'), (subordinate_id, 'subordinate')]:
        results.append({
            'start_timestamps': start_time,
            'end_timestamps': end_time,
            'outcome': outcome,
            'subject_id': subject_id,
            'pellet_number': total_pellets.get(subject_id, 0)
        })

results_df = pd.DataFrame(results)

# Filter events where both subjects had pellets
results_df = results_df.groupby('start_timestamps').filter(lambda group: group['pellet_number'].notna().all() and (group['pellet_number'] != 0).all())

# Plot number of pellets post tube test per subject when dominant and when subordinate
fig = px.box(
    results_df[results_df['subject_id'] != 'nan'],
    x='subject_id',
    y='pellet_number',
    color='outcome',
    title='Pellets post tube test based on outcome (all events)',
    labels={'subject_id': 'Subject ID', 'pellet_number': f'Pellets {time_window}min post event', 'outcome': 'Outcome of event'},
    points='all',
)

fig.show()

In [None]:
# Create the outcome variable
event_pellets_df['outcome'] = event_pellets_df.apply(
    lambda row: 'dominant' if (row['dominant_id'] == row['subject_id'] and 
                               row['dominant_id'] in unique_ids and 
                               row['subject_id'] in unique_ids) else 
                'subordinate' if (row['dominant_id'] != row['subject_id'] and 
                                  row['dominant_id'] in unique_ids and 
                                  row['subject_id'] in unique_ids) else None,
    axis=1
)

# Create the scatter plot
filtered_chasing_df = event_pellets_df[
    (event_pellets_df['behaviour_type'] == 'chasing') &
    (event_pellets_df['dominant_id'].isin(unique_ids)) &
    (event_pellets_df['subject_id'].isin(unique_ids))
]
fig = px.box(
    filtered_chasing_df,
    x='subject_id',
    y='pellet_number',
    color='outcome',
    title='Pellets post chasing based on outcome (only if no events in time window)',
    labels={'subject_id': 'Subject ID', 'pellet_number': 'Pellets 5min post event', 'event': 'Post social interation'},
    category_orders={'behaviour_type': category_order},
    points='all',
)

# Show the plot
fig.show()

# Create the scatter plot
filtered_tube_test_df = event_pellets_df[
    (event_pellets_df['behaviour_type'] == 'tube_test') &
    (event_pellets_df['dominant_id'].isin(unique_ids)) &
    (event_pellets_df['subject_id'].isin(unique_ids))
]
fig = px.box(
    filtered_tube_test_df,
    x='subject_id',
    y='pellet_number',
    color='outcome',
    title='Pellets post tue test based on outcome (only if no events in time window)',
    labels={'subject_id': 'Subject ID', 'pellet_number': 'Pellets 5min post event', 'event': 'Post social interation'},
    category_orders={'behaviour_type': category_order},
    points='all',
)

# Show the plot
fig.show()