# New pipeline

In [None]:
import matplotlib.pyplot as plt
import cv2


import scipy

from pathlib import Path

import numpy as np

import h5py
import math


import pandas as pd

import holoviews as hv

import platform

import sys
sys.path.insert(0, "..")

from pathlib import Path

import cv2

import json


from Utilities.Utils import *
from Utilities.Processing import *

In [None]:
# Get the DataFolder

if platform.system() == "Darwin":
    DataPath = Path("/Volumes/Ramdya-Lab/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos")
# Linux Datapath
if platform.system() == "Linux":
    DataPath = Path("/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos")

print(DataPath)

In [None]:
# Make a list of the folders I want to use
# For instance, I want to use the folders that have the "FeedingState" in the name

Folders = []
for folder in DataPath.iterdir():
    minfolder = str(folder).lower()
    #if "tnt" in minfolder and "tracked" in minfolder and "pm" in minfolder:
    if 'feedingstate' in minfolder and 'pm' in minfolder:
        Folders.append(folder)

Folders
    

In [None]:
# Build a dataframe that will store the ball y positions and the arena and corridor numbers as metadata

Dataset = pd.DataFrame(columns=["Fly", "yball", "arena", "corridor"])

# Loop over all the .analysis.h5 files in the folder and store the ball y positions and the arena and corridor numbers as metadata

Flynum = 0
# Loop over all the foldes that don't have "Dark" in the name
for folder in Folders:
    # Read the metadata.json file
    with open(folder / "Metadata.json", "r") as f:
        metadata = json.load(f)
        variables = metadata["Variable"]
        metadata_dict = {}
        for var in variables:
            metadata_dict[var] = {}
            for arena in range(1, 10):
                arena_key = f"Arena{arena}"
                var_index = variables.index(var)
                metadata_dict[var][arena_key] = metadata[arena_key][var_index]
        
        print (metadata_dict)
        
    for file in folder.glob("**/*.analysis.h5"):
        #print(file)
        with h5py.File(file, "r") as f:
            dset_names = list(f.keys())
            locations = f["tracks"][:].T
            node_names = [n.decode() for n in f["node_names"][:]]

        locations.shape
        
        if "Flipped" in folder.name:
            yball[:, 0, 0] = -yball[:, 0, 0]

        else:
            yball : np.ndarray = locations[:, :, 1, :]
        
        # Get the filename from the path
        foldername = folder.name

        # Get the arena and corridor numbers from the parent (corridor) and grandparent (arena) folder names
        arena = file.parent.parent.name
        corridor = file.parent.name
        
        # Get the metadata for this arena
        arena_key = arena.capitalize()
        arena_metadata = {var: metadata_dict[var][arena_key] for var in metadata_dict}
        
        Flynum += 1
        
        # Store the ball y positions and the arena and corridor numbers as metadata
        data = {"Fly": "Fly" + str(Flynum),
                "yball": yball[:, 0, 0], 
                "experiment": foldername,
                "arena": arena, 
                "corridor": corridor}
        data.update(arena_metadata)
        Dataset = Dataset.append(data, ignore_index=True).reset_index(drop=True)



In [None]:
# Unpack yball positions

Dataset = Dataset.explode("yball")

Dataset['Frame'] = Dataset.groupby('Fly').cumcount()

Dataset['time'] = Dataset['Frame'] / 30
#DataFrame['time'] = DataFrame.groupby(['experiment', 'arena', 'corridor']).cumcount() / 30

#DataFrame['Fly'] = 'Fly' + (DataFrame.groupby(['experiment', 'arena', 'corridor']).ngroup() + 1).astype(str)

#Dataset.reset_index(drop=True, inplace=True)

Dataset.head()

In [None]:
# Replace all occurrences of "Fed" with "fed" in the 'FeedingState' column
Dataset['FeedingState'] = Dataset['FeedingState'].replace('Fed', 'fed')

# If there is 'Flipped' in the foldername, replace the correspondint 'Orientation' with 'flipped'


In [None]:
Dataset.loc[Dataset['experiment'].str.contains('Flipped'), 'Orientation'] = 'flipped'

