This notebook will have analysis to correlate of social inraction events and outcomes, e.g. tube tests, chasing, fighting. 
For this we will need:
- CSVs exported from event detection analyses, containing start and end timestamps/frames and ID of dominant/winner mouse. 
- DJ BlockAnalysis pipeline, whih has tools for computign patch preference over time
- external tube test data

In [None]:
# Social02 exp timeline

# Aeon3
#  2024-01-31 : 2024-02-03 - BAA-1104045 pre solo
#  2024-02-05 : 2024-02-08 - BAA-1104047 pre solo (dominant)
#  2024-02-09 : 2024-02-23 - BAA-1104045, BAA-1104047 social
#  2024-02-25 : 2024-02-28 - BAA-1104045 post solo
#  2024-02-28 : 2024-03-02 - BAA-1104047 post solo


# Aeon4
#  2024-01-31 : 2024-02-03 - BAA-1104048 pre solo (dominant)
#  2024-02-05 : 2024-02-08 - BAA-1104049 pre solo
#  2024-02-09 : 2024-02-23 - BAA-1104048, BAA-1104049 social
#  2024-02-25 : 2024-02-28 - BAA-1104048 post solo
#  2024-02-28 : 2024-03-02 - BAA-1104049 post solo

In [None]:

import psutil
import os
 
# Get the current process ID
pid = os.getpid()
 
# Create a Process object
process = psutil.Process(pid)
 
# Get the memory info
memory_info = process.memory_info()
 
print(f"RAM used by the notebook process: {memory_info.rss / 1024 / 1024:.2f} MB")
 

# 0. Import and process data

In [None]:
%load_ext autoreload
%autoreload 2

import datajoint as dj
import aeon
from aeon.schema.schemas import social02
from aeon.dj_pipeline.analysis.block_analysis import * #this connects to database and imports all tables in block_analysis

import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
import numpy as np
from plotly.subplots import make_subplots
from datetime import timedelta, datetime
from aeon.io import api
from scipy import stats
from shapely.geometry import Point, Polygon

We will get all data from one experiment and arena - so 2 mice, 2 weeks of data. 

Get social interaciton times and outcomes from heuristics based detection. 

In [None]:
# TODO: replace with revised csv path
base_path = '/ceph/aeon/aeon/code/scratchpad/Orsi/'
tube_test_path = 'all_tube_test_videos/AEON3_tube_tests_revised_final.csv'
fights_path = 'all_fighting_videos/AEON3_fights.csv'
chasing_path = 'all_chasing_videos/AEON3_chases.csv'

In [None]:
# Open CSV containing tube test data.
tube_test_df = pd.read_csv(base_path + tube_test_path)
# Open CSV containing fighting data.
fights_df = pd.read_csv(base_path + fights_path)
# Open CSV containing chasing data.
chasing_df = pd.read_csv(base_path + chasing_path)

In [None]:
# Convert the start and end timestamps to datetime
chasing_df['start_timestamp'] = pd.to_datetime(chasing_df['start_timestamp'], format='%Y-%m-%dT%H-%M-%S')
chasing_df['end_timestamp'] = pd.to_datetime(chasing_df['end_timestamp'], format='%Y-%m-%dT%H-%M-%S')

fights_df['start_timestamp'] = pd.to_datetime(fights_df['start_timestamp'], format='%Y-%m-%dT%H-%M-%S')
fights_df['end_timestamp'] = pd.to_datetime(fights_df['end_timestamp'], format='%Y-%m-%dT%H-%M-%S')

tube_test_df['start_timestamp'] = pd.to_datetime(tube_test_df['start_timestamp'], format='%Y-%m-%dT%H-%M-%S')
tube_test_df['end_timestamp'] = pd.to_datetime(tube_test_df['end_timestamp'], format='%Y-%m-%dT%H-%M-%S')

In [None]:
chasing_df. head()

In [None]:
#asessing the quality of chasing data
print(f"{(len(chasing_df[chasing_df['chaser_id'].isna()]))/len(chasing_df)} ratio of chases with non-assigned chaser_ids")
print(len(chasing_df))

# check te assignemnt of these maually
import random
random_numbers = [random.randint(0, len(chasing_df)) for _ in range(30)] #around 5% of the data
sorted_random_numbers = sorted(random_numbers)
print(sorted_random_numbers)

#after revising a few
revised_chasing_path = 'all_chasing_videos/AEON3_chases_revised.csv'
revised_chasing_df = pd.read_csv(base_path + revised_chasing_path)
# event checked manually
print(len(revised_chasing_df[~revised_chasing_df['revision'].isna()]))
# good checked chases
print(len(revised_chasing_df[revised_chasing_df['revision'] == 'Ok']))
# bad cheked chases
print(len(revised_chasing_df[revised_chasing_df['revision'] == '47']))
print(len(revised_chasing_df[revised_chasing_df['revision'] == '45']))
# error rate
error_count = len(revised_chasing_df[revised_chasing_df['revision'] == '47']) + len(revised_chasing_df[revised_chasing_df['revision'] == '45'])
good_count = len(revised_chasing_df[revised_chasing_df['revision'] == 'Ok'])
error_rate = error_count / good_count+error_count
print(error_rate)

In [None]:
tube_test_df. head()

In [None]:
fights_df. head()

Make csv for three behaviours hte saem format.

In [None]:
# Add a 'behavior_type' column to each data frame
chasing_df['behavior_type'] = 'chasing'
fights_df['behavior_type'] = 'fighting'
tube_test_df['behavior_type'] = 'tube_test'

In [None]:
chasing_df.rename(columns={'chaser_id': 'dominant_id'}, inplace=True)
tube_test_df.rename(columns={'winner_id': 'dominant_id'}, inplace=True)
fights_df['dominant_id'] = fights_df.get('dominant_id', 'NaN')


In [None]:
fights_df

In [None]:
# Combine the data frames
combined_df = pd.concat([chasing_df, fights_df, tube_test_df])
# Replace NaN values in 'dominant_id' with a string 'NaN'
combined_df['dominant_id'] = combined_df['dominant_id'].fillna('NaN')
combined_df.head()

In [None]:
unique_ids = np.unique(combined_df['dominant_id'])
unique_ids = unique_ids[unique_ids != 'NaN']
unique_ids

Get foraging data from DJ.

In [None]:
key = {"experiment_name": "social0.2-aeon3"}

In [None]:
schema = dj.schema(get_schema_name("block_analysis"))

In [None]:
social_start_day = '2024-02-09 00:00:00'
social_end_day = '2024-02-23 23:00:00'

In [None]:
social_start_time = '2024-02-09 16:00:00'
social_end_time = '2024-02-23 13:00:00'

In [None]:
pre_solo_start_day = '2024-01-31 00:00:00'
pre_solo_end_day= '2024-02-08 23:00:00'
post_solo_start_day = '2024-02-25 00:00:00'
post_solo_end_day = '2024-03-02 23:00:00'

In [None]:
# Define light cycle periods
night_start = '08:00'
night_end = '19:00'
twilight_start = '07:00'
twilight_end = '08:00'
dawn_start = '19:00'
dawn_end = '20:00'
day_start = '20:00'
day_end = '07:00'

Load metadata:


In [None]:
#just get metadata base on a random solo block

exp_start = pd.Timestamp("2024-01-31")

metadata = (
    api.load('/ceph/aeon/aeon/data/raw/AEON3/social0.2', social02.Metadata, exp_start, pd.Timestamp('2024-02-11 15:57:42')).iloc[0].metadata
)
metadata

Add pre/post tubetest data maually:

In [None]:
# add pre/post tue test data manually

# Create the data
pre_tube_test = {
    'behavior_type': ['pre_tube_test'] * 10,
    'dominant_id': ['BAA-1104045'] * 2 + ['BAA-1104047'] * 8
}

# Create the DataFrame
pre_tube_test = pd.DataFrame(pre_tube_test)


# Create the data
post_tube_test = {
    'behavior_type': ['pre_tube_test'] * 10,
    'dominant_id': ['BAA-1104045'] * 1 + ['BAA-1104047'] * 9
}

# Create the DataFrame
post_tube_test = pd.DataFrame(post_tube_test)



# 1. Temporal pattern and stability of dominance interactions

Raster of chasing per subject over time:

In [None]:

# Create a dictionary for color mapping
id_color_map = {
    'NaN': 'grey',
    'BAA-1104047': 'purple',
    'BAA-1104045': 'green'
} 
behaviour_map = {
    'chasing': 'blue',
    'fighting': 'red',
    'tube_test': 'orange'
} 


In [None]:
social_start_datetime = datetime.strptime(social_start_time, '%Y-%m-%d %H:%M:%S')
social_end_datetime = datetime.strptime(social_end_time, '%Y-%m-%d %H:%M:%S')

pre_solo_start_datetime =datetime.strptime(pre_solo_start_day, '%Y-%m-%d %H:%M:%S')
pre_solo_end_datetime= datetime.strptime(pre_solo_end_day, '%Y-%m-%d %H:%M:%S')
post_solo_start_datetime = datetime.strptime(post_solo_start_day, '%Y-%m-%d %H:%M:%S')
post_solo_end_datetime = datetime.strptime(post_solo_end_day, '%Y-%m-%d %H:%M:%S')

In [None]:

# Create the raster plot
fig = px.scatter(
    combined_df,
    x='start_timestamp',
    y='behavior_type',
    color='dominant_id',
    title='Behavior Raster Plot',
    labels={'start_timestamp': 'Time', 'behavior_type': 'Behavior Type'},
    color_discrete_map=id_color_map
)

# Set x-axis limits
fig.update_xaxes(range=[social_start_time, social_end_time])
# Set x-axis limits
fig.update_yaxes(range=[-1, 2.5])  # Example date range

# Iterate over each day in the two-week period
current_day = social_start_datetime
while current_day < social_end_datetime:
    # Define the start and end times for each period
    night_start_time = current_day.replace(hour=int(night_start.split(':')[0]), minute=int(night_start.split(':')[1]))
    night_end_time = current_day.replace(hour=int(night_end.split(':')[0]), minute=int(night_end.split(':')[1]))
    twilight_start_time = current_day.replace(hour=int(twilight_start.split(':')[0]), minute=int(twilight_start.split(':')[1]))
    twilight_end_time = current_day.replace(hour=int(twilight_end.split(':')[0]), minute=int(twilight_end.split(':')[1]))
    dawn_start_time = current_day.replace(hour=int(dawn_start.split(':')[0]), minute=int(dawn_start.split(':')[1]))
    dawn_end_time = current_day.replace(hour=int(dawn_end.split(':')[0]), minute=int(dawn_end.split(':')[1]))
    day_start_time = current_day.replace(hour=int(day_start.split(':')[0]), minute=int(day_start.split(':')[1]))
    day_end_time = current_day + timedelta(days=1) if int(day_end.split(':')[0]) < int(day_start.split(':')[0]) else current_day.replace(hour=int(day_end.split(':')[0]), minute=int(day_end.split(':')[1]))
    # Add horizontal lines for light and dark periods
    fig.add_shape(
        type="line",
        x0=twilight_start_time,
        x1=twilight_end_time,
        y0=-0.5,
        y1=-0.5,
        line=dict(color="gray", width=4)
    )
    fig.add_shape(
        type="line",
        x0=night_start_time,
        x1=night_end_time,
        y0=-0.5,
        y1=-0.5,
        line=dict(color="black", width=4)
    )
    fig.add_shape(
        type="line",
        x0=dawn_start_time,
        x1=dawn_end_time,
        y0=-0.5,
        y1=-0.5,
        line=dict(color="gray", width=4)
    )
    fig.add_shape(
        type="line",
        x0=day_start_time,
        x1=day_end_time,
        y0=-0.5,
        y1=-0.5,
        line=dict(color="white", width=4)
    )
    # Move to the next day
    current_day += timedelta(days=1)

# Show the plot
fig.show()

In [None]:
# TODO: could also do this based no on discrete events but duraiton of time spent interacting
# Create a list of time bins
time_bins = pd.date_range(start=social_start_time, end=social_end_time, freq='24h')

