In [None]:
%load_ext autoreload
%autoreload 2

import datajoint as dj
import aeon
from aeon.io import api
from aeon.schema.schemas import social02
from aeon.dj_pipeline.analysis.block_analysis import * 
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
import plotly.subplots as sp
import numpy as np
from plotly.subplots import make_subplots
from datetime import timedelta, datetime
from scipy import stats
from shapely.geometry import Point, Polygon
import random

# 0. Import and process data

Experiment to analyse:

In [None]:
key = {"experiment_name": "social0.2-aeon3"}

In [None]:
arena = 'AEON3'

Load social interaction data:

In [None]:
# TODO: replace with revised csv path
base_path = '/ceph/aeon/aeon/code/scratchpad/Orsi/'

tube_test_path = f'all_tube_test_videos/{arena}_tube_tests_revised_final.csv'
fights_path = f'all_fighting_videos/{arena}_fights.csv'
chasing_path = f'all_chasing_videos/{arena}_chases.csv'

# Open CSV containing tube test data.
tube_test_df = pd.read_csv(base_path + tube_test_path)
# Open CSV containing fighting data.
fights_df = pd.read_csv(base_path + fights_path)
# Open CSV containing chasing data.
chasing_df = pd.read_csv(base_path + chasing_path)

Clean up social interaction data:

In [None]:
# Convert the start and end timestamps to datetime
chasing_df['start_timestamp'] = pd.to_datetime(chasing_df['start_timestamp'], format='%Y-%m-%dT%H-%M-%S')
chasing_df['end_timestamp'] = pd.to_datetime(chasing_df['end_timestamp'], format='%Y-%m-%dT%H-%M-%S')

fights_df['start_timestamp'] = pd.to_datetime(fights_df['start_timestamp'], format='%Y-%m-%dT%H-%M-%S')
fights_df['end_timestamp'] = pd.to_datetime(fights_df['end_timestamp'], format='%Y-%m-%dT%H-%M-%S')

tube_test_df['start_timestamp'] = pd.to_datetime(tube_test_df['start_timestamp'], format='%Y-%m-%dT%H-%M-%S')
tube_test_df['end_timestamp'] = pd.to_datetime(tube_test_df['end_timestamp'], format='%Y-%m-%dT%H-%M-%S')

# Add a 'behavior_type' column to each data frame
chasing_df['behavior_type'] = 'chasing'
fights_df['behavior_type'] = 'fighting'
tube_test_df['behavior_type'] = 'tube_test'

#remane colunms to domiant_id
chasing_df.rename(columns={'chaser_id': 'dominant_id'}, inplace=True)
tube_test_df.rename(columns={'winner_id': 'dominant_id'}, inplace=True)
fights_df['dominant_id'] = fights_df.get('dominant_id', 'NaN')

# Combine the data frames
combined_df = pd.concat([chasing_df, fights_df, tube_test_df])
# Replace NaN values in 'dominant_id' with a string 'NaN'
combined_df['dominant_id'] = combined_df['dominant_id'].fillna('NaN')
combined_df.head()

Add pre/post tubetest data manually:

In [None]:
# Create the data
pre_tube_test = {
    'behavior_type': ['pre_tube_test'] * 10,
    'dominant_id': ['BAA-1104045'] * 2 + ['BAA-1104047'] * 8
}

# Create the DataFrame
pre_tube_test = pd.DataFrame(pre_tube_test)


# Create the data
post_tube_test = {
    'behavior_type': ['pre_tube_test'] * 10,
    'dominant_id': ['BAA-1104045'] * 1 + ['BAA-1104047'] * 9
}

# Create the DataFrame
post_tube_test = pd.DataFrame(post_tube_test)

Get metadata like info:

In [None]:
#get start an edn as first and last block start
first_block_start = (BlockAnalysis() & key).fetch()['block_start'][0]
last_block_start= (BlockAnalysis() & key).fetch()['block_start'][-1]
# Calculate the start/end of the same day
experiment_start = datetime(first_block_start.year, first_block_start.month, first_block_start.day, 0, 0, 0)
experiment_end = datetime(last_block_start.year, last_block_start.month, last_block_start.day+1, 0, 0, 0)
experiment_end

In [None]:
#TODO: find out where to get this from data, now its manual
pre_solo_start = experiment_start
pre_solo_end = datetime.strptime('2024-02-09 00:00:00', '%Y-%m-%d %H:%M:%S')
social_start =  datetime.strptime('2024-02-09 00:00:00', '%Y-%m-%d %H:%M:%S') 
social_end = datetime.strptime('2024-02-24 00:00:00', '%Y-%m-%d %H:%M:%S')
post_solo_start = datetime.strptime('2024-02-25 00:00:00', '%Y-%m-%d %H:%M:%S')
post_solo_end = experiment_end

In [None]:
#ids of the animals
unique_ids = np.unique(combined_df['dominant_id'])
unique_ids = unique_ids[unique_ids != 'NaN']
unique_ids[0]

In [None]:
# Define light cycle periods
night_start = '08:00'
night_end = '19:00'
twilight_start = '07:00'
twilight_end = '08:00'
dawn_start = '19:00'
dawn_end = '20:00'
day_start = '20:00'
day_end = '07:00'

In [None]:
#get metadata base on fist solo block using api
metadata = (
    api.load(f'/ceph/aeon/aeon/data/raw/{arena}/social0.2', social02.Metadata, experiment_start, first_block_start).iloc[0].metadata
)
metadata

Plotting style settings:

In [None]:
# Create a dictionary for color mapping
id_color_map = {
    'NaN': 'grey',
    unique_ids[1]: 'purple',
    unique_ids[0]: 'green'
} 
behaviour_map = {
    'chasing': 'blue',
    'fighting': 'red',
    'tube_test': 'orange'
} 


# 1. Temporal pattern and stability of dominance interactions

In [None]:
# Create the raster plot
fig = px.scatter(
    combined_df,
    x='start_timestamp',
    y='behavior_type',
    color='dominant_id',
    title='Behavior Raster Plot',
    labels={'start_timestamp': 'Time', 'behavior_type': 'Behavior Type'},
    color_discrete_map=id_color_map
)

# Set x-axis limits
fig.update_xaxes(range=[social_start, social_end])
# Set x-axis limits
fig.update_yaxes(range=[-1.5, 2.5])  # Example date range

# Iterate over each day in the two-week period
# Iterate over each day in the two-week period
current_day = social_start
while current_day < social_end:
    # Define the start and end times for each period
    night_start_time = current_day.replace(hour=int(night_start.split(':')[0]), minute=int(night_start.split(':')[1]))
    night_end_time = current_day.replace(hour=int(night_end.split(':')[0]), minute=int(night_end.split(':')[1]))
    twilight_start_time = current_day.replace(hour=int(twilight_start.split(':')[0]), minute=int(twilight_start.split(':')[1]))
    twilight_end_time = current_day.replace(hour=int(twilight_end.split(':')[0]), minute=int(twilight_end.split(':')[1]))
    dawn_start_time = current_day.replace(hour=int(dawn_start.split(':')[0]), minute=int(dawn_start.split(':')[1]))
    dawn_end_time = current_day.replace(hour=int(dawn_end.split(':')[0]), minute=int(dawn_end.split(':')[1]))
    day_start_time = current_day.replace(hour=int(day_start.split(':')[0]), minute=int(day_start.split(':')[1]))
    
    # Calculate day_end_time correctly
    if int(day_end.split(':')[0]) < int(day_start.split(':')[0]):
        day_end_time = current_day + timedelta(days=1)
        day_end_time = day_end_time.replace(hour=int(day_end.split(':')[0]), minute=int(day_end.split(':')[1]))
    else:
        day_end_time = current_day.replace(hour=int(day_end.split(':')[0]), minute=int(day_end.split(':')[1]))

    # Add horizontal lines for light and dark periods
    fig.add_shape(
        type="line",
        x0=twilight_start_time,
        x1=twilight_end_time,
        y0=-1,
        y1=-1,
        line=dict(color="gray", width=4)
    )
    fig.add_shape(
        type="line",
        x0=night_start_time,
        x1=night_end_time,
        y0=-1,
        y1=-1,
        line=dict(color="black", width=4)
    )
    fig.add_shape(
        type="line",
        x0=dawn_start_time,
        x1=dawn_end_time,
        y0=-1,
        y1=-1,
        line=dict(color="gray", width=4)
    )
    fig.add_shape(
        type="line",
        x0=day_start_time,
        x1=day_end_time,
        y0=-1,
        y1=-1,
        line=dict(color="white", width=4)
    )
    
    # Move to the next day
    current_day += timedelta(days=1)