In [None]:
savepath = Path("/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Datasets")
checksave(
    path=savepath.joinpath("230821_TNTScreen_4exps.feather"),
    object="dataframe",
    file=Dataset,
)

In [None]:
savepath = Path("/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Datasets")



In [None]:
Dataset.head()

In [None]:
# Compute the maximum time value for each fly
max_time = Dataset.groupby("Fly")["time"].max()

# Compute the number of rows for each fly
num_rows = Dataset.groupby("Fly").size()

# Display the results
print("Maximum time value for each fly:")
print(max_time)
print()
print("Number of rows for each fly:")
print(num_rows)


In [None]:
# print each possible value of 'Fly'
print(Dataset['Fly'].unique())

In [None]:
#LightDataset = Dataset[Dataset['Light'] == 'on']

GroupedDF = Dataset.groupby(['Genotype','time',])['yball'].mean().reset_index()

GroupedDF.head()


In [None]:
# Get all unique values of the column FeedingState
feeding_states = GroupedDF['FeedingState'].unique()

print(feeding_states)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Create a FacetGrid object with the 'Period' column as the row variable
g = sns.FacetGrid(data=GroupedDF, row='Period')

# Map a line plot of the 'yball' column over time to each facet, with the hue set to 'FeedingState'
g.map(sns.lineplot, 'time', 'yball', 'FeedingState')

# Add a legend to the plot
g.add_legend()

# Invert the y-axis of each Axes object in the FacetGrid
for ax in g.axes.flat:
    ax.invert_yaxis()

# Show the plot
plt.show()


In [None]:
GroupedDF.head()

In [None]:
import matplotlib
print(sns.__version__)
print(matplotlib.__version__)

In [None]:
import matplotlib.pyplot as plt

# Create a figure and axes
fig, axes = plt.subplots(nrows=2, sharex=True)

# Plot the data for each period on a separate axis
for i, period in enumerate(GroupedDF['Period'].unique()):
    data = GroupedDF[GroupedDF['Period'] == period]
    for feeding_state in data['FeedingState'].unique():
        subset = data[data['FeedingState'] == feeding_state]
        axes[i].plot(subset['time'], subset['yball'], label=feeding_state)
    axes[i].set_title(period)
    axes[i].invert_yaxis()

# Add a legend to the first axis
axes[0].legend()

# Show the plot
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Filter the data to only include rows where 'Light' == 'on'
GroupedDF = GroupedDF[GroupedDF['Light'] == 'on']

# Create a figure and axes
fig, axes = plt.subplots(nrows=2, sharex=True)

# Plot the data for each period on a separate axis
for i, period in enumerate(GroupedDF['Period'].unique()):
    data = GroupedDF[GroupedDF['Period'] == period]
    for feeding_state in data['FeedingState'].unique():
        subset = data[data['FeedingState'] == feeding_state]
        axes[i].plot(subset['time'], subset['yball'], label=feeding_state)
    axes[i].set_title(period)
    axes[i].invert_yaxis()

# Add a legend to the first axis
axes[0].legend()

# Show the plot
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Create a figure and axes
fig, ax = plt.subplots(figsize=(10, 6))

# Define colors for each feeding state and light combination
colors = {('fed', 'on'): 'C0', ('fed', 'off'): 'lightblue',
          ('starved', 'on'): 'C1', ('starved', 'off'): 'lightgreen',
          ('starved_noWater', 'on'): 'C2', ('starved_noWater', 'off'): 'pink'}

# Plot the data for the PM period
data = GroupedDF[GroupedDF['Period'] == 'PM']
for feeding_state in data['FeedingState'].unique():
    subset = data[data['FeedingState'] == feeding_state]
    for light in subset['Light'].unique():
        subsubset = subset[subset['Light'] == light]
        linestyle = '-' if light == 'on' else '-'
        label = f'{feeding_state} - Light {light}'
        color = colors[(feeding_state, light)]
        ax.plot(subsubset['time'], subsubset['yball'], linestyle=linestyle, color=color, label=label)

# Set the title
ax.set_title('PM')
ax.invert_yaxis()

# Add a legend to the axis
ax.legend()

# Show the plot
plt.show()