# Bin the start_timestamp into the created time bins
combined_df['time_bin'] = pd.cut(combined_df['start_timestamp'], bins=time_bins)

# Calculate the number of interactions for each behavior in each time bin
interaction_counts = combined_df.groupby(['time_bin', 'behavior_type']).size().reset_index(name='total_interactions')

# Convert the time_bin column to strings
# Extract the start time of each bin
interaction_counts['time_bin_start'] = interaction_counts['time_bin'].apply(lambda x: x.left)

# Convert the start time to a string format
interaction_counts['time_bin_start'] = interaction_counts['time_bin_start'].dt.strftime('%Y-%m-%d %H:%M')

print(interaction_counts)


In [None]:


# Create the line plot
fig = px.line(
    interaction_counts,
    x='time_bin_start',
    y='total_interactions',
    color='behavior_type',
    title='Number of Interactions Over Time for Each Behavior',
    labels={'time_bin_start': 'Time Bin', 'total_interactions': 'Number of Interactions'},
    color_discrete_map=behaviour_map
)

# Show the plot
fig.show()

Proportion of winning and losing in two animals:

In [None]:
# Aggregate the data to count the occurrences of each dominant_id for chasing
chaser_counts_chasing = chasing_df['dominant_id'].value_counts().reset_index()
chaser_counts_chasing.columns = ['dominant_id', 'count']

# Aggregate the data to count the occurrences of each dominant_id for tube test
chaser_counts_tube_test = tube_test_df['dominant_id'].value_counts().reset_index()
chaser_counts_tube_test.columns = ['dominant_id', 'count']

# Aggregate the data to count the occurrences of each dominant_id for tube test
chaser_counts_pre_tube_test = pre_tube_test['dominant_id'].value_counts().reset_index()
chaser_counts_pre_tube_test.columns = ['dominant_id', 'count']

# Aggregate the data to count the occurrences of each dominant_id for tube test
chaser_counts_post_tube_test = post_tube_test['dominant_id'].value_counts().reset_index()
chaser_counts_post_tube_test.columns = ['dominant_id', 'count']

# Create scatter plots using plotly express
fig_chasing = px.scatter(
    chaser_counts_chasing,
    x='dominant_id',
    y='count',
    color='dominant_id',
    title='Count of Chases per Chaser ID',
    labels={'dominant_id': 'Chaser ID', 'count': 'Count of Chases'},
    color_discrete_map=id_color_map
)

fig_tube_test = px.scatter(
    chaser_counts_tube_test,
    x='dominant_id',
    y='count',
    color='dominant_id',
    title='Count of Tube Tests per Chaser ID',
    labels={'dominant_id': 'Chaser ID', 'count': 'Count of Tube Tests'},
    color_discrete_map=id_color_map
)

fig_pre_tube_test = px.scatter(
    chaser_counts_pre_tube_test,
    x='dominant_id',
    y='count',
    color='dominant_id',
    title='Count of Tube Tests per Chaser ID',
    labels={'dominant_id': 'Chaser ID', 'count': 'Count of Tube Tests'},
    color_discrete_map=id_color_map
)

fig_post_tube_test = px.scatter(
    chaser_counts_post_tube_test,
    x='dominant_id',
    y='count',
    color='dominant_id',
    title='Count of Tube Tests per Chaser ID',
    labels={'dominant_id': 'Chaser ID', 'count': 'Count of Tube Tests'},
    color_discrete_map=id_color_map
)

# Create subplots
fig = make_subplots(rows=1, cols=4, subplot_titles=('Chasing', 'Tube Test', 'Pre Tube Test', 'Post Tube Test'))

# Add traces to the subplots
for trace in fig_chasing['data']:
    fig.add_trace(trace, row=1, col=1)

for trace in fig_tube_test['data']:
    fig.add_trace(trace, row=1, col=2)
    
for trace in fig_pre_tube_test['data']:
    fig.add_trace(trace, row=1, col=3)
    
for trace in fig_post_tube_test['data']:
    fig.add_trace(trace, row=1, col=4)

# Set y-axis limits for both subplots
fig.update_yaxes(range=[0, chaser_counts_chasing['count'].max() + 10], row=1, col=1)
fig.update_yaxes(range=[0, chaser_counts_tube_test['count'].max() + 10], row=1, col=2)
fig.update_yaxes(range=[0, chaser_counts_pre_tube_test['count'].max() + 10], row=1, col=3)
fig.update_yaxes(range=[0, chaser_counts_post_tube_test['count'].max() + 10], row=1, col=4)

# Update layout
fig.update_layout(title_text='Count of Chases and Tube Tests per Chaser ID')

# Show the plot
fig.show()