# Show the plot
fig.show()

In [None]:
# Create a list of time bins
time_bins = pd.date_range(start=social_start, end=social_end, freq='24h')

# Bin the start_timestamp into the created time bins
combined_df['time_bin'] = pd.cut(combined_df['start_timestamp'], bins=time_bins)

# Calculate the number of interactions for each behavior in each time bin
interaction_counts = combined_df.groupby(['time_bin', 'behavior_type']).size().reset_index(name='total_interactions')

# Convert the time_bin column to strings
# Extract the start time of each bin
interaction_counts['time_bin_start'] = interaction_counts['time_bin'].apply(lambda x: x.left)

# Convert the start time to a string format
interaction_counts['time_bin_start'] = interaction_counts['time_bin_start'].dt.strftime('%Y-%m-%d %H:%M')

# Create the line plot
fig = px.line(
    interaction_counts,
    x='time_bin_start',
    y='total_interactions',
    color='behavior_type',
    title='Number of Interactions Over Time for Each Behavior',
    labels={'time_bin_start': 'Time Bin', 'total_interactions': 'Number of Interactions'},
    color_discrete_map=behaviour_map
)

# Show the plot
fig.show()

In [None]:
# Calculate the fraction of winning for chasing
chaser_counts_chasing = chasing_df['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_chasing.columns = ['dominant_id', 'p_wins']
chaser_counts_chasing['behavior_type'] = 'Chasing'

# Calculate the fraction of winning for tube test
chaser_counts_tube_test = tube_test_df['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_tube_test.columns = ['dominant_id', 'p_wins']
chaser_counts_tube_test['behavior_type'] = 'Tube Test'

# Calculate the fraction of winning for pre tube test
chaser_counts_pre_tube_test = pre_tube_test['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_pre_tube_test.columns = ['dominant_id', 'p_wins']
chaser_counts_pre_tube_test['behavior_type'] = 'Pre Tube Test'

# Calculate the fraction of winning for post tube test
chaser_counts_post_tube_test = post_tube_test['dominant_id'].value_counts(normalize=True).reset_index()
chaser_counts_post_tube_test.columns = ['dominant_id', 'p_wins']
chaser_counts_post_tube_test['behavior_type'] = 'Post Tube Test'

# Combine the data
combined_counts = pd.concat([
    chaser_counts_chasing, 
    chaser_counts_tube_test, 
    chaser_counts_pre_tube_test, 
    chaser_counts_post_tube_test
])

# Create the scatter plot
fig = px.scatter(
    combined_counts,
    x='behavior_type',
    y='p_wins',
    color='dominant_id',
    title='Dominance per Behavior Type and Dominant ID',
    labels={'behavior_type': 'Behavior Type', 'p_wins': 'Proportion of wins'},
    color_discrete_map=id_color_map
)

fig.update_yaxes(range=[0, 1])  # Example date range

# Show the plot
fig.show()

In [None]:
# calculate te number of winning events over all events for every 42h time bin and plot over time
# Create a list of time bins
time_bins = pd.date_range(start=social_start, end=social_end, freq='24h')

# Bin the start_timestamp into the created time bins
tube_test_df['time_bin'] = pd.cut(tube_test_df['start_timestamp'], bins=time_bins)

# Calculate the number of winning events and total events in each time bin and id
winning_counts = tube_test_df.groupby(['time_bin', 'dominant_id']).size().reset_index(name='total_events')

# Calculate the fraction of winning events for each dominant_id in each time bin
winning_counts['fraction_winning'] = winning_counts.groupby('time_bin')['total_events'].transform(lambda x: x / x.sum())

# Create a new column for day numbers starting from 1
winning_counts['day_number'] = winning_counts['time_bin'].cat.codes + 1

# Create the line plot for total events
fig = px.line(
    winning_counts,
    x='day_number',
    y='total_events',
    color='dominant_id',
    title='Winning Events Over Time (Tube test)',
    labels={'day_number': 'Day Number', 'total_events': 'Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create the line plot for fraction of winning events
fig = px.line(
    winning_counts,
    x='day_number',
    y='fraction_winning',
    color='dominant_id',
    title='Fraction of Winning Events Over Time (Tube test)',
    labels={'day_number': 'Day Number', 'fraction_winning': 'Proportion of Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create a list of time bins
time_bins = pd.date_range(start=social_start, end=social_end, freq='24h')

# Bin the start_timestamp into the created time bins
chasing_df['time_bin'] = pd.cut(chasing_df['start_timestamp'], bins=time_bins)

# Calculate the number of winning events and total events in each time bin and id
winning_counts = chasing_df.groupby(['time_bin', 'dominant_id']).size().reset_index(name='total_events')

# Calculate the fraction of winning events for each dominant_id in each time bin
winning_counts['fraction_winning'] = winning_counts.groupby('time_bin')['total_events'].transform(lambda x: x / x.sum())

# Create a new column for day numbers starting from 1
winning_counts['day_number'] = winning_counts['time_bin'].cat.codes + 1

# Create the line plot for total events
fig = px.line(
    winning_counts,
    x='day_number',
    y='total_events',
    color='dominant_id',
    title='Winning Events Over Time (Chasing)',
    labels={'day_number': 'Day Number', 'total_events': 'Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create the line plot for fraction of winning events
fig = px.line(
    winning_counts,
    x='day_number',
    y='fraction_winning',
    color='dominant_id',
    title='Fraction of Winning Events Over Time (Chasing)',
    labels={'day_number': 'Day Number', 'fraction_winning': 'Proportion of Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

In [None]:
# calculate te number of winning events over all events for every 42h time bin and plot over time
# Create a list of time bins
time_bins = pd.date_range(start=social_start, end=social_end, freq='24h')
# Bin the start_timestamp into the created time bins
tube_test_df['time_bin'] = pd.cut(tube_test_df['start_timestamp'], bins=time_bins)
# Calculate the number of winning events and total events in each time bin and id
winning_counts = tube_test_df.groupby(['time_bin', 'dominant_id']).size().reset_index(name='total_events')
# Calculate the fraction of winning events for each dominant_id in each time bin
winning_counts['fraction_winning'] = winning_counts.groupby('time_bin')['total_events'].transform(lambda x: x / x.sum())

# Convert the time_bin column to strings
winning_counts['time_bin'] = winning_counts['time_bin'].astype(str)


# Create the line plot
fig = px.line(
    winning_counts,
    x='time_bin',
    y='total_events',
    color = 'dominant_id',
    title='Winning Events Over Time (Tube test)',
    labels={'time_bin': 'Time Bin', 'total_events': 'Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

# Create the line plot
fig = px.line(
    winning_counts,
    x='time_bin',
    y='fraction_winning',
    color = 'dominant_id',
    title='Fraction of Winning Events Over Time (Tube test)',
    labels={'time_bin': 'Time Bin', 'fraction_winning': 'Proportion of Winning Events'},
    color_discrete_map=id_color_map
)

# Show the plot
fig.show()

## 1.1 Chasing speed and dominance

In [None]:
#TODO: we should do id swapping cleaning bfore calculating this (hope is that it doesnt matter much cuz two animals are pretty close to each other and synchronised)

In [None]:
# get this data if already calcuklated, if not calcualte here, but will ake a while
chasing_speed_data = pd.read_csv(base_path + '/all_chasing_videos/chasing_speed.csv')
chasing_speed_data

if chasing_speed_data.empty:
    print('No chasing speed data available')
    #get centroid data during chasing events
    block_subjects = (
        BlockAnalysis.Subject.proj('position_x', 'position_y', 'position_timestamps')
        & key
        & f'block_start >= "{social_start}"'
        & f'block_start <= "{social_end}"'
    )
    block_subjects_dict = block_subjects.fetch(as_dict=True)
    
    # Convert chasing_df timestamps to datetime
    chasing_df['start_timestamp'] = pd.to_datetime(chasing_df['start_timestamp'])
    chasing_df['end_timestamp'] = pd.to_datetime(chasing_df['end_timestamp'])

    # Initialize an empty list to store the filtered position data
    filtered_positions = []

    # Loop through each row in chasing_df so we only get position that are within a chase
    for _, row in chasing_df.iterrows():
        start_time = row['start_timestamp']
        end_time = row['end_timestamp']
        
        # Filter the position data for each subject
        for s in block_subjects_dict:
            # Convert position timestamps to datetime
            position_timestamps = pd.to_datetime(s['position_timestamps'])
            
            # Filter based on the start and end timestamps
            mask = (position_timestamps >= start_time) & (position_timestamps <= end_time)
            
            # Create a DataFrame for the filtered data
            filtered_data = pd.DataFrame(
        {
                    "subject_name": [s["subject_name"]] * sum(mask),
                    "position_timestamps": position_timestamps[mask],
                    "position_x": pd.Series(s["position_x"])[mask].values,
                    "position_y": pd.Series(s["position_y"])[mask].values,
                    "start_time": [start_time] * sum(mask),
                    "end_time": [end_time] * sum(mask)
                }
            )
            
            # Append the filtered data to the list
            filtered_positions.append(filtered_data)

    # Concatenate the list of filtered data into a single DataFrame
    subjects_positions_df = pd.concat(filtered_positions)
    
    # Calculate the avg speed of each chasing event
    # calculate the distance travelled by each subject and divide by the duration of chase
    subjects_positions_df["speed"] = (
        subjects_positions_df.groupby("subject_name")[["position_x", "position_y"]].diff().apply(np.linalg.norm, axis=1)
        / subjects_positions_df.reset_index()
        .groupby("subject_name")["position_timestamps"]
        .diff()
        .dt.total_seconds()
        .values
    )
    chasing_speed_data = subjects_positions_df

In [None]:
# get rid of nans, unrealistically high speeds and infs
cm2px = 5.4 
max_speed_threshold = 100 * cm2px # in cm/s
chasing_speed_data = chasing_speed_data[chasing_speed_data['speed'] < max_speed_threshold]
chasing_speed_data = chasing_speed_data[~chasing_speed_data['speed'].isna()]
chasing_speed_data = chasing_speed_data[~chasing_speed_data['speed'].isin([np.inf, -np.inf])]


In [None]:
# get average speed per chase
# Step 1: Group speed_df by 'chase_id' and 'subject_name' to calculate avg speed per subject per chase
avg_speed_per_subject = (
    chasing_speed_data.groupby(['start_time', 'subject_name'])['speed']
    .mean()
    .reset_index()
    .rename(columns={'speed': 'avg_speed_per_subject'})
)

# Step 2: Now, group again by 'chase_id' to average over the subjects' average speeds
avg_speed_per_chase = (
    avg_speed_per_subject.groupby('start_time')['avg_speed_per_subject']
    .mean()
    .reset_index()
    .rename(columns={'avg_speed_per_subject': 'avg_speed_per_chase'})
)
avg_speed_per_chase.rename(columns={'start_time': 'start_timestamp'}, inplace=True)

#merge back inot chase df
chasing_df['start_timestamp'] = pd.to_datetime(chasing_df['start_timestamp'])
avg_speed_per_chase['start_timestamp'] = pd.to_datetime(avg_speed_per_chase['start_timestamp'])
avg_chasing_speed_df = chasing_df.merge(avg_speed_per_chase, on='start_timestamp', how='left')
#get rid of rows whre missings speed
avg_chasing_speed_df = avg_chasing_speed_df[~avg_chasing_speed_df['avg_speed_per_chase'].isna()]

In [None]:
#plot the speed of chases per chaser id
# Initialize the figure
fig = go.Figure()

for dominant_id in avg_chasing_speed_df['dominant_id'].unique():
    if pd.isna(dominant_id):
        color = 'gray'  # Assign a default color for NaN values
    else:
        color = id_color_map.get(dominant_id, 'black')  # Fallback to black if the id is not in the map
    
    dominant_df = avg_chasing_speed_df[avg_chasing_speed_df['dominant_id'] == dominant_id]
    fig.add_trace(go.Box(
        x=dominant_df['dominant_id'],
        y=dominant_df['avg_speed_per_chase'],
        name=str(dominant_id),  # Convert NaN to string to display properly
        boxpoints='all',  # Show individual points
        jitter=0.3,  # Add some jitter to avoid overlap
        pointpos=-1.8,  # Position of the individual points (to the right of the box)
        marker=dict(color=color)
    ))
# Update layout
fig.update_layout(
    title='Average Speed of Chases per Chaser ID',
    xaxis_title='Chaser ID',
    yaxis_title='Average Speed',
    boxmode='group',  # Group box plots by chaser ID
    showlegend=False
)

# Filter data for the two chaser IDs
chaser1_speeds = avg_chasing_speed_df[avg_chasing_speed_df['dominant_id'] == unique_ids[0]]['avg_speed_per_chase']
chaser2_speeds = avg_chasing_speed_df[avg_chasing_speed_df['dominant_id'] == unique_ids[1]]['avg_speed_per_chase']

# Perform the t-test
t_stat, p_value_ttest = stats.ttest_ind(chaser1_speeds, chaser2_speeds, equal_var=False)


# Add annotations for statistical test results
fig.add_annotation(
    x=0.5, y=1.05, xref='paper', yref='paper',
    text=f"T-test p-value: {p_value_ttest:.3f}",
    showarrow=False,
    font=dict(size=12, color='black'),
    align='center'
)

# Show the plot
fig.show()

# 2. Dominance and foraging

In [None]:
#assign dominant id
# Calculate the frequency of each dominant_id in combined_df
dominant_id_counts = combined_df['dominant_id'].value_counts()
# Identify the dominant_id with the highest frequency
dominant_id = dominant_id_counts.idxmax()
subordinate_id = unique_ids[0] if unique_ids[0] != dominant_id else unique_ids[1]


Get foraging data:

In [None]:
#get pellet data
foraging_query = (
    BlockSubjectAnalysis.Patch.proj('pellet_count', 'pellet_timestamps', 'patch_threshold')
    & {"spinnaker_video_source_name": "CameraTop"} #this is the video source name which we rstrict once we selected tuff to keep in table
    & key)

# Fetch the data
pellet_data = foraging_query.fetch()

# Initialize an empty list to store the data
data = []

# Loop through each entry in the pellet_data array
for entry in pellet_data:
    experiment_name = entry[0]
    block_start = entry[1]
    patch_name = entry[2]
    subject_id = entry[3]
    pellet_count = entry[4]
    pellet_timestamps = entry[5]
    patch_threshold =  entry[6]

    
    # For each pellet timestamp, create a dictionary and append to the list
    for pellet_timestamp, threshold in zip(pellet_timestamps, patch_threshold):
        pellet_timestamp = pd.to_datetime(pellet_timestamp)
        # Determine the period based on the timestamp
        if pre_solo_start <= pellet_timestamp <= pre_solo_end:
            period = 'pre_solo'
        elif social_start <= pellet_timestamp <= social_end:
            period = 'social'
        elif post_solo_start <= pellet_timestamp <= post_solo_end:
            period = 'post_solo'
        else:
            ValueError(f"Timestamp {pellet_timestamp} does not fall within any period")

        data.append({
            'time': pellet_timestamp,
            'subject_id': subject_id,
            'threshold': threshold,
            'rank': 'dominant' if subject_id == dominant_id else 'subordinate',
            'period': period
        })
# Convert the list of dictionaries into a DataFrame
pellet_df = pd.DataFrame(data)

# Convert the 'time' column to datetime
pellet_df['time'] = pd.to_datetime(pellet_df['time'])

# data cleaning: filter out pellets that are less than 2s after previos pellet for teh subject
pellet_df['time_diff'] = pellet_df.groupby('subject_id')['time'].diff()
pellet_df['time_diff'] = pellet_df['time_diff'].dt.total_seconds()
pellet_df = pellet_df[pellet_df['time_diff'] > 2]

# Display the DataFrame
print(pellet_df.head())

In [None]:
def create_raster_plot(df, id_color_map):
    # Create subplots: two for pre_solo, two for post_solo, and one for social
    fig = make_subplots(
        rows=5, cols=1, shared_xaxes=False,
        subplot_titles=(f'Pre Solo - {unique_ids[0]}', f'Pre Solo - {unique_ids[1]}', f'Post Solo - {unique_ids[0]}', f'Post Solo - {unique_ids[1]}', 'Social')
    )

    # Filter the DataFrame for pre_solo, post_solo, and social periods
    pre_solo_df = df[df['period'] == 'pre_solo']
    post_solo_df = df[df['period'] == 'post_solo']
    social_df = df[df['period'] == 'social']

    # Create scatter plots for each period if the DataFrame is not empty
    if not pre_solo_df.empty:
        for subject_id in pre_solo_df['subject_id'].unique():
            subject_df = pre_solo_df[pre_solo_df['subject_id'] == subject_id]
            pre_solo_scatter = px.scatter(
                subject_df,
                x='time',
                y='subject_id',
                color='subject_id',
                labels={'time': 'Time'},
                color_discrete_map=id_color_map
            ).data[0]
            pre_solo_scatter.showlegend = False
            row = 1 if subject_id == unique_ids[0] else 2
            fig.add_trace(pre_solo_scatter, row=row, col=1)

    if not post_solo_df.empty:
        for subject_id in post_solo_df['subject_id'].unique():
            subject_df = post_solo_df[post_solo_df['subject_id'] == subject_id]
            post_solo_scatter = px.scatter(
                subject_df,
                x='time',
                y='subject_id',
                color='subject_id',
                labels={'time': 'Time'},
                color_discrete_map=id_color_map
            ).data[0]
            post_solo_scatter.showlegend = False
            row = 3 if subject_id == unique_ids[0] else 4
            fig.add_trace(post_solo_scatter, row=row, col=1)
        
    if not social_df.empty:
        for subject_id in social_df['subject_id'].unique():
            subject_df = social_df[social_df['subject_id'] == subject_id]
            social_scatter = px.scatter(
                subject_df,
                x='time',
                y='subject_id',
                color='subject_id',
                labels={'time': 'Time'},
                color_discrete_map=id_color_map
            ).data[0]
            fig.add_trace(social_scatter, row=5, col=1)

    # Update layout
    fig.update_layout(
        title_text='Pellet Raster Plots',
        showlegend=True
    )

    # Remove y-axis labels
    for i in range(1, 6):
        fig.update_yaxes(title_text='', row=i, col=1)

    # Show the plot
    fig.show()

# Create the raster plot for pre_solo, post_solo, and social periods
create_raster_plot(pellet_df, id_color_map)