In [None]:
GroupedDF.head()

In [None]:
from scipy import stats
import numpy as np

# Define a function to compute the confidence interval for a given array of values
def confint(x, alpha=0.05):
    # Check if the input array contains at least two values
    if len(x) < 2:
        # If not, return a tuple containing two nan values
        return (np.nan, np.nan)
    
    # Compute the mean and standard error of the mean
    mean = np.mean(x)
    sem = stats.sem(x)
    
    # Compute the confidence interval
    ci = stats.t.interval(1 - alpha, len(x) - 1, loc=mean, scale=sem)
    
    return ci

# Apply the confint function to each group of rows in your original dataframe
confint_df = DataFrame.groupby(['Period', 'time'])['yball'].apply(confint).reset_index()

# Rename the columns of the resulting dataframe
confint_df.columns = ['Period', 'time', 'yball_lower', 'yball_upper']

# Merge the resulting dataframe with your grouped dataframe
GroupedDF = pd.merge(GroupedDF, confint_df, on=['Period', 'time'], how='left')


In [None]:
# Create a line plot of the 'yball' column over time, grouped by the 'Period' column
sns.lineplot(data=GroupedDF, x='time', y='yball', hue='Period')

# Add the confidence intervals to the plot
for period, group in GroupedDF.groupby('Period'):
    plt.fill_between(group['time'], group['yball_lower'], group['yball_upper'], alpha=0.1)

# Show the plot
plt.show()

In [None]:
# Mutants

GroupedDF_TNT = Dataset.groupby(['Genotype','time',])['yball'].mean().reset_index()


In [None]:
# Calculate the sample size for each Genotype group based on the number of individual flies
sample_size = Dataset.groupby('Genotype')['Fly'].nunique()



In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Calculate the sample size for each Genotype group based on the number of individual flies
sample_size = Dataset.groupby('Genotype')['Fly'].nunique()

# Set the figure size
plt.figure(figsize=(12, 6))

# Create a line plot of the 'yball' column over time, colored by the 'Genotype' column
sns.lineplot(data=GroupedDF, x='time', y='yball', hue='Genotype', linewidth=1)

# Invert the y-axis
plt.gca().invert_yaxis()

# Modify the labels of the legend to include the sample size for each Genotype group
legend = plt.legend()
for text, genotype in zip(legend.texts, sample_size.index):
    text.set_text(f'{genotype} (n = {sample_size[genotype]})')

# Show the plot
plt.show()

# Confints

In [None]:
Confints_BS = Dataset.groupby(['Genotype','time'])['yball'].apply(lambda x: draw_bs_ci(x, n_reps=300))


In [None]:

Confints_BS_Process = Confints_BS.reset_index()

In [None]:
# Split values of Confints_process["cumulated_success"] into two columns ci_lower and ci_upper
Confints_BS_Process[["ci_lower", "ci_upper"]] = pd.DataFrame(
    Confints_BS_Process["yball"].tolist(), index=Confints_BS_Process.index
)

GroupedDF["ci_lower"] = Confints_BS_Process["ci_lower"]
GroupedDF["ci_upper"] = Confints_BS_Process["ci_upper"]

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Set the figure size
plt.figure(figsize=(12, 6))

# Create a line plot of the 'yball' column over time, colored by the 'Genotype' column
sns.lineplot(data=GroupedDF, x='time', y='yball', hue='Genotype', linewidth=1)

# Add the confidence intervals to the plot
for genotype, data in GroupedDF.groupby('Genotype'):
    plt.fill_between(data['time'], data['ci_lower'], data['ci_upper'], alpha=0.2)

# Invert the y-axis
plt.gca().invert_yaxis()

# Modify the labels of the legend to include the sample size for each Genotype group
legend = plt.legend()
for text, genotype in zip(legend.texts, sample_size.index):
    text.set_text(f'{genotype} (n = {sample_size[genotype]})')

# Show the plot
plt.show()


In [None]:
SubGroup = GroupedDF[GroupedDF['Genotype'].isin(['TNTxTH', 'TNTxE-PG', 'PR'])]

In [None]:
# Set the figure size
plt.figure(figsize=(12, 6))