In [None]:
# Calculate the fraction of winning for chasing
chaser_counts_chasing = chasing_df['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_chasing.columns = ['dominant_id', 'p_wins']
chaser_counts_chasing['behavior_type'] = 'Chasing'

# Calculate the fraction of winning for tube test
chaser_counts_tube_test = tube_test_df['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_tube_test.columns = ['dominant_id', 'p_wins']
chaser_counts_tube_test['behavior_type'] = 'Tube Test'

# Calculate the fraction of winning for pre tube test
chaser_counts_pre_tube_test = pre_tube_test['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_pre_tube_test.columns = ['dominant_id', 'p_wins']
chaser_counts_pre_tube_test['behavior_type'] = 'Pre Tube Test'

# Calculate the fraction of winning for post tube test
chaser_counts_post_tube_test = post_tube_test['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_post_tube_test.columns = ['dominant_id', 'p_wins']
chaser_counts_post_tube_test['behavior_type'] = 'Post Tube Test'

# Combine the data
combined_counts = pd.concat([
    chaser_counts_chasing, 
    chaser_counts_tube_test, 
    chaser_counts_pre_tube_test, 
    chaser_counts_post_tube_test
])

# Create the scatter plot
fig = px.scatter(
    combined_counts,
    x='behavior_type',
    y='p_wins',
    color='dominant_id',
    title='Dominance per Behavior Type and Dominant ID',
    labels={'behavior_type': 'Behavior Type', 'p_wins': 'Proportion of wins'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

Time binned measure of dominance:

In [None]:
# calculate te number of winning events over all events for every 42h time bin and plot over time
# Create a list of time bins
time_bins = pd.date_range(start=social_start_time, end=social_end_time, freq='24h')
# Bin the start_timestamp into the created time bins
chasing_df['time_bin'] = pd.cut(chasing_df['start_timestamp'], bins=time_bins)
# Calculate the number of winning events and total events in each time bin and id
winning_counts = chasing_df.groupby(['time_bin', 'dominant_id']).size().reset_index(name='total_events')
# Calculate the fraction of winning events for each dominant_id in each time bin
winning_counts['fraction_winning'] = winning_counts.groupby('time_bin')['total_events'].transform(lambda x: x / x.sum())

# Convert the time_bin column to strings
winning_counts['time_bin'] = winning_counts['time_bin'].astype(str)

# Create the line plot
fig = px.line(
    winning_counts,
    x='time_bin',
    y='total_events',
    color = 'dominant_id',
    title='Winning Events Over Time (Chasing)',
    labels={'time_bin': 'Time Bin', 'total_events': 'Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create the line plot
fig = px.line(
    winning_counts,
    x='time_bin',
    y='fraction_winning',
    color = 'dominant_id',
    title='Fraction of Winning Events Over Time (Chasing)',
    labels={'time_bin': 'Time Bin', 'fraction_winning': 'Proportion of Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

In [None]:
# calculate te number of winning events over all events for every 42h time bin and plot over time
# Create a list of time bins
time_bins = pd.date_range(start=social_start_time, end=social_end_time, freq='24h')
# Bin the start_timestamp into the created time bins
tube_test_df['time_bin'] = pd.cut(tube_test_df['start_timestamp'], bins=time_bins)
# Calculate the number of winning events and total events in each time bin and id
winning_counts = tube_test_df.groupby(['time_bin', 'dominant_id']).size().reset_index(name='total_events')
# Calculate the fraction of winning events for each dominant_id in each time bin
winning_counts['fraction_winning'] = winning_counts.groupby('time_bin')['total_events'].transform(lambda x: x / x.sum())

# Convert the time_bin column to strings
winning_counts['time_bin'] = winning_counts['time_bin'].astype(str)


# Create the line plot
fig = px.line(
    winning_counts,
    x='time_bin',
    y='total_events',
    color = 'dominant_id',
    title='Winning Events Over Time (Tube test)',
    labels={'time_bin': 'Time Bin', 'total_events': 'Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create the line plot
fig = px.line(
    winning_counts,
    x='time_bin',
    y='fraction_winning',
    color = 'dominant_id',
    title='Fraction of Winning Events Over Time (Tube test)',
    labels={'time_bin': 'Time Bin', 'fraction_winning': 'Proportion of Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()


Chasing speed and dominance:

In [None]:
block_restriction = {'block_start':'2024-02-10 16:08:15.027999'}

In [None]:
#speed of chasing and dominane - some chasing events are slower, so difernt behviour. maybe subordinate is chaser more in these slow chases?

#get centroid data during chasing events
block_subjects = (
    BlockAnalysis.Subject.proj('position_x', 'position_y', 'position_timestamps')
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
    #& block_restriction
)
block_subjects


In [None]:
block_subjects_dict = block_subjects.fetch(as_dict=True)

In [None]:
# Convert chasing_df timestamps to datetime
chasing_df['start_timestamp'] = pd.to_datetime(chasing_df['start_timestamp'])
chasing_df['end_timestamp'] = pd.to_datetime(chasing_df['end_timestamp'])

# Initialize an empty list to store the filtered position data
filtered_positions = []

# Loop through each row in chasing_df so we only get position that are within a chase
for _, row in chasing_df.iterrows():
    start_time = row['start_timestamp']
    end_time = row['end_timestamp']
    
    # Filter the position data for each subject
    for s in block_subjects_dict:
        # Convert position timestamps to datetime
        position_timestamps = pd.to_datetime(s['position_timestamps'])
        
        # Filter based on the start and end timestamps
        mask = (position_timestamps >= start_time) & (position_timestamps <= end_time)
        
        # Create a DataFrame for the filtered data
        filtered_data = pd.DataFrame(
      {
                "subject_name": [s["subject_name"]] * sum(mask),
                "position_timestamps": position_timestamps[mask],
                "position_x": pd.Series(s["position_x"])[mask].values,
                "position_y": pd.Series(s["position_y"])[mask].values,
                "start_time": [start_time] * sum(mask),
                "end_time": [end_time] * sum(mask)
            }
        )
        
        # Append the filtered data to the list
        filtered_positions.append(filtered_data)

# Concatenate the list of filtered data into a single DataFrame
subjects_positions_df = pd.concat(filtered_positions)

In [None]:
# Calculate the avg speed of each chasing event
# calculate the distance travelled by each subject and divide by the duration of chase
subjects_positions_df["speed"] = (
    subjects_positions_df.groupby("subject_name")[["position_x", "position_y"]].diff().apply(np.linalg.norm, axis=1)
    / subjects_positions_df.reset_index()
    .groupby("subject_name")["position_timestamps"]
    .diff()
    .dt.total_seconds()
    .values
)

# get rid of nans, unrealistically hihg speeds and infs
cm2px = 5.4 
max_speed_threshold = 100 * cm2px # in cm/s
speed_df = subjects_positions_df
speed_df = subjects_positions_df[subjects_positions_df['speed'] < max_speed_threshold]
speed_df = speed_df[~speed_df['speed'].isna()]
speed_df = speed_df[~speed_df['speed'].isin([np.inf, -np.inf])]
speed_df


In [None]:
subjects_positions_df.to_csv('/ceph/aeon/aeon/code/scratchpad/Orsi/all_chasing_videos/chasing_speed.csv', index=False)

In [None]:
# Step 1: Group speed_df by 'chase_id' and 'subject_name' to calculate avg speed per subject per chase
avg_speed_per_subject = (
    speed_df.groupby(['start_time', 'subject_name'])['speed']
    .mean()
    .reset_index()
    .rename(columns={'speed': 'avg_speed_per_subject'})
)

# Step 2: Now, group again by 'chase_id' to average over the subjects' average speeds
avg_speed_per_chase = (
    avg_speed_per_subject.groupby('start_time')['avg_speed_per_subject']
    .mean()
    .reset_index()
    .rename(columns={'avg_speed_per_subject': 'avg_speed_per_chase'})
)
avg_speed_per_chase.rename(columns={'start_time': 'start_timestamp'}, inplace=True)



In [None]:
chasing_speed_df = chasing_df.merge(avg_speed_per_chase, on='start_timestamp', how='left')
#get rid of rows whre missings speed
chasing_speed_df = chasing_speed_df[~chasing_speed_df['avg_speed_per_chase'].isna()]
chasing_speed_df

In [None]:
# did we lose a lit of chases
print(chasing_df.shape)
print(chasing_speed_df.shape)

In [None]:
nan_count = chasing_speed_df['dominant_id'].isna().sum()
print(f"Number of NaN values in dominant_id: {nan_count}")

In [None]:
# Define the two chaser IDs you want to compare
chaser1_id = 'BAA-1104045'  
chaser2_id = 'BAA-1104047'  

# Filter data for the two chaser IDs
chaser1_speeds = chasing_speed_df[chasing_speed_df['dominant_id'] == chaser1_id]['avg_speed_per_chase']
chaser2_speeds = chasing_speed_df[chasing_speed_df['dominant_id'] == chaser2_id]['avg_speed_per_chase']

# Perform the t-test
t_stat, p_value_ttest = stats.ttest_ind(chaser1_speeds, chaser2_speeds, equal_var=False)
p_value_ttest


In [None]:
#plot the speed of chases per chaser id
# Initialize the figure
fig = go.Figure()

for dominant_id in chasing_speed_df['dominant_id'].unique():
    if pd.isna(dominant_id):
        color = 'gray'  # Assign a default color for NaN values
    else:
        color = id_color_map.get(dominant_id, 'black')  # Fallback to black if the id is not in the map
    
    dominant_df = chasing_speed_df[chasing_speed_df['dominant_id'] == dominant_id]
    fig.add_trace(go.Box(
        x=dominant_df['dominant_id'],
        y=dominant_df['avg_speed_per_chase'],
        name=str(dominant_id),  # Convert NaN to string to display properly
        boxpoints='all',  # Show individual points
        jitter=0.3,  # Add some jitter to avoid overlap
        pointpos=-1.8,  # Position of the individual points (to the right of the box)
        marker=dict(color=color)
    ))
# Update layout
fig.update_layout(
    title='Average Speed of Chases per Chaser ID',
    xaxis_title='Chaser ID',
    yaxis_title='Average Speed',
    boxmode='group',  # Group box plots by chaser ID
    showlegend=False
)

# Add annotations for statistical test results
fig.add_annotation(
    x=0.5, y=1.05, xref='paper', yref='paper',
    text=f"T-test p-value: {p_value_ttest:.3f}",
    showarrow=False,
    font=dict(size=12, color='black'),
    align='center'
)


# Show the plot
fig.show()

Takeaway: 
- dominance based on the outcome of these interactions is stable throughtout
- number of fights and chases increase throughtout, while tube tests are most frequent on firs day
- both tube test adn chasing outcomes aer more varied on first they and stabilise after 
- speed of chases is variable, and when more dominant moue is the chaser, chases are faster
- stability shows these behaviours as good measure of dominance

# 2. Dominance and foraging

In [None]:
dominant_id = 'BAA-1104047' 
subordinate_id = 'BAA-1104045'

Based on interactions we know which one is dominant and this seems to be stable. So does dominance affect foraging behaviour? In solo vs social?

Number of pellets per subject:

In [None]:
#get pellet data
foraging_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps', 'patch_threshold')
    & {"spinnaker_video_source_name": "CameraTop"} #this is the video source name which we rstrict once we selected tuff to keep in table
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
)
foraging_query


In [None]:
solo_foraging_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps', 'patch_threshold')
    & {"spinnaker_video_source_name": "CameraTop"}  # this is the video source name which we restrict once we selected stuff to keep in table
    & key
    & (
        f'(block_start >= "{pre_solo_start_day}" AND block_start <= "{pre_solo_end_day}")'
        f' OR '
        f'(block_start >= "{post_solo_start_day}" AND block_start <= "{post_solo_end_day}")'
    )
)
solo_foraging_query

In [None]:
# Fetch the data
pellet_data = foraging_query.fetch()

# Initialize an empty list to store the data
data = []

# Loop through each entry in the pellet_data array
for entry in pellet_data:
    experiment_name = entry[0]
    block_start = entry[1]
    patch_name = entry[2]
    subject_id = entry[3]
    pellet_count = entry[4]
    pellet_timestamps = entry[5]
    patch_threshold =  entry[6]

    
    # For each pellet timestamp, create a dictionary and append to the list
    for pellet_timestamp, threshold in zip(pellet_timestamps, patch_threshold):
        data.append({
            'time': pellet_timestamp,
            'subject_id': subject_id,
            'threshold': threshold,
            'rank': 'dominant' if subject_id == dominant_id else 'subordinate',
        })
# Convert the list of dictionaries into a DataFrame
pellet_df = pd.DataFrame(data)

# Convert the 'time' column to datetime
pellet_df['time'] = pd.to_datetime(pellet_df['time'])

# Display the DataFrame
print(pellet_df)


In [None]:
# filter out pellets that are less than 2s after previos pellet for teh subject
pellet_df['time_diff'] = pellet_df.groupby('subject_id')['time'].diff()
pellet_df['time_diff'] = pellet_df['time_diff'].dt.total_seconds()
pellet_df = pellet_df[pellet_df['time_diff'] > 2]
pellet_df

In [None]:
# Fetch the data
solo_pellet_data = solo_foraging_query.fetch()

# Initialize an empty list to store the data
data = []

# Loop through each entry in the pellet_data array
for entry in solo_pellet_data:
    experiment_name = entry[0]
    block_start = entry[1]
    patch_name = entry[2]
    subject_id = entry[3]
    pellet_count = entry[4]
    pellet_timestamps = entry[5]
    patch_threshold =  entry[6]

    
    # For each pellet timestamp, create a dictionary and append to the list
    for pellet_timestamp, threshold in zip(pellet_timestamps, patch_threshold):
        data.append({
            'time': pellet_timestamp,
            'subject_id': subject_id,
            'threshold': threshold,
            'rank': 'dominant' if subject_id == dominant_id else 'subordinate',
        })
# Convert the list of dictionaries into a DataFrame
solo_pellet_df = pd.DataFrame(data)

# Convert the 'time' column to datetime
solo_pellet_df['time'] = pd.to_datetime(solo_pellet_df['time'])

# Display the DataFrame
print(solo_pellet_df)

# filter out pellets that are less than 2s after previos pellet for teh subject
solo_pellet_df['time_diff'] = solo_pellet_df.groupby('subject_id')['time'].diff()
solo_pellet_df['time_diff'] = solo_pellet_df['time_diff'].dt.total_seconds()
solo_pellet_df = solo_pellet_df[solo_pellet_df['time_diff'] > 2]
solo_pellet_df

In [None]:
pellet_df['period'] = 'social'
solo_pellet_df['period'] = 'pre_solo' if solo_pellet_df['time'].iloc[0] < pd.Timestamp(pre_solo_end_day) else 'post_solo'
all_pellet_df = pd.concat([pellet_df, solo_pellet_df])
all_pellet_df

In [None]:

def create_raster_plot(df, period, start_time, end_time, start_datetime, end_datetime, id_color_map):
    # Create the raster plot
    fig = px.scatter(
        df[df['period'] == period],
        x='time',
        y='rank',
        color='subject_id',
        title=f'Pellet Raster Plot - {period.capitalize()}',
        labels={'time': 'Time', 'behavior_type': 'Behavior Type'},
        color_discrete_map=id_color_map
    )

    # Set x-axis limits
    fig.update_xaxes(range=[start_time, end_time])  # Example date range

    # Iterate over each day in the period
    current_day = start_datetime
    while current_day < end_datetime:
        # Define the start and end times for each period
        night_start_time = current_day.replace(hour=int(night_start.split(':')[0]), minute=int(night_start.split(':')[1]))
        night_end_time = current_day.replace(hour=int(night_end.split(':')[0]), minute=int(night_end.split(':')[1]))
        twilight_start_time = current_day.replace(hour=int(twilight_start.split(':')[0]), minute=int(twilight_start.split(':')[1]))
        twilight_end_time = current_day.replace(hour=int(twilight_end.split(':')[0]), minute=int(twilight_end.split(':')[1]))
        dawn_start_time = current_day.replace(hour=int(dawn_start.split(':')[0]), minute=int(dawn_start.split(':')[1]))
        dawn_end_time = current_day.replace(hour=int(dawn_end.split(':')[0]), minute=int(dawn_end.split(':')[1]))
        day_start_time = current_day.replace(hour=int(day_start.split(':')[0]), minute=int(day_start.split(':')[1]))
        day_end_time = current_day + timedelta(days=1) if int(day_end.split(':')[0]) < int(day_start.split(':')[0]) else current_day.replace(hour=int(day_end.split(':')[0]), minute=int(day_end.split(':')[1]))
        
        # Add horizontal lines for light and dark periods
        fig.add_shape(
            type="line",
            x0=twilight_start_time,
            x1=twilight_end_time,
            y0=-0.5,
            y1=-0.5,
            line=dict(color="gray", width=4)
        )
        fig.add_shape(
            type="line",
            x0=night_start_time,
            x1=night_end_time,
            y0=-0.5,
            y1=-0.5,
            line=dict(color="black", width=4)
        )
        fig.add_shape(
            type="line",
            x0=dawn_start_time,
            x1=dawn_end_time,
            y0=-0.5,
            y1=-0.5,
            line=dict(color="gray", width=4)
        )
        fig.add_shape(
            type="line",
            x0=day_start_time,
            x1=day_end_time,
            y0=-0.5,
            y1=-0.5,
            line=dict(color="white", width=4)
        )
        
        # Move to the next day
        current_day += timedelta(days=1)

    # Show the plot
    fig.show()

# Define the parameters for each period
periods = {
    'pre_solo': {
        'start_time': pre_solo_start_day,
        'end_time': pre_solo_end_day,
        'start_datetime': pre_solo_start_datetime,
        'end_datetime': pre_solo_end_datetime
    },
    'post_solo': {
        'start_time': post_solo_start_day,
        'end_time': post_solo_end_day,
        'start_datetime': post_solo_start_datetime,
        'end_datetime': post_solo_end_datetime
    },
    'social': {
        'start_time': social_start_time,
        'end_time': social_end_time,
        'start_datetime': social_start_datetime,
        'end_datetime': social_end_datetime
    }
}

# Create the raster plot for each period
for period, params in periods.items():
    create_raster_plot(
        all_pellet_df,
        period,
        params['start_time'],
        params['end_time'],
        params['start_datetime'],
        params['end_datetime'],
        id_color_map
    )

In [None]:
# Define the time bin size (e.g., 1 hour)
time_bin_size = '1h'
# Create a copy of the DataFrame to avoid modifying the original
pellet_df_copy = pellet_df.copy()

# Set the 'time' as the DataFrame index
pellet_df_copy.set_index('time', inplace=True)

# Bin the data by time and count the number of pellets per bin per subject
binned_pellet_df = pellet_df_copy.groupby([pd.Grouper(freq=time_bin_size), 'subject_id']).size().reset_index(name='pellet_count')

# Create a complete time range for the bins
time_range = pd.date_range(start=social_start_time, end=social_end_time, freq=time_bin_size)

# Get unique subject IDs
subject_ids = binned_pellet_df['subject_id'].unique()

# Create a MultiIndex with all combinations of time_range and subject_ids
multi_index = pd.MultiIndex.from_product([time_range, subject_ids], names=['time', 'subject_id'])

# Create an empty DataFrame with the MultiIndex
complete_df = pd.DataFrame(index=multi_index).reset_index()

# Merge the complete DataFrame with the binned pellet data
complete_df = complete_df.merge(binned_pellet_df, on=['time', 'subject_id'], how='left').fillna(0)

# Ensure pellet_count is an integer
complete_df['pellet_count'] = complete_df['pellet_count'].astype(int)

# Create the line plot with resampled and interpolated data
fig = px.line(
    complete_df,
    x='time',
    y='pellet_count',
    color='subject_id',
    title='Pellets during Social Period (1h bins)',
    labels={'pellet_count': 'Pellet Count', 'time': 'Time', 'subject_id': 'Subject ID'},    
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

In [None]:
def create_binned_plot(df, period, start_time, end_time, id_color_map, time_bin_size = '1h'):
    # Define the time bin size (e.g., 1 hour)
    time_bin_size
    
    # Filter the DataFrame for the given period
    period_df = df[(df['time'] >= start_time) & (df['time'] <= end_time)].copy()
    
    # Set the 'time' as the DataFrame index
    period_df.set_index('time', inplace=True)
    
    # Bin the data by time and count the number of pellets per bin per subject
    binned_df = period_df.groupby([pd.Grouper(freq=time_bin_size), 'subject_id']).size().reset_index(name='pellet_count')
    
    # Create a complete time range for the bins
    time_range = pd.date_range(start=start_time, end=end_time, freq=time_bin_size)
    
    # Get unique subject IDs
    subject_ids = binned_df['subject_id'].unique()
    
    # Create a MultiIndex with all combinations of time_range and subject_ids
    multi_index = pd.MultiIndex.from_product([time_range, subject_ids], names=['time', 'subject_id'])
    
    # Create an empty DataFrame with the MultiIndex
    complete_df = pd.DataFrame(index=multi_index).reset_index()
    
    # Merge the complete DataFrame with the binned pellet data
    complete_df = complete_df.merge(binned_df, on=['time', 'subject_id'], how='left').fillna(0)
    
    # Ensure pellet_count is an integer
    complete_df['pellet_count'] = complete_df['pellet_count'].astype(int)
    
    # Number the bins
    complete_df['time_bin'] = complete_df.groupby('subject_id').cumcount() + 1
    
    # Create the line plot with resampled and interpolated data
    fig = px.line(
        complete_df,
        x='time_bin',
        y='pellet_count',
        color='subject_id',
        title=f'Pellets during {period.capitalize()} Period (1h bins)',
        labels={'pellet_count': 'Pellet Count', 'time_bin': 'Time Bin', 'subject_id': 'Subject ID'},
        color_discrete_map=id_color_map
    )
    
    # Show the plot
    fig.show()

# Define the parameters for each period
periods = {
    'pre_solo': {
        'start_time': pre_solo_start_day,
        'end_time': pre_solo_end_day
    },
    'post_solo': {
        'start_time': post_solo_start_day,
        'end_time': post_solo_end_day
    },
    'social': {
        'start_time': social_start_time,
        'end_time': social_end_time
    }
}

# Create the binned plot for each period
for period, params in periods.items():
    create_binned_plot(
        all_pellet_df,
        period,
        params['start_time'],
        params['end_time'],
        id_color_map
    )

In [None]:
for period, params in periods.items():
    create_binned_plot(
        all_pellet_df,
        period,
        params['start_time'],
        params['end_time'],
        id_color_map,
        time_bin_size = '24h'
    )

In [None]:
pellet_df_copy

Foraging during the experiment:

In [None]:
# Define the time bin size (e.g., 1 hour)
time_bin_size = '24h'
# Create a copy of the DataFrame to avoid modifying the original
pellet_df_copy = pellet_df.copy()

# Set the 'time' as the DataFrame index
pellet_df_copy.set_index('time', inplace=True)

# Bin the data by time and count the number of pellets per bin per subject
binned_pellet_df = pellet_df_copy.groupby([pd.Grouper(freq=time_bin_size), 'subject_id']).size().reset_index(name='pellet_count')

# Create a complete time range for the bins
time_range = pd.date_range(start=pd.to_datetime(social_start_day) + pd.Timedelta(days=1), end=pd.to_datetime(social_end_day) - pd.Timedelta(days=1), freq=time_bin_size)

# Get unique subject IDs
subject_ids = binned_pellet_df['subject_id'].unique()

# Create a MultiIndex with all combinations of time_range and subject_ids
multi_index = pd.MultiIndex.from_product([time_range, subject_ids], names=['time', 'subject_id'])

# Create an empty DataFrame with the MultiIndex
complete_df = pd.DataFrame(index=multi_index).reset_index()

# Merge the complete DataFrame with the binned pellet data
complete_df = complete_df.merge(binned_pellet_df, on=['time', 'subject_id'], how='left').fillna(0)

# Ensure pellet_count is an integer
complete_df['pellet_count'] = complete_df['pellet_count'].astype(int)

# Create the line plot with resampled and interpolated data
fig = px.line(
    complete_df,
    x='time',
    y='pellet_count',
    color='subject_id',
    title='Pellets during Social Period (24h bins)',
    labels={'pellet_count': 'Pellet Count', 'time': 'Time', 'subject_id': 'Subject ID'},    
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

Pellet threshold distributions:

In [None]:

# Define the time bin size (e.g., 24 hours)
time_bin_size = '24h'

# Create a copy of the DataFrame to avoid modifying the original
pellet_df_copy = pellet_df.copy()

# Set the 'time' as the DataFrame index
pellet_df_copy.set_index('time', inplace=True)

# Bin the data by time and calculate the mean, standard deviation, and SEM of the threshold per bin per subject
binned_threshold_df = pellet_df_copy.groupby([pd.Grouper(freq=time_bin_size), 'subject_id'])['threshold'].agg(['mean', 'std', 'count']).reset_index()
binned_threshold_df.rename(columns={'mean': 'average_threshold', 'std': 'std_threshold', 'count': 'n'}, inplace=True)

# Calculate the standard error of the mean (SEM)
binned_threshold_df['sem_threshold'] = binned_threshold_df['std_threshold'] / np.sqrt(binned_threshold_df['n'])

# Create a complete time range for the bins
time_range = pd.date_range(start=pd.to_datetime(social_start_day) + pd.Timedelta(days=1), end=pd.to_datetime(social_end_day) - pd.Timedelta(days=1), freq=time_bin_size)

# Get unique subject IDs
subject_ids = binned_threshold_df['subject_id'].unique()

# Create a MultiIndex with all combinations of time_range and subject_ids
multi_index = pd.MultiIndex.from_product([time_range, subject_ids], names=['time', 'subject_id'])

# Create an empty DataFrame with the MultiIndex
complete_df = pd.DataFrame(index=multi_index).reset_index()

# Merge the complete DataFrame with the binned threshold data
complete_df = complete_df.merge(binned_threshold_df, on=['time', 'subject_id'], how='left').fillna(0)

# Create a figure
fig = go.Figure()

# Add the line plot with SEM error bars for each subject
for subject_id in complete_df['subject_id'].unique():
    subject_df = complete_df[complete_df['subject_id'] == subject_id]
    
    fig.add_trace(go.Scatter(
        x=subject_df['time'],
        y=subject_df['average_threshold'],
        mode='lines+markers',
        name=f'{subject_id} Average Threshold',
        line=dict(color=id_color_map[subject_id]),
        error_y=dict(
            type='data',
            array=subject_df['sem_threshold'],  # Error bar data (standard error of the mean)
            visible=True
        )
    ))

# Update layout
fig.update_layout(
    title='Average Threshold with SEM Error Bars during Social Period (24h bins)',
    xaxis_title='Time',
    yaxis_title='Average Threshold',
    showlegend=True
)

# Show the plot
fig.show()


Foraging during the day:

In [None]:
# Define the time bin size (e.g., 1 hour)
time_bin_size = '1h'
# Create a copy of the DataFrame to avoid modifying the original
pellet_df_copy = pellet_df.copy()

# Set the 'time' as the DataFrame index
pellet_df_copy.set_index('time', inplace=True)

# Bin the data by time and count the number of pellets per bin per subject
binned_pellet_df = pellet_df_copy.groupby([pd.Grouper(freq=time_bin_size), 'subject_id']).size().reset_index(name='pellet_count')

# Create a complete time range for the bins
time_range = pd.date_range(start=social_start_time, end=social_end_time, freq=time_bin_size)

# Get unique subject IDs
subject_ids = binned_pellet_df['subject_id'].unique()

# Create a MultiIndex with all combinations of time_range and subject_ids
multi_index = pd.MultiIndex.from_product([time_range, subject_ids], names=['time', 'subject_id'])

# Create an empty DataFrame with the MultiIndex
complete_df = pd.DataFrame(index=multi_index).reset_index()

# Merge the complete DataFrame with the binned pellet data
complete_df = complete_df.merge(binned_pellet_df, on=['time', 'subject_id'], how='left').fillna(0)

# Ensure pellet_count is an integer
complete_df['pellet_count'] = complete_df['pellet_count'].astype(int)

# Reset the index to get 'timestamp' back as a column
complete_df.reset_index(inplace=True)

# Calculate the average number of pellets per hour of day for each mouse
complete_df['hour_of_day'] = complete_df['time'].dt.hour

average_pellets_per_hour = complete_df.groupby(['subject_id', 'hour_of_day'])['pellet_count'].mean().reset_index()


# Create the figure
fig = go.Figure()

# Add individual daily data traces with faint lines
for subject_id in complete_df['subject_id'].unique():
    subject_data = complete_df[complete_df['subject_id'] == subject_id]
    fig.add_trace(go.Scatter(
        x=subject_data['hour_of_day'],
        y=subject_data['pellet_count'],
        mode='lines',
        name=f'{subject_id} (daily)',
        line=dict(color=id_color_map[subject_id], width=1),
        opacity=0.2
    ))

# Add the average data with more prominent lines
for subject_id in average_pellets_per_hour['subject_id'].unique():
    avg_data = average_pellets_per_hour[average_pellets_per_hour['subject_id'] == subject_id]
    fig.add_trace(go.Scatter(
        x=avg_data['hour_of_day'],
        y=avg_data['pellet_count'],
        mode='lines',
        name=f'{subject_id} (average)',
        line=dict(color=id_color_map[subject_id], width=2)
    ))

# Update layout
fig.update_layout(
    title='Pellets throughout the day',
    xaxis_title='Time',
    yaxis_title='Pellet Count',
    legend_title='Subject ID'
)

# Show the plot
fig.show()

Patch preference:

In [None]:
preference_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps', 'patch_threshold')
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
)
preference_query

In [None]:
#get pellet data
preference_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps', 'patch_threshold')
    * BlockSubjectAnalysis.Preference.proj('final_preference_by_wheel','final_preference_by_time')
    * BlockAnalysis.Patch.proj('patch_rate', 'patch_offset')
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
)
preference_query

In [None]:
preference_data = preference_query.fetch()


In [None]:
# Initialize an empty list to store the data
preference_list = []

# Loop through each entry in the pellet_data array
for entry in preference_data:
    experiment_name = entry[0]
    block_start = entry[1]
    patch_name = entry[2]
    subject_id = entry[3]
    pellet_count = entry[4]
    #pellet_timestamps = entry[5]
    #patch_threshold = entry[6]
    final_preference_by_wheel = entry[7]
    final_preference_by_time = entry[8]
    patch_rate = entry[9]
    patch_offset = entry[10]
    rank = 'dominant' if subject_id == dominant_id else 'subordinate'
    patch_type = 'easy' if patch_rate == 0.01 else 'medium' if patch_rate == 0.0033 else 'hard' if patch_rate == 0.002 else 'unknown'

# Create a dictionary for the current entry
    data_dict = {
        'experiment_name': experiment_name,
        'block_start': block_start,
        'patch_name': patch_name,
        'subject_id': subject_id,
        'pellet_count': pellet_count,
        'final_preference_by_wheel': final_preference_by_wheel,
        'final_preference_by_time': final_preference_by_time,
        'patch_rate': patch_rate,
        'patch_offset': patch_offset,
        'rank': rank,
        'patch_type': patch_type
    }
    
    # Append the dictionary to the list
    preference_list.append(data_dict)

# Convert the list of dictionaries into a DataFrame
preference_df = pd.DataFrame(preference_list)

# Convert the 'block_start' column to datetime
preference_df['block_start'] = pd.to_datetime(preference_df['block_start'])

# Display the DataFrame
preference_df.head()

In [None]:
# Filter for blocks where both subjects got above the threshold pellets overall
pellet_threshold = 10

# Group by block_start, patch_name, and subject_id and calculate the sum of pellet_count
grouped_df = preference_df.groupby(['block_start', 'patch_name', 'subject_id']).agg({'pellet_count': 'sum'}).reset_index()

# Filter the groups where the sum of pellet_count is above the threshold
filtered_groups = grouped_df[grouped_df['pellet_count'] >= pellet_threshold]

# Get unique values for block_start, patch_name, and subject_id
unique_blocks = filtered_groups['block_start'].unique()
unique_patches = preference_df['patch_name'].unique()
unique_subjects = preference_df['subject_id'].unique()

# Create a DataFrame with all possible combinations of block_start, patch_name, and subject_id
all_combinations = pd.MultiIndex.from_product([unique_blocks, unique_patches, unique_subjects], names=['block_start', 'patch_name', 'subject_id']).to_frame(index=False)

# Merge the all_combinations DataFrame with the original DataFrame to retain all columns
preference_df_filtered = all_combinations.merge(preference_df, on=['block_start', 'patch_name', 'subject_id'], how='left').fillna(0)

preference_df_filtered

In [None]:
# Define category order for the x-axis
category_order = ['easy', 'medium', 'hard']

# Initialize the figure
fig = go.Figure()

# Add box plot for each subject
for subject_id in preference_df_filtered['subject_id'].unique():
    subject_df = preference_df_filtered[preference_df_filtered['subject_id'] == subject_id]
    fig.add_trace(go.Box(
        x=subject_df['patch_type'],
        y=subject_df['pellet_count'],
        name=subject_id,
        boxpoints='all',  # Show indiv points
        jitter=0.25,  # Add some jitter for individual points
        pointpos=-1.4,  # Position of the individual points (to the right)
        marker=dict(color=id_color_map[subject_id])
    ))

# Update layout
fig.update_layout(
    title='Pellets per Patch Rate',
    xaxis_title='Patch Rate',
    yaxis_title='Pellet Count',
    boxmode='group',  # Group box plots by patch_type
    xaxis=dict(categoryorder='array', categoryarray=category_order),  
    showlegend=True
)

fig.show()


In [None]:
# Define category order for the x-axis
category_order = ['easy', 'medium', 'hard']

# Initialize the figure
fig = go.Figure()

# Add box plot for each subject
for subject_id in preference_df_filtered['subject_id'].unique():
    subject_df = preference_df_filtered[preference_df_filtered['subject_id'] == subject_id]
    fig.add_trace(go.Box(
        x=subject_df['patch_type'],
        y=subject_df['final_preference_by_time'],
        name=subject_id,
        boxpoints='all',  # Show indiv points
        jitter=0.25,  # Add some jitter for individual points
        pointpos=-1.4,  # Position of the individual points (to the right)
        marker=dict(color=id_color_map[subject_id])
    ))

# Update layout
fig.update_layout(
    title='Preference Index per Patch Rate',
    xaxis_title='Patch Rate',
    yaxis_title='Preference Index (time)',
    boxmode='group',  # Group box plots by patch_type
    xaxis=dict(categoryorder='array', categoryarray=category_order),  
    showlegend=True
)

fig.show()

In [None]:
# Define category order for the x-axis
category_order = ['easy', 'medium', 'hard']

# Initialize the figure
fig = go.Figure()

# Add box plot for each subject
for subject_id in preference_df_filtered['subject_id'].unique():
    subject_df = preference_df_filtered[preference_df_filtered['subject_id'] == subject_id]
    fig.add_trace(go.Box(
        x=subject_df['patch_type'],
        y=subject_df['final_preference_by_wheel'],
        name=subject_id,
        boxpoints='all',  # Show indiv points
        jitter=0.25,  # Add some jitter for individual points
        pointpos=-1.4,  # Position of the individual points (to the right)
        marker=dict(color=id_color_map[subject_id])
    ))

# Update layout
fig.update_layout(
    title='Preference Index per Patch Rate',
    xaxis_title='Patch Rate',
    yaxis_title='Preference Index (wheel)',
    boxmode='group',  # Group box plots by patch_type
    xaxis=dict(categoryorder='array', categoryarray=category_order),  
    showlegend=True
)

fig.show()

In [None]:

# Ensure block_start is a datetime object
preference_df_filtered['block_start'] = pd.to_datetime(preference_df_filtered['block_start'])

# Group by day, subject_id, and patch_type, and calculate the average final_preference_by_wheel
preference_df_filtered['day'] = preference_df_filtered['block_start'].dt.date
daily_avg_df = preference_df_filtered.groupby(['day', 'subject_id', 'patch_type']).agg({'final_preference_by_wheel': 'mean'}).reset_index()

# Exclude incomplete days
first_day = daily_avg_df['day'].min()
last_day = daily_avg_df['day'].max()
filtered_daily_avg_df = daily_avg_df[(daily_avg_df['day'] != first_day) & (daily_avg_df['day'] != last_day)]

# Create the plot
fig = px.line(
    filtered_daily_avg_df,
    x='day',
    y='final_preference_by_wheel',
    color='subject_id',
    line_dash='patch_type',
    title='Average Daily Preference Index by Day',
    labels={'day': 'Day', 'final_preference_by_wheel': 'Average Preference Index'},
    category_orders={'patch_type': ['easy', 'medium', 'hard']},
    color_discrete_map=id_color_map
)

fig.show()

In [None]:
# Ensure block_start is a datetime object
preference_df_filtered['block_start'] = pd.to_datetime(preference_df_filtered['block_start'])

# Group by day, subject_id, and patch_type, and calculate the average final_preference_by_wheel
preference_df_filtered['day'] = preference_df_filtered['block_start'].dt.date
daily_avg_df = preference_df_filtered.groupby(['day', 'subject_id', 'patch_type']).agg({'pellet_count': 'mean'}).reset_index()

# Exclude incomplete days
first_day = daily_avg_df['day'].min()
last_day = daily_avg_df['day'].max()
filtered_daily_avg_df = daily_avg_df[(daily_avg_df['day'] != first_day) & (daily_avg_df['day'] != last_day)]

# Create the plot
fig = px.line(
    filtered_daily_avg_df,
    x='day',
    y='pellet_count',
    color='subject_id',
    line_dash='patch_type',
    title='Average Daily Pellets by Day',
    labels={'day': 'Day', 'pellet_count': 'Pellets'},
    category_orders={'patch_type': ['easy', 'medium', 'hard']},
    color_discrete_map=id_color_map
)

fig.show()

Weight:

In [None]:
foraging_query = (
    BlockAnalysis.Subject.proj('weights', 'weight_timestamps')
    & {"spinnaker_video_source_name": "CameraTop"} #this is the video source name which we rstrict once we selected tuff to keep in table
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
)
foraging_query
weight_data =  foraging_query.fetch()
type(weight_data)
weight_data
weight_data[0][3]
weight_data[0][4]
# Initialize an empty list to store the data
data = []

# Loop through each entry in the weight_data array
for entry in weight_data:
    subject_name = entry[0]
    block_start = entry[1]
    subject_id = entry[2]
    weights = entry[3]
    weight_timestamps = entry[4]
    
    # For each weight measurement, create a dictionary and append to the list
    for time, weight in zip(weight_timestamps, weights):
        data.append({'time': time, 'weight': weight, 'subject_id': subject_id})

# Convert the list of dictionaries into a DataFrame
weight_df = pd.DataFrame(data)

# Convert the 'time' column to datetime
weight_df['time'] = pd.to_datetime(weight_df['time'])

# Display the DataFrame
print(weight_df)

In [None]:
weight_data

In [None]:
#data cleaning: filter out measurements below 25g
weight_df = weight_df[weight_df['weight'] > 25]

In [None]:

# Create the line plot
fig = px.line(
    weight_df,
    x='time',
    y='weight',
    color='subject_id',
    title='Weight during social period',
    labels={'weight': 'Weight (g)', 'time': 'Time', 'subject_id': 'Subject ID'} ,
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

In [None]:
#TODO: check raw data
# Set the 'time' column as the index
time_weight_df = weight_df.set_index('time', inplace=False)

# Resample the data into hourly bins and calculate the average weight for each bin
avg_weight_df = time_weight_df.groupby('subject_id').resample('24h').mean()
#print(avg_weight_df)

# Reset the index to use 'time' as a column again
avg_weight_df.reset_index(inplace=True)

# Create the line plot with resampled and interpolated data
fig = px.line(
    avg_weight_df,
    x='time',
    y='weight',
    color='subject_id',
    title='Average Weight during Social Period (Daily)',
    labels={'weight': 'Weight (g)', 'time': 'Time', 'subject_id': 'Subject ID'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

Takeaways:
- dominant and subordinate mouse forage similar amounts, daily pattern similar, preference similar
TODO: who starts forgaing bouts?

# 3. Social interactions and foraging

## 3.1 Do social events influence foraging behaviour?

Pellet counts after social interactions vs random times:

In [None]:
#getting interactions not folowed by other interactions
time_window_length = 5  # in minutes

# Initialize a DataFrame to store good chases
good_events_df = pd.DataFrame()

# Iterate through each row in chasing_df
for index, event in combined_df.iterrows():
    event_time = event['start_timestamp']

    # Calculate the start and end time of the time window
    start_time = event_time
    end_time = event_time + pd.Timedelta(minutes=time_window_length)
    
    # Filter out events that happened within the time window
    events_in_time_window = combined_df[
        (combined_df['start_timestamp'] >= start_time) & 
        (combined_df['start_timestamp'] <= end_time) & 
        (combined_df.index != index)  # Exclude the current event
    ]
    
    # If no events are found in the time window, add the chase to good_chases_df
    if events_in_time_window.empty:
        good_events_df = pd.concat([good_events_df, event.to_frame().T])
        


In [None]:
good_events_df

Get active times that are not sleeping:

Only do this in good forgating blocks, where at least 5 pellets obtained:

In [None]:
foraging_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps')
    & {"spinnaker_video_source_name": "CameraTop"} #this is the video source name which we rstrict once we selected tuff to keep in table
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
)
pellet_data = foraging_query.fetch()

In [None]:
# Fetch the data
pellet_data = foraging_query.fetch()

# Convert the fetched data into a DataFrame
pellet_df = pd.DataFrame(pellet_data, columns=['experiment_name', 'block_start', 'patch_name', 'subject_id', 'pellet_count', 'pellet_timestamps'])

# Group by block_start and sum the pellet_count
grouped_pellet_df = pellet_df.groupby('block_start')['pellet_count'].sum().reset_index()

# Filter blocks where the total pellet count is greater than 3
good_blocks_df = grouped_pellet_df[grouped_pellet_df['pellet_count'] > 15]

# Extract block_start values as a list
good_block_starts = good_blocks_df['block_start'].tolist()

# Convert block_start timestamps to strings in the correct format
good_block_starts_str = [block_start.strftime('%Y-%m-%d %H:%M:%S.%f') for block_start in good_block_starts]
print(len(good_block_starts_str))

In [None]:
# Create a string that matches SQL's IN clause format
block_starts_sql_list = ', '.join(f'"{block_start}"' for block_start in good_block_starts_str)
len(block_starts_sql_list)

In [None]:
# Restrict block_subjects query to only include blocks in good_block_starts
block_subjects = (
    BlockAnalysis.Subject.proj('position_x', 'position_y', 'position_timestamps')
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
    & f'block_start IN ({block_starts_sql_list})'   # Restrict to block_start values in good_blocks_df
)
block_subjects

Get posiiton data for one block:

In [None]:
# do only for one example block
block_restriction = {'block_start':'2024-02-20 10:34:17.001984'}

#get centoid df
block_subjects = (
    BlockAnalysis.Subject.proj('position_x', 'position_y', 'position_timestamps')
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
    & block_restriction
)
block_subjects

In [None]:
block_subjects_dict = block_subjects.fetch(as_dict=True)

In [None]:
# Construct subject position dataframe
subjects_positions_df = pd.DataFrame
subjects_positions_df = pd.concat(
    #the folllwing list comperhension makes df for each subject
    [
        pd.DataFrame(  #make df from dict for both subjects
            {"subject_name": [s["subject_name"]] * len(s["position_timestamps"])}  #make subject_name colunm filled iwth name for all tie steps
            | { #merge dicts, start dict complehension
                k: s[k] #for each key k get corrsponding list from s subject
                for k in ( #iterate ove these keys
                    "position_timestamps",
                    "position_x",
                    "position_y",
                )
            }
        )
        for s in block_subjects_dict
    ]
)
#subjects_positions_df.set_index("position_timestamps", inplace=True) #make timestamps the row labels
subjects_positions_df

In [None]:
# sometimes both tracks ae assinged to same id
# if one subject_id has two tracks for the same tiemstamps and the other subjectid as no track for that timestmp, assign hte second value to the other subjectid

# Sample DataFrames for demonstration
# Replace these with your actual DataFrames
positions_subject0 = subjects_positions_df[subjects_positions_df['subject_name'] == 'BAA-1104045']
positions_subject1 = subjects_positions_df[subjects_positions_df['subject_name'] == 'BAA-1104047']

# Identify timestamps with multiple entries for each subject
count_subject0 = positions_subject0['position_timestamps'].value_counts()
count_subject1 = positions_subject1['position_timestamps'].value_counts()

# Identify timestamps where one subject has multiple entries and the other has none
duplicate_subject0 = count_subject0[count_subject0 > 1].index
missing_subject1 = duplicate_subject0.difference(count_subject1.index)

duplicate_subject1 = count_subject1[count_subject1 > 1].index
missing_subject0 = duplicate_subject1.difference(count_subject0.index)

# Initialize lists to hold the new rows for each subject
new_rows_subject0 = []
new_rows_subject1 = []

# For timestamps where subject0 has duplicates and subject1 has none
for timestamp in missing_subject1:
    duplicate_rows = positions_subject0[positions_subject0['position_timestamps'] == timestamp]
    if len(duplicate_rows) > 1:
        # Reassign every second row to subject1
        extra_rows = duplicate_rows.iloc[1:]
        positions_subject0.loc[extra_rows.index, 'subject_name'] = 'BAA-1104047'

# For timestamps where subject1 has duplicates and subject0 has none
for timestamp in missing_subject0:
    duplicate_rows = positions_subject1[positions_subject1['position_timestamps'] == timestamp]
    if len(duplicate_rows) > 1:
        # Reassign every second row to subject0
        extra_rows = duplicate_rows.iloc[1:]
        positions_subject1.loc[extra_rows.index, 'subject_name'] = 'BAA-1104045'

# Convert lists to DataFrames
new_rows_subject0_df = pd.DataFrame(new_rows_subject0)
new_rows_subject1_df = pd.DataFrame(new_rows_subject1)

# Combine the original DataFrames with the new rows
corrected_positions_subject0 = pd.concat([positions_subject0, new_rows_subject0_df], ignore_index=True)
corrected_positions_subject1 = pd.concat([positions_subject1, new_rows_subject1_df], ignore_index=True)

# Combine the corrected DataFrames
corrected_positions_df = pd.concat([corrected_positions_subject0, corrected_positions_subject1])

# Sort by timestamp to ensure order
corrected_positions_df = corrected_positions_df.sort_values(by='position_timestamps').reset_index(drop=True)

# Display the corrected DataFrame
print(corrected_positions_df)


In [None]:
# get rid of timestmps with one subject missing
# Convert timestamps to datetime if not already
corrected_positions_df['position_timestamps'] = pd.to_datetime(corrected_positions_df['position_timestamps'])

# Identify unique timestamps for each subject
timestamps_subjects = corrected_positions_df.groupby('subject_name')['position_timestamps'].apply(set)

# Find common timestamps where both subjects have data
common_timestamps = set(timestamps_subjects['BAA-1104045']).intersection(set(timestamps_subjects['BAA-1104047']))

# Filter rows to keep only those with timestamps in common_timestamps
corrected_positions_df = corrected_positions_df[corrected_positions_df['position_timestamps'].isin(common_timestamps)]

# Display the filtered DataFrame
print(corrected_positions_df)

In [None]:
# divide into consecutive chunks of timetamps
# Sort the DataFrame by 'position_timestamps'
corrected_positions_df = corrected_positions_df.sort_values('position_timestamps')

# Calculate time differences between consecutive rows
corrected_positions_df['time_diff'] = corrected_positions_df['position_timestamps'].diff()

# Define maximum allowed gap
max_gap = pd.Timedelta(seconds=0.5)

# Identify chunks based on the time difference
corrected_positions_df['chunk'] = (corrected_positions_df['time_diff'] > max_gap).cumsum()

# Group by 'chunk' to process each chunk separately
chunks = corrected_positions_df.groupby('chunk')
len(chunks)

In [None]:
#detect id swaps
# Initialize a list to hold detected swaps
all_swaps = []

# Process each chunk to detect swaps
for chunk_id, chunk_df in chunks:
    # Extract positions for each subject
    positions_subject0 = chunk_df[chunk_df['subject_name'] == 'BAA-1104045'][['position_x', 'position_y']].values
    positions_subject1 = chunk_df[chunk_df['subject_name'] == 'BAA-1104047'][['position_x', 'position_y']].values

    # Initialize last known positions
    last_known_pos0 = positions_subject0[0]
    last_known_pos1 = positions_subject1[0]
    assert len(positions_subject0) == len(positions_subject1)
    
    # Loop over the frames
    for i in range(1, len(positions_subject0)):
        # Calculate Euclidean distances
        dists = np.zeros((2, 2))
        dists[0, 0] = np.sqrt(np.sum((positions_subject0[i] - last_known_pos0)**2))
        dists[0, 1] = np.sqrt(np.sum((positions_subject0[i] - last_known_pos1)**2))
        dists[1, 0] = np.sqrt(np.sum((positions_subject1[i] - last_known_pos0)**2))
        dists[1, 1] = np.sqrt(np.sum((positions_subject1[i] - last_known_pos1)**2))

        # Check for swaps
        if dists[0, 0] + dists[1, 1] <= dists[0, 1] + dists[1, 0]:
            last_known_pos0 = positions_subject0[i]
            last_known_pos1 = positions_subject1[i]
        else:
            last_known_pos0 = positions_subject1[i]
            last_known_pos1 = positions_subject0[i]
            swap_timestamp = chunk_df.iloc[i]['position_timestamps']
            if swap_timestamp not in all_swaps:
                all_swaps.append(swap_timestamp)

# Display the detected swaps
print(f"Detected ID swaps: {len(all_swaps)}")

In [None]:
#plot x position of both subjects with id swaps overlaid
positions_subject0 = corrected_positions_df[corrected_positions_df['subject_name'] == 'BAA-1104045']
positions_subject1 = corrected_positions_df[corrected_positions_df['subject_name'] == 'BAA-1104047']

# Initialize the figure
fig = go.Figure()

# Add x positions for Subject 0
fig.add_trace(go.Scatter(
    x=positions_subject0['position_timestamps'],
    y=positions_subject0['position_x'],
    mode='lines',
    name='Subject BAA-1104045',
    line=dict(color='blue')
))

# Add x positions for Subject 1
fig.add_trace(go.Scatter(
    x=positions_subject1['position_timestamps'],
    y=positions_subject1['position_x'],
    mode='lines',
    name='Subject BAA-1104047',
    line=dict(color='red')
))

# Add ID swap markers
swap_y_value = positions_subject0['position_x'].max() if not positions_subject0.empty else 0
fig.add_trace(go.Scatter(
    x=all_swaps,
    y=[swap_y_value] * len(all_swaps),
    mode='markers',
    marker=dict(color='green', symbol='x', size=10),
))

# Update layout
fig.update_layout(
    title='X Position of Both Subjects with ID Swaps Overlaid',
    xaxis_title='Time',
    yaxis_title='X Position (cm)',
    legend_title='Subjects',
    template='plotly_white'
)

# Show the plot
fig.show()

In [None]:
# Convert the list of swap timestamps to a set for faster lookup
swap_timestamps_set = set(all_swaps)

# Create a copy of the DataFrame to avoid modifying the original
corrected_positions_swapped_df = corrected_positions_df.copy()

# Iterate through the DataFrame and swap subject names where timestamps match
for index, row in corrected_positions_swapped_df.iterrows():
    if row['position_timestamps'] in swap_timestamps_set:
        if row['subject_name'] == 'BAA-1104045':
            corrected_positions_swapped_df.at[index, 'subject_name'] = 'BAA-1104047'
        elif row['subject_name'] == 'BAA-1104047':
            corrected_positions_swapped_df.at[index, 'subject_name'] = 'BAA-1104045'

print(corrected_positions_swapped_df)


In [None]:
#plot x position of both subjects with id swaps overlaid
positions_subject0 = corrected_positions_swapped_df[corrected_positions_swapped_df['subject_name'] == 'BAA-1104045']
positions_subject1 = corrected_positions_swapped_df[corrected_positions_swapped_df['subject_name'] == 'BAA-1104047']

# Initialize the figure
fig = go.Figure()

# Add x positions for Subject 0
fig.add_trace(go.Scatter(
    x=positions_subject0['position_timestamps'],
    y=positions_subject0['position_x'],
    mode='lines',
    name='Subject BAA-1104045',
    line=dict(color='blue')
))

# Add x positions for Subject 1
fig.add_trace(go.Scatter(
    x=positions_subject1['position_timestamps'],
    y=positions_subject1['position_x'],
    mode='lines',
    name='Subject BAA-1104047',
    line=dict(color='red')
))

# Add ID swap markers
swap_y_value = positions_subject0['position_x'].max() if not positions_subject0.empty else 0
fig.add_trace(go.Scatter(
    x=all_swaps,
    y=[swap_y_value] * len(all_swaps),
    mode='markers',
    marker=dict(color='green', symbol='x', size=10),
))

# Update layout
fig.update_layout(
    title='X Position of Both Subjects with ID Swaps Overlaid',
    xaxis_title='Time',
    yaxis_title='X Position (cm)',
    legend_title='Subjects',
    template='plotly_white'
)

# Show the plot
fig.show()

In [None]:
# Define the nest region as a polygon
nest_corners = metadata.ActiveRegion.NestRegion.ArrayOfPoint
nest_polygon = Polygon([
    (int(nest_corners[0]["X"]), int(nest_corners[0]["Y"])),
    (int(nest_corners[1]["X"]), int(nest_corners[1]["Y"])),
    (int(nest_corners[2]["X"]), int(nest_corners[2]["Y"])),
    (int(nest_corners[3]["X"]), int(nest_corners[3]["Y"]))
])

# Function to check if a point is within the nest polygon
def is_in_nest(x, y):
    point = Point(x, y)
    return nest_polygon.contains(point)

# Apply the function to filter the DataFrame
in_nest_position_df = corrected_positions_swapped_df[
    corrected_positions_swapped_df.apply(lambda row: is_in_nest(row["position_x"], row["position_y"]), axis=1)
]
# Reset the index to make 'position_timestamps' a column
in_nest_position_df = in_nest_position_df.reset_index()
in_nest_position_df

In [None]:
# only compute speed if consecutive timestamps are between
min_time_diff = 0.002
max_time_diff = 5

# Compute the time differences in seconds
time_diffs = in_nest_position_df.reset_index().groupby("subject_name")["position_timestamps"].diff().dt.total_seconds()

# Compute speed
in_nest_position_df["speed"] = (
    in_nest_position_df.groupby("subject_name")[["position_x", "position_y"]].diff().apply(np.linalg.norm, axis=1)
    / time_diffs
)

# Set speed to NaN where time difference is greater than 5 seconds
in_nest_position_df.loc[time_diffs > max_time_diff, "speed"] = np.nan
in_nest_position_df.loc[time_diffs < min_time_diff, "speed"] = np.nan

in_nest_position_df

In [None]:
# discard inf value - errors in tracking
# Identify timestamps with inf values in the speed column
in_nest_position_df.set_index('position_timestamps', inplace=True)
inf_timestamps = in_nest_position_df[in_nest_position_df['speed'].isin([np.inf, -np.inf])][['subject_name', 'speed']].reset_index()
# Get the subject-timestamp pairs where speed is inf
inf_pairs = inf_timestamps[['subject_name', 'position_timestamps']].drop_duplicates()
# Merge to filter out the rows with inf subject-timestamp pairs
cleaned_df = in_nest_position_df.merge(inf_pairs, on=['subject_name', 'position_timestamps'], how='left', indicator=True)
cleaned_df = cleaned_df[cleaned_df['_merge'] == 'left_only'].drop(columns=['_merge'])
cleaned_df

In [None]:
# Apply rolling mean with a window size of 50 = 1s
window_size = 50
# Set position_timestamps as the index
cleaned_df['smoothed_speed'] = cleaned_df.groupby('subject_name')['speed'].rolling(window=window_size, min_periods=1).mean().reset_index(level=0, drop=True)
# Reset index for plotting
cleaned_df = cleaned_df.reset_index()
cleaned_df

In [None]:

# Plot the sampled speed values
fig = px.line(
    cleaned_df,
    x="position_timestamps",
    y="smoothed_speed",
    color="subject_name",
    title="Speed of subjects",
    labels={"smoothed_speed": "Speed (pixels/second)", "position_timestamps": "Time"},
    color_discrete_map=id_color_map,
)

fig.show()

In [None]:
subjects_positions_df = cleaned_df

In [None]:
# Define thresholds
cm2px = 5.4 
fps = 50
speed_threshold = 5 * cm2px # in cm/s
frame_threshold = 30 * fps  # 0.5 minutes at 50 frames per second

subjects_positions_df = subjects_positions_df.reset_index(drop=True)

grouped = subjects_positions_df.groupby('subject_name')

# Function to apply to each group
def filter_by_speed_threshold(group):
    # Create a boolean column where speed is below the threshold
    group['below_threshold'] = (group['smoothed_speed'] < speed_threshold) | (group['smoothed_speed'].isna())
    # Calculate rolling sum of frames below threshold
    group['below_threshold_rolling'] = group['below_threshold'].rolling(window=frame_threshold, center=True).sum()
    # Create a boolean column where the rolling sum meets the frame threshold
    group['meets_frame_threshold'] = group['below_threshold_rolling'] < frame_threshold
    # Create the 'active' column based on the criteria
    group['active'] = group['meets_frame_threshold']
    return group

# Apply the function to each group and concatenate the results
filtered_subjects_positions_df = pd.DataFrame
filtered_subjects_positions_df = grouped.apply(filter_by_speed_threshold).reset_index(drop=True)


filtered_subjects_positions_df

In [None]:
# Create subplots
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, subplot_titles=("BAA-1104045", "BAA-1104047"))

# Filter data for each subject
subject1_df = filtered_subjects_positions_df[filtered_subjects_positions_df['subject_name'] == 'BAA-1104045']
subject2_df = filtered_subjects_positions_df[filtered_subjects_positions_df['subject_name'] == 'BAA-1104047']

# Add line plot for Subject 1
fig.add_trace(
    go.Scatter(
        x=subject1_df["position_timestamps"],
        y=subject1_df["smoothed_speed"],
        mode='lines',
        name='Speed BAA-1104045',
        line=dict(color=id_color_map['BAA-1104045'])
    ),
    row=1, col=1
)

active_subject1_df = subject1_df[subject1_df["active"]]  # Filter for active timestamps
if not active_subject1_df.empty:  # Check if there are any active timestamps
    fig.add_trace(
        go.Scatter(
            x=active_subject1_df["position_timestamps"],
            y=[-1000] * len(active_subject1_df),  # Dummy y-values to position the dots
            mode='markers',
            marker=dict(
                color='green',  # Active timestamps are green
                size=5,
                symbol='circle'
            ),
            name='Active Status BAA-1104045'
        ),
        row=1, col=1
    )
# Add line plot for Subject 2
fig.add_trace(
    go.Scatter(
        x=subject2_df["position_timestamps"],
        y=subject2_df["smoothed_speed"],
        mode='lines',
        name='Speed BAA-1104047',
        line=dict(color=id_color_map['BAA-1104047'])
    ),
    row=2, col=1
)

active_subject2_df = subject2_df[subject2_df["active"]]  # Filter for active timestamps
if not active_subject2_df.empty:  # Check if there are any active timestamps
    fig.add_trace(
        go.Scatter(
            x=active_subject2_df["position_timestamps"],
            y=[-1000] * len(active_subject2_df),  # Dummy y-values to position the dots
            mode='markers',
            marker=dict(
                color='green',  # Active timestamps are green
                size=5,
                symbol='circle'
            ),
            name='Active Status BAA-1104047'
        ),
        row=2, col=1
    )

# Update layout
fig.update_layout(
    title_text="Speed filtered for Subjects",
    height=600,
    showlegend=False
)

fig.update_xaxes(title_text="Time", row=2, col=1)
fig.update_yaxes(title_text="Speed (pixels/second)", row=1, col=1)
fig.update_yaxes(title_text="Speed (pixels/second)", row=2, col=1)

fig.show()

In [None]:
foraging_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps')
    & {"spinnaker_video_source_name": "CameraTop"} #this is the video source name which we rstrict once we selected tuff to keep in table
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
)
pellet_data = foraging_query.fetch()

For now take shortcut, and just choose timestamps duirng the night where no events occur (not getin actual awake times)

In [None]:
# Extract and sort the start timestamps
time_window = timedelta(minutes=5)  # 5 minutes
start_timestamps = combined_df['start_timestamp'].sort_values().reset_index(drop=True)


# Convert string times to datetime.time objects
night_start_time = datetime.strptime(night_start, '%H:%M').time()
night_end_time = datetime.strptime(night_end, '%H:%M').time()

# Function to check if a timestamp is valid
def is_valid_timestamp(new_timestamp, existing_timestamps, time_window, night_start_time, night_end_time):
    # Extract time from the timestamp
    new_time = new_timestamp.time()
    
    # Check if the timestamp falls within the light cycle period
    if not (night_start_time <= new_time < night_end_time):
        return False
    
    # Check if the timestamp is within the time window length
    for ts in existing_timestamps:
        if abs((new_timestamp - ts).total_seconds()) < time_window.total_seconds():
            return False
    return True


# Generate 200 random valid timestamps
valid_timestamps = []
min_timestamp = start_timestamps.min()
max_timestamp = start_timestamps.max()

while len(valid_timestamps) < len(good_events_df):
    random_timestamp = min_timestamp + timedelta(seconds=np.random.randint(0, int((max_timestamp - min_timestamp).total_seconds())))
    if is_valid_timestamp(random_timestamp, start_timestamps, time_window, night_start_time, night_end_time):
        valid_timestamps.append(random_timestamp)
        start_timestamps = pd.concat([start_timestamps, pd.Series([random_timestamp])]).sort_values().reset_index(drop=True)

# Convert the list of valid timestamps to a DataFrame
valid_timestamps_df = pd.DataFrame(valid_timestamps, columns=['random_timestamp'])

# Display the DataFrame
valid_timestamps_df

In [None]:
#concatenate tiemstmps
# Extract start_timestamps from good_events
good_events_start_timestamps = good_events_df['start_timestamp']
good_events_behaviour_type = good_events_df['behavior_type']
good_events_dominant_id = good_events_df['dominant_id']

# Create DataFrame for good_events with event column set to True
good_events_df_with_event = pd.DataFrame({
    'start_timestamps': good_events_start_timestamps,
    'event': True,
    'behaviour_type': good_events_behaviour_type,
    'dominant_id': good_events_dominant_id
})

# Create DataFrame for valid_timestamps with event column set to False
valid_timestamps_df_with_event = pd.DataFrame({
    'start_timestamps': valid_timestamps_df['random_timestamp'],
    'event': False
})
# Concatenate the DataFrames
all_timestamps_df = pd.concat([good_events_df_with_event, valid_timestamps_df_with_event]).sort_values(by='start_timestamps').reset_index(drop=True)

# Add end_timestamps column
all_timestamps_df['end_timestamps'] = all_timestamps_df['start_timestamps'] + timedelta(minutes=time_window_length)

# Display the new DataFrame
all_timestamps_df

In [None]:
#get pellet data
foraging_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps')
    & {"spinnaker_video_source_name": "CameraTop"} #this is the video source name which we rstrict once we selected tuff to keep in table
    & key
    & f'block_start >= "{social_start_day}"'
    & f'block_start <= "{social_end_day}"'
)
foraging_query
# Fetch the data
pellet_data = foraging_query.fetch()

# Initialize an empty list to store the data
data = []

# Loop through each entry in the pellet_data array
for entry in pellet_data:
    experiment_name = entry[0]
    block_start = entry[1]
    patch_name = entry[2]
    subject_id = entry[3]
    pellet_count = entry[4]
    pellet_timestamps = entry[5]
    
    
    # For each pellet timestamp, create a dictionary and append to the list
    for pellet_timestamp in pellet_timestamps:
        data.append({'time': pellet_timestamp,
                     'subject_id': subject_id,
                     'rank': 'dominant' if subject_id == dominant_id else 'subordinate',})

# Convert the list of dictionaries into a DataFrame
pellet_df = pd.DataFrame(data)

# Convert the 'time' column to datetime
pellet_df['time'] = pd.to_datetime(pellet_df['time'])

# Display the DataFrame
print(pellet_df)


In [None]:
#get pellets within time window
results = []

# Get all unique subject IDs
all_subject_ids = pellet_df['subject_id'].unique()

for timestamp in all_timestamps_df.iterrows():
    start_time = timestamp[1]['start_timestamps']
    end_time = timestamp[1]['end_timestamps']
    
    # Get the pellet data within the time range
    pellet_data = pellet_df[(pellet_df['time'] >= start_time) & (pellet_df['time'] <= end_time)]
    
    # Calculate the total number of pellets per subject
    total_pellets = pellet_data['subject_id'].value_counts()
    
    # Append results for each subject, including those with zero pellets
    for subject_id in all_subject_ids:
        pellet_count = total_pellets.get(subject_id, 0)
        result = timestamp[1].to_dict()  # Copy all columns from all_timestamps_df
        result.update({
            'subject_id': subject_id,
            'pellet_number': pellet_count
        })
        results.append(result)

# Convert results to DataFrame
results_df = pd.DataFrame(results)

# Replace 'nan' values in 'behaviour_type' column with 'no_interaction'
results_df['behaviour_type'] = results_df['behaviour_type'].fillna('no_interaction')

# Display the new DataFrame
results_df

In [None]:
# Define the order of categories for 'behaviour_type'
category_order = ['no_interaction'] + [x for x in results_df['behaviour_type'].unique() if x != 'no_interaction']

# Create the box plot with individual data points
fig = px.box(
    results_df[results_df['subject_id'] != 'nan'],
    x='behaviour_type',
    y='pellet_number',
    color='subject_id',
    title='Pellets post events',
    labels={'subject_id': 'Subject ID', 'pellet_number': 'Pellets 5min post event', 'behaviour_type': 'Social interaction type'},
    category_orders={'behaviour_type': category_order},
    color_discrete_map=id_color_map,
    points='all',
)

# Show the plot
fig.show()

In [None]:
# Filter out rows where 'pellet_number' is zero
filtered_df = results_df[(results_df['subject_id'] != 'nan') & (results_df['pellet_number'] != 0)]

# Create the scatter plot
fig = px.box(
    filtered_df,
    x='behaviour_type',
    y='pellet_number',
    color='subject_id',
    title='Pellets post events (if at least 1 pellet)',
    labels={'subject_id': 'Subject ID', 'pellet_number': 'Pellets 5min post event', 'behaviour_type': 'Social interaction type'},
    category_orders={'behaviour_type': category_order},
    color_discrete_map=id_color_map,
    points='all',
)

# Show the plot
fig.show()

In [None]:
# Create the scatter plot
fig = px.box(
    results_df[results_df['subject_id'] != 'nan'],
    x='event',
    y='pellet_number',
    color='subject_id',
    title='Pellets post events',
    labels={'subject_id': 'Subject ID', 'pellet_number': 'Pellets 5min post event', 'event': 'Post social interation'},
    category_orders={'behaviour_type': category_order},
    color_discrete_map=id_color_map,
    points='all',
)

# Show the plot
fig.show()

## 3.2 Does the outcome of social interactions influence foraging behaviour?

Pellets and tube test wins:

In [None]:
#get 20min period after each chase
#get number of pellets per subject in eachpost chase period
results = []
for index, chase in chasing_df.iterrows():
    # Get the subject_id and the timestamp of the tube test
    dominant_id = str(chase['dominant_id'])
    subordinate_id = (set(unique_ids) - {dominant_id}).pop()


    chase_time = chase['start_timestamp'] 

    # Calculate the start and end time of the 20-minute period
    start_time = chase_time
    end_time = chase_time + pd.Timedelta(minutes=20)
    
    # Get the pellet data within the time range
    pellet_data = pellet_df[(pellet_df['time'] >= start_time) & (pellet_df['time'] <= end_time)]
    
    # Calculate the total number of pellets per subject
    total_pellets = pellet_data['subject_id'].value_counts()
    
    # Get the number of pellets for the dominant and subordinate subjects
    dominant_pellets = total_pellets.get(dominant_id, 0)
    subordinate_pellets = total_pellets.get(subordinate_id, 0)

    # Append results for the dominant subject
    results.append({
        'start': start_time,
        'end': end_time,
        'outcome': 'dominant',
        'subject_id': dominant_id,
        'pellets': dominant_pellets
    })

    # Append results for the subordinate subject
    results.append({
        'start': start_time,
        'end': end_time,
        'outcome': 'subordinate',
        'subject_id': subordinate_id,
        'pellets': subordinate_pellets
    })

# Convert the results list to a DataFrame
results_df = pd.DataFrame(results)

results_df.head()

# plot number of pellets post tubetest per subject when dominant and when subordinate
# Create the scatter plot
fig = px.box(
    results_df[results_df['subject_id'] != 'nan'],
    x='subject_id',
    y='pellets',
    color='outcome',
    title='Pellets post chase',
    labels={'subject_id': 'Subject ID', 'pellets': 'Pellets 20min post event', 'outcome': 'Outcome of event'},
    points='all',
)

# Show the plot
fig.show()

In [None]:
results_df.head()


In [None]:

#get number of pellets per subject in each post-tubetest period
results = []
for index, test in tube_test_df.iterrows():
    # Get the subject_id and the timestamp of the tube test
    dominant_id = str(test['dominant_id'])
    subordinate_id = (set(unique_ids) - {dominant_id}).pop()


    test_time = test['start_timestamp'] 

    # Calculate the start and end time of the 20-minute period
    start_time = test_time
    end_time = test_time + pd.Timedelta(minutes=20)
    
    # Get the pellet data within the time range
    pellet_data = pellet_df[(pellet_df['time'] >= start_time) & (pellet_df['time'] <= end_time)]
    
    # Calculate the total number of pellets per subject
    total_pellets = pellet_data['subject_id'].value_counts()
    
    # Get the number of pellets for the dominant and subordinate subjects
    dominant_pellets = total_pellets.get(dominant_id, 0)
    subordinate_pellets = total_pellets.get(subordinate_id, 0)

    # Append results for the dominant subject
    results.append({
        'start': start_time,
        'end': end_time,
        'outcome': 'dominant',
        'subject_id': dominant_id,
        'pellets': dominant_pellets
    })

    # Append results for the subordinate subject
    results.append({
        'start': start_time,
        'end': end_time,
        'outcome': 'subordinate',
        'subject_id': subordinate_id,
        'pellets': subordinate_pellets
    })

# Convert the results list to a DataFrame
results_df = pd.DataFrame(results)

results_df.head()

# plot number of pellets post tubetest per subject when dominant and when subordinate
# Create the scatter plot
fig = px.box(
    results_df[results_df['subject_id'] != 'nan'],
    x='subject_id',
    y='pellets',
    color='outcome',
    title='Pellets post tube test',
    labels={'subject_id': 'Subject ID', 'pellets': 'Pellets 20min post event', 'outcome': 'Outcome of event'},
    points='all',
)

# Show the plot
fig.show()


Heatmap of foraging /interactions / time relationship:

In [None]:
# nomalise measure of intracitions

# Resample the data into 1-hour bins and count the number of events for each behaviour type
events_per_hour = combined_df.groupby('behavior_type').resample('1H').size().reset_index(name='event_count')

# Display the result
print(events_per_hour)

In [None]:
#normalise teh event counts for eac behaviour type to get a metric between 0-1
# Calculate the min and max event count for each behavior type
min_max_counts = events_per_hour.groupby('behavior_type')['event_count'].agg(['min', 'max']).reset_index()
min_max_counts.rename(columns={'min': 'min_event_count', 'max': 'max_event_count'}, inplace=True)

# Merge the min and max event counts back to the original DataFrame
events_per_hour = events_per_hour.merge(min_max_counts, on='behavior_type')

# Apply min-max normalization
events_per_hour['normalized_event_count'] = (
    (events_per_hour['event_count'] - events_per_hour['min_event_count']) /
    (events_per_hour['max_event_count'] - events_per_hour['min_event_count'])
)

# Drop the 'min_event_count' and 'max_event_count' columns if not needed
events_per_hour.drop(columns=['min_event_count', 'max_event_count'], inplace=True)

# Display the result
events_per_hour

In [None]:
# plot this measure across time
# Create the line plot with resampled and interpolated data
fig = px.line(
    events_per_hour,
    x='start_timestamp',
    y='normalized_event_count',
    color='behavior_type',
    title='Normalised interactions over time',
    labels={'start_timestamp': 'Time', 'normalized_event_count': 'Normalised count', 'behavior_type': 'Behaviour'},    
    color_discrete_map=behaviour_map,
)

# Show the plot
fig.show()

In [None]:
#get pellet counts
# Define the time bin size (e.g., 1 hour)
time_bin_size = '1h'
# Create a copy of the DataFrame to avoid modifying the original
pellet_df_copy = pellet_df.copy()

# Set the 'time' as the DataFrame index
pellet_df_copy.set_index('time', inplace=True)

# Bin the data by time and count the number of pellets per bin per subject
binned_pellet_df = pellet_df_copy.groupby([pd.Grouper(freq=time_bin_size), 'subject_id']).size().reset_index(name='pellet_count')
    
# Create a complete time range for the bins
time_range = pd.date_range(start=social_start_time, end=social_end_time, freq=time_bin_size)

# Get unique subject IDs
subject_ids = binned_pellet_df['subject_id'].unique()

# Create a MultiIndex with all combinations of time_range and subject_ids
multi_index = pd.MultiIndex.from_product([time_range, subject_ids], names=['time', 'subject_id'])

# Create an empty DataFrame with the MultiIndex
complete_df = pd.DataFrame(index=multi_index).reset_index()

# Merge the complete DataFrame with the binned pellet data
complete_df = complete_df.merge(binned_pellet_df, on=['time', 'subject_id'], how='left').fillna(0)

# Ensure pellet_count is an integer
complete_df['pellet_count'] = complete_df['pellet_count'].astype(int)
complete_df

In [None]:
#normalise hte pellet counts pr subject
# Calculate the min and max event count for each behavior type
min_max_counts = complete_df.groupby('subject_id')['pellet_count'].agg(['min', 'max']).reset_index()
min_max_counts.rename(columns={'min': 'min_event_count', 'max': 'max_event_count'}, inplace=True)

# Merge the min and max event counts back to the original DataFrame
complete_df = complete_df.merge(min_max_counts, on='subject_id')

# Apply min-max normalization
complete_df['normalized_pellet_count'] = (
    (complete_df['pellet_count'] - complete_df['min_event_count']) /
    (complete_df['max_event_count'] - complete_df['min_event_count'])
)

# Drop the 'min_event_count' and 'max_event_count' columns if not needed
complete_df.drop(columns=['min_event_count', 'max_event_count'], inplace=True)

# Display the result
complete_df

In [None]:
#add pellet count to event table table
events_per_hour.rename(columns={'start_timestamp': 'time'}, inplace=True)
# Merge the DataFrames on 'time' and 'subject_id'
merged_df = pd.merge(events_per_hour, complete_df, on=['time'], how='left')

# Display the result
merged_df

In [None]:
# List of unique subjects and behavior types
import plotly.express as px
import plotly.subplots as sp

subjects = merged_df['subject_id'].unique()
behavior_types = merged_df['behavior_type'].unique()

# Iterate over each subject
for subject in subjects:
    # Create a subplot figure with 3 rows (one for each behavior type)
    fig = sp.make_subplots(rows=len(behavior_types), cols=1, shared_xaxes=True, subplot_titles=behavior_types)
    
    # Iterate over each behavior type
    for i, behavior_type in enumerate(behavior_types, start=1):
        # Filter the DataFrame for the specific subject and behavior type
        subject_behavior_df = merged_df[(merged_df['subject_id'] == subject) & (merged_df['behavior_type'] == behavior_type)]
        
        # Create a heatmap DataFrame
        heatmap_data = subject_behavior_df.pivot_table(
            index='time',
            values=['normalized_event_count', 'normalized_pellet_count']
        )
        
        # Create heatmap
        heatmap_fig = px.imshow(
            heatmap_data.T,
            labels={'x': 'Time (1h Bins)', 'y': 'Measurement'},
            x=heatmap_data.index,
            y=['Normalized Event Count', 'Normalized Pellet Count'],
            color_continuous_scale='Viridis',
            aspect='auto'
        )
        
        # Add the heatmap to the subplot
        for trace in heatmap_fig.data:
            fig.add_trace(trace, row=i, col=1)
    
    # Update layout
    fig.update_layout(
        title=f'Heatmaps for Subject: {subject}',
        xaxis_title='Time (1h Bins)',
        yaxis_title='Measurement',
        template='plotly_white',
        height=900  # Adjust height as needed
    )
    
    # Show the plot
    fig.show()

In [None]:


# Pivot the data to create a DataFrame for the heatmap
heatmap_data = merged_df.pivot_table(
    index='time',
    columns=['subject_id', 'behavior_type'],
    values=['normalized_event_count', 'normalized_pellet_count']
)

# Flatten the MultiIndex to create unique labels for the heatmap rows
heatmap_data.columns = [f'{val}_{subj}_{beh}' for val, subj, beh in heatmap_data.columns]

# Create a new DataFrame with just the rows you want to plot

# Extract pellet counts for both subjects (assuming no behavior type distinction is needed)
pellet_counts_47 = heatmap_data.filter(like='normalized_pellet_count_BAA-1104047')
pellet_counts_45 = heatmap_data.filter(like='normalized_pellet_count_BAA-1104045')

# Average event counts across both subjects for each behavior type
event_counts_chasing = heatmap_data.filter(like='chasing').mean(axis=1) *2
event_counts_fighting = heatmap_data.filter(like='fighting').mean(axis=1) *2
event_counts_tube_test = heatmap_data.filter(like='tube_test').mean(axis=1) *2

# Combine the data into a single DataFrame for plotting
heatmap_final = pd.DataFrame({
    'Pellet Count BAA-1104047': pellet_counts_47.mean(axis=1),
    'Pellet Count BAA-1104045': pellet_counts_45.mean(axis=1),
    'Chasing Events': event_counts_chasing,
    'Fighting Events': event_counts_fighting,
    'Tube Test Events': event_counts_tube_test
}, index=heatmap_data.index)

# Create the heatmap
fig = px.imshow(
    heatmap_final.T,  # Transpose to have the measurements on the y-axis
    labels={'x': 'Time (1h Bins)', 'y': 'Measurement'},
    color_continuous_scale='Viridis',
    aspect='auto',
    zmin=0, zmax=1  # Set the color scale to range from 0 to 1
)

# Update layout
fig.update_layout(
    title='Heatmap of Pellet Counts and Events for Both Subjects Over Time',
    xaxis_title='Time (1h Bins)',
    yaxis_title='Measurement',
    template='plotly_white',
    height=500  # Adjust height as needed
)

# Show the plot
fig.show()