# Create a line plot of the 'yball' column over time, colored by the 'Genotype' column
sns.lineplot(data=SubGroup, x='time', y='yball', hue='Genotype', linewidth=1)

# Add the confidence intervals to the plot
for genotype, data in SubGroup.groupby('Genotype'):
    plt.fill_between(data['time'], data['ci_lower'], data['ci_upper'], alpha=0.2)

# Invert the y-axis
plt.gca().invert_yaxis()

# Modify the labels of the legend to include the sample size for each Genotype group
legend = plt.legend()
for text, genotype in zip(legend.texts, sample_size.index):
    text.set_text(f'{genotype} (n = {sample_size[genotype]})')

# Show the plot
plt.show()


In [None]:
GroupedDF.head()

In [None]:
GroupedDF['Genotype'] = GroupedDF['Genotype'].astype('category')

In [None]:
savepath = Path("/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Datasets")
checksave(
    path=savepath.joinpath("230821_TNTScreen_4exps_GroupedDF.feather"),
    object="dataframe",
    file=Dataset,
)

# New Method for more efficient dataframes

In [None]:
# Make a list of the folders I want to use
# For instance, I want to use the folders that have the "FeedingState" in the name

Folders = []
for folder in DataPath.iterdir():
    minfolder = str(folder).lower()
    #if "tnt" in minfolder and "tracked" in minfolder and "pm" in minfolder:
    if 'feedingstate' in minfolder:
        Folders.append(folder)

Folders

In [None]:
# Build a dataframe that will store the ball y positions and the arena and corridor numbers as metadata
Dataset = pd.DataFrame(columns=["Fly", "yball", "arena", "corridor"])

# Loop over all the .analysis.h5 files in the folder and store the ball y positions and the arena and corridor numbers as metadata
Flynum = 0
# Loop over all the foldes that don't have "Dark" in the name
for folder in Folders:
    print(f"Adding experiment {folder} to the dataset...")
    # Read the metadata.json file
    with open(folder / "Metadata.json", "r") as f:
        metadata = json.load(f)
        variables = metadata["Variable"]
        metadata_dict = {}
        for var in variables:
            metadata_dict[var] = {}
            for arena in range(1, 10):
                arena_key = f"Arena{arena}"
                var_index = variables.index(var)
                metadata_dict[var][arena_key] = metadata[arena_key][var_index]
        
        print (metadata_dict)
        
    for file in folder.glob("**/*.analysis.h5"):
        #print(file)
        with h5py.File(file, "r") as f:
            dset_names = list(f.keys())
            locations = f["tracks"][:].T
            node_names = [n.decode() for n in f["node_names"][:]]

        locations.shape
        
        if "Flipped" in folder.name:
            yball[:, 0, 0] = -yball[:, 0, 0]

        else:
            yball : np.ndarray = locations[:, :, 1, :]
        
        # Get the filename from the path
        foldername = folder.name

        # Get the arena and corridor numbers from the parent (corridor) and grandparent (arena) folder names
        arena = file.parent.parent.name
        corridor = file.parent.name
        
        # Get the metadata for this arena
        arena_key = arena.capitalize()
        arena_metadata = {var: pd.Categorical([metadata_dict[var][arena_key]]) for var in metadata_dict}
        
        Flynum += 1
        
        # Store the ball y positions and the arena and corridor numbers as metadata
        data = {"Fly": pd.Categorical(["Fly" + str(Flynum)]),
                "yball": [list(yball[:, 0, 0])], 
                "experiment": pd.Categorical([foldername]),
                "arena": pd.Categorical([arena]), 
                "corridor": pd.Categorical([corridor])}
        data.update(arena_metadata)

        # Use pandas.concat instead of DataFrame.append
        Dataset = pd.concat([Dataset, pd.DataFrame(data)], ignore_index=True)

In [None]:
# Unpack yball positions

Dataset = Dataset.explode("yball")

Dataset['Frame'] = Dataset.groupby('Fly').cumcount()

Dataset['time'] = Dataset['Frame'] / 30

Dataset.reset_index(drop=True, inplace=True)

Dataset.head()

In [None]:
# Replace all occurrences of "Fed" with "fed" in the 'FeedingState' column
Dataset['FeedingState'] = Dataset['FeedingState'].replace('Fed', 'fed')

# If there is 'Flipped' in the foldername, replace the correspondint 'Orientation' with 'flipped'
Dataset.loc[Dataset['experiment'].str.contains('Flipped'), 'Orientation'] = 'flipped'

In [None]:
savepath = Path("/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Datasets")
checksave(
    path=savepath.joinpath("230822_FeedingState.feather"),
    object="dataframe",
    file=Dataset,
)

In [None]:
# Compute bootstrapped confidence intervals with multi-threading and a global progress bar
Confints_BS = Dataset.groupby(['Genotype','time'])['yball'].apply(draw_bs_ci)

# Detecting start and end of maze for each video

## Test with one video

In [None]:
Videopath = Path('/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos/230721_Feedingstate_4_PM_Videos_Tracked/arena1/corridor1/corridor1.mp4')

In [None]:
# open the first frame of the video
cap = cv2.VideoCapture(Videopath.as_posix())
ret, frame = cap.read()
cap.release()

if not ret:
    print("Error: Could not read frame from video")
elif frame is None:
    print("Error: Frame is None")
else:
    # Convert to grayscale
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    # Display the frame
    plt.imshow(frame)


In [None]:
# plot the summed pixel values for each row of pixels
rows = frame.sum(axis=1)

plt.plot(rows)

In [None]:
# Find the minimum of rows
min_row = rows.argmin()

# plot rows with the minimum row marked in red
plt.plot(rows)
plt.axvline(min_row, color='red')


In [None]:
# plot the frame with a nar drawn at the min location
plt.imshow(frame)
plt.axhline(min_row - 30, color='red')
plt.axhline(min_row - 320, color='blue')

In [None]:
# plot the frame but move the bars locations. For the first bar, move it to the right, for the second bar, move it to the left.

plt.imshow(frame)
for i, peak in enumerate(peaks):
    if i % 2 == 0:
        plt.axhline(peak + 50, color='red', alpha=0.5)
    else:
        plt.axhline(peak - 50, color='red', alpha=0.5)

# Test with all videos of an experiment

In [None]:
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import median_filter, gaussian_filter

# Set the path to the main folder
main_folder = Path('/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos/230721_Feedingstate_4_PM_Videos_Tracked')

# Create a list to store the frames and minimum row indices
frames = []
min_rows = []

# Set the threshold value
threshold = 100

# Recursively traverse the directory tree
for file in main_folder.rglob('*.mp4'):
    # Set the path to the video file
    Videopath = file
    
    # open the first frame of the video
    cap = cv2.VideoCapture(Videopath.as_posix())
    ret, frame = cap.read()
    cap.release()

    if not ret:
        print(f"Error: Could not read frame from video {Videopath}")
    elif frame is None:
        print(f"Error: Frame is None for video {Videopath}")
    else:
        # Convert to grayscale
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        
        # Apply a median filter to smooth out noise and small variations
        frame = median_filter(frame, size=3)
        
        # Apply a Gaussian filter to smooth out noise and small variations
        frame = gaussian_filter(frame, sigma=1)
        
        # Compute the summed pixel values and apply a threshold
        summed_pixel_values = frame.sum(axis=1)
        summed_pixel_values[summed_pixel_values < threshold] = 0
        
        # Find the index of the minimum value in the thresholded summed pixel values
        min_row = np.argmin(summed_pixel_values)
        
        # Store the frame and minimum row index
        frames.append(frame)
        min_rows.append(min_row)

# Set the number of rows and columns for the grid
nrows = 9
ncols = 6

# Create a figure with subplots
fig, axs = plt.subplots(nrows, ncols, figsize=(20, 20))

# Loop over the frames and minimum row indices
for i, (frame, min_row) in enumerate(zip(frames, min_rows)):
    # Get the row and column index for this subplot
    row = i // ncols
    col = i % ncols
    
    # Plot the frame on this subplot
    axs[row, col].imshow(frame, cmap='gray', vmin=0, vmax=255)
    
    # Plot the horizontal lines on this subplot
    axs[row, col].axhline(min_row, color='yellow')

    axs[row, col].axhline(min_row - 30, color='red')
    axs[row, col].axhline(min_row - 320, color='blue')

# Remove the axis of each subplot and draw them closer together
for ax in axs.flat:
    ax.axis("off")
plt.subplots_adjust(wspace=0, hspace=0)


In [None]:
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import median_filter, gaussian_filter

# Set the path to the main folder
main_folder = Path('/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos/230721_Feedingstate_4_PM_Videos_Tracked')

# Create a list to store the frames, minimum row indices, and video paths
frames = []
min_rows = []
video_paths = []

# Set the threshold value
threshold = 100

# Recursively traverse the directory tree
for file in main_folder.rglob('*.mp4'):
    # Set the path to the video file
    Videopath = file
    
    # open the first frame of the video
    cap = cv2.VideoCapture(Videopath.as_posix())
    ret, frame = cap.read()
    cap.release()

    if not ret:
        print(f"Error: Could not read frame from video {Videopath}")
    elif frame is None:
        print(f"Error: Frame is None for video {Videopath}")
    else:
        # Convert to grayscale
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        
        # Apply a median filter to smooth out noise and small variations
        frame = median_filter(frame, size=3)
        
        # Apply a Gaussian filter to smooth out noise and small variations
        frame = gaussian_filter(frame, sigma=1)
        
        # Compute the summed pixel values and apply a threshold
        summed_pixel_values = frame.sum(axis=1)
        summed_pixel_values[summed_pixel_values < threshold] = 0
        
        # Find the index of the minimum value in the thresholded summed pixel values
        min_row = np.argmin(summed_pixel_values)
        
        # Store the frame, minimum row index, and video path
        frames.append(frame)
        min_rows.append(min_row)
        video_paths.append(Videopath)

# Set the number of rows and columns for the grid
nrows = 9
ncols = 6

# Create a figure with subplots
fig, axs = plt.subplots(nrows, ncols, figsize=(20, 20))

# Loop over the frames, minimum row indices, and video paths
for i, (frame, min_row, Videopath) in enumerate(zip(frames, min_rows, video_paths)):
    # Get the row and column index for this subplot
    row = i // ncols
    col = i % ncols
    
    # Plot the frame on this subplot
    axs[row, col].imshow(frame, cmap='gray', vmin=0, vmax=255)
    
    # Plot the horizontal lines on this subplot
    axs[row, col].axhline(min_row - 30, color='red')
    axs[row, col].axhline(min_row - 320, color='blue')
    
    # Save a .npy file with the start and end coordinates in the video folder
    np.save(Videopath.parent / 'coordinates.npy', [min_row - 30, min_row - 320])

# Remove the axis of each subplot and draw them closer together
for ax in axs.flat:
    ax.axis("off")
plt.subplots_adjust(wspace=0, hspace=0)

# Save the grid image in the main folder
plt.savefig(main_folder / 'coordinates_grid.png')


# Data import with relative yball computation

In [None]:
import json
import h5py
import numpy as np
import pandas as pd
from pathlib import Path

# Set the path to the data folder
data_folder = Path('/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos')

# Build a dataframe that will store the ball y positions and the arena and corridor numbers as metadata
Dataset = pd.DataFrame(columns=["Fly", "yball", "arena", "corridor", "start", "end"])

# Loop over all the .analysis.h5 files in the folder and store the ball y positions and the arena and corridor numbers as metadata
Flynum = 0
# Loop over all the folders that don't have "Dark" in the name
for folder in Folders:
    print(f"Adding experiment {folder} to the dataset...")
    # Read the metadata.json file
    with open(folder / "Metadata.json", "r") as f:
        metadata = json.load(f)
        variables = metadata["Variable"]
        metadata_dict = {}
        for var in variables:
            metadata_dict[var] = {}
            for arena in range(1, 10):
                arena_key = f"Arena{arena}"
                var_index = variables.index(var)
                metadata_dict[var][arena_key] = metadata[arena_key][var_index]
        
        print (metadata_dict)
        
    for file in folder.glob("**/*.analysis.h5"):
        #print(file)
        with h5py.File(file, "r") as f:
            dset_names = list(f.keys())
            locations = f["tracks"][:].T
            node_names = [n.decode() for n in f["node_names"][:]]

        locations.shape
        
        if "Flipped" in folder.name:
            yball[:, 0, 0] = -yball[:, 0, 0]

        else:
            yball : np.ndarray = locations[:, :, 1, :]
        
        # Get the filename from the path
        foldername = folder.name

        # Get the arena and corridor numbers from the parent (corridor) and grandparent (arena) folder names
        arena = file.parent.parent.name
        corridor = file.parent.name
        
        # Get the metadata for this arena
        arena_key = arena.capitalize()
        arena_metadata = {var: pd.Categorical([metadata_dict[var][arena_key]]) for var in metadata_dict}
        
        Flynum += 1
        
        # Load the start and end coordinates from coordinates.npy
        start, end = np.load(file.parent / 'coordinates.npy')
        
        # Store the ball y positions, start and end coordinates, and the arena and corridor numbers as metadata
        data = {"Fly": pd.Categorical(["Fly" + str(Flynum)]),
                "yball": [list(yball[:, 0, 0])], 
                "experiment": pd.Categorical([foldername]),
                "arena": pd.Categorical([arena]), 
                "corridor": pd.Categorical([corridor]),
                "start": pd.Categorical([start]),
                "end": pd.Categorical([end])}
        data.update(arena_metadata)

        # Use pandas.concat instead of DataFrame.append
        Dataset = pd.concat([Dataset, pd.DataFrame(data)], ignore_index=True) 

# Explode yball column to have one row per timepoint
Dataset = Dataset.explode('yball')
Dataset['yball'] = Dataset['yball'].astype(float)

# Compute yball_relative relative to start
Dataset['yball_relative'] = abs(Dataset['yball'] - Dataset['start'])

Dataset["Frame"] = Dataset.groupby("Fly").cumcount()

Dataset["time"] = Dataset["Frame"] / 30

Dataset.reset_index(drop=True, inplace=True)

Dataset.head()

In [None]:
# Replace all occurrences of "Fed" with "fed" in the 'FeedingState' column
Dataset["FeedingState"] = Dataset["FeedingState"].replace("Fed", "fed")

# Add "flipped" to the list of categories for the Orientation column
Dataset['Orientation'] = Dataset['Orientation'].cat.add_categories(['flipped'])

# If there is 'Flipped' in the foldername, replace the corresponding 'Orientation' with 'flipped'
Dataset.loc[Dataset["experiment"].str.contains("Flipped"), "Orientation"] = "flipped"


In [None]:
GroupedDF_TNT = Dataset.groupby(['Genotype','time',])['yball_relative'].mean().reset_index()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Calculate the sample size for each Genotype group based on the number of individual flies
sample_size = Dataset.groupby('Genotype')['Fly'].nunique()

# Set the figure size
plt.figure(figsize=(12, 6))

# Create a line plot of the 'yball' column over time, colored by the 'Genotype' column
sns.lineplot(data=GroupedDF_TNT, x='time', y='yball_relative', hue='Genotype', linewidth=1)

# Invert the y-axis
#plt.gca().invert_yaxis()

# Modify the labels of the legend to include the sample size for each Genotype group
legend = plt.legend()
for text, genotype in zip(legend.texts, sample_size.index):
    text.set_text(f'{genotype} (n = {sample_size[genotype]})')

# Show the plot
plt.show()

In [None]:
# Compute the difference between consecutive yball_relative values for each fly
Dataset['yball_relative_diff'] = Dataset.groupby('Fly')['yball_relative'].diff()

# Compute the cumulative_push and cumulative_pull for each fly
Dataset['cumulative_push'] = Dataset.apply(lambda x: x['yball_relative_diff'] if x['yball_relative_diff'] > 0 else 0, axis=1).groupby(Dataset['Fly']).cumsum()
Dataset['cumulative_pull'] = Dataset.apply(lambda x: -x['yball_relative_diff'] if x['yball_relative_diff'] < 0 else 0, axis=1).groupby(Dataset['Fly']).cumsum()
