In [None]:
import matplotlib.pyplot as plt
import cv2


import scipy

from pathlib import Path

import numpy as np

import h5py
import math


import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

#import datashader as ds


import holoviews as hv
from holoviews import opts
from holoviews.plotting.util import process_cmap
from bokeh.palettes import Viridis256
from holoviews.operation.datashader import datashade, shade, dynspread
hv.extension('bokeh')


import platform

import sys
sys.path.insert(0, "..")
sys.path.insert(0, "../../..")


from pathlib import Path

import cv2

import json

from bokeh.models import ColumnDataSource
from bokeh.plotting import figure, show
from bokeh.palettes import Spectral11
from bokeh.io import output_notebook
import iqplot
import bokeh.io
bokeh.io.output_notebook()
from bokeh.plotting import figure, show
from bokeh.models import LinearColorMapper, ColorBar

from Utilities.Utils import *
from Utilities.Processing import *
from Utilities.Ballpushing_utils import *

# Data loading

In [None]:
# Get the DataFolder

if platform.system() == "Darwin":
    DataPath = Path("/Volumes/Ramdya-Lab/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos")
# Linux Datapath
if platform.system() == "Linux":
    DataPath = Path("/mnt/labserver/DURRIEU_Matthias/Experimental_data/MultiMazeRecorder/Videos")

print(DataPath)

In [None]:
Folders = []
for folder in DataPath.iterdir():
    minfolder = str(folder).lower()
    if "tnt" in minfolder and "tracked" in minfolder and "pm" in minfolder:
        Folders.append(folder)

Folders

In [None]:
SavePath = Path("/mnt/labserver/DURRIEU_Matthias/Pictures/RasterPlots/")


In [None]:
from Utilities.Ballpushing_utils import *

Dataset_list = []
Flycount = 0

for folder in Folders:
    #print(f"Processing {folder}...")
    # Read the metadata.json file
    with open(folder / "Metadata.json", "r") as f:
        metadata = json.load(f)
        variables = metadata["Variable"]
        metadata_dict = {}
        for var in variables:
            metadata_dict[var] = {}
            for arena in range(1, 10):
                arena_key = f"Arena{arena}"
                var_index = variables.index(var)
                metadata_dict[var][arena_key] = metadata[arena_key][var_index]

        # In the metadata_dict, make all they Arena subkeys lower case

        for var in variables:
            metadata_dict[var] = {k.lower(): v for k, v in metadata_dict[var].items()}
        #print(metadata_dict)

        files = list(folder.glob("**/*.mp4"))

    for file in files:
        #print(file.name)
        # Get the arena and corridor numbers from the parent (corridor) and grandparent (arena) folder names
        arena = file.parent.parent.name
        # print(arena)
        corridor = file.parent.name

        # Get the Genotype and Dates from the metadata, arena should have a upper case first letter

        Genotype = metadata_dict["Genotype"][arena]
        #print(f"Genotype: {Genotype} for arena {arena}")

        Date = metadata_dict["Date"][arena]
        # print(f"Date: {Date} for arena {arena}")

        Light = metadata_dict["Light"][arena]
        FeedingState = metadata_dict["FeedingState"][arena]
        Period = metadata_dict["Period"][arena]

        start, end = np.load(file.parent / 'coordinates.npy')
        
        dir = file.parent

        # Define flypath as the *tracked_fly*.analysis.h5 file in the same folder as the video
        try:
            flypath = list(dir.glob("*tracked_fly*.analysis.h5"))[0]
            #print(flypath.name)
        except IndexError:
            #print(f"No fly tracking file found for {file.name}, skipping...")
            
            continue

        # Define ballpath as the *tracked*.analysis.h5 file in the same folder as the video
        try:
            ballpath = list(dir.glob("*tracked*.analysis.h5"))[0]
            #print(ballpath.name)
        except IndexError:
            #print(f"No ball tracking file found for {file.name}, skipping...")
            
            continue

        vidpath = file
        vidname = f"{Genotype}_{Date}_Light_{Light}_{FeedingState}_{Period}_{arena}_{corridor}"

        try:
            # Extract interaction events and mark them in the DataFrame
            data = extract_interaction_events(ballpath, flypath, mark_in_df=True)
            data["start"] = start
            data["end"] = end
            data["Genotype"] = Genotype
            data["Date"] = Date
            data["arena"] = arena
            data["corridor"] = corridor
            Flycount += 1
            data["Fly"] = f'Fly {Flycount}'
            # Compute yball_relative relative to start
            data['yball_relative'] = abs(data['yball_smooth'] - data['start'])

            # Fill missing values using linear interpolation
            data['yball_relative'] = data['yball_relative'].interpolate(method='linear')
            
            
            # Append the data to the all_data DataFrame
            Dataset_list.append(data)
        except Exception as e:
            error_message = str(e)
            traceback_message = traceback.format_exc()
            #print(f"Error processing video {vidname}: {error_message}")
            #print(traceback_message)

# Concatenate all dataframes in the list into a single dataframe
Dataset = pd.concat(Dataset_list, ignore_index=True)

Dataset.head()

# Example with 1 fly

In [None]:
# Select data from the first fly
Dataset_solo = Dataset[Dataset['Fly'] == 'Fly 1']

Dataset_solo.head()

In [None]:
plt.figure(figsize=(20, 10))

plt.scatter(Dataset_solo['time'], np.zeros_like(Dataset_solo['time']), c=Dataset_solo['yball_relative'], cmap='viridis')

plt.colorbar(label='yball_relative')
plt.xlabel('time')

#plt.show()

In [None]:
# with bokeh

p = figure()

# Create a color mapper
color_mapper = LinearColorMapper(palette='Viridis256', low=Dataset_solo['yball_relative'].min(), high=Dataset_solo['yball_relative'].max())

# Create a scatter plot of 'time' vs a series of zeros (since it's a 1D plot), 
# color-coded by 'yball_relative' value with a viridis colormap
# Create a ColumnDataSource from your DataFrame
Dataset_solo['zeros'] = np.zeros_like(Dataset_solo['time'])

source = ColumnDataSource(Dataset_solo)

# Now use column names in your call to p.circle()
p.circle('time', 'zeros', color={'field': 'yball_relative', 'transform': color_mapper}, source=source)
# Add colorbar
color_bar = ColorBar(color_mapper=color_mapper, label_standoff=12, location=(0,0), title='yball_relative')
p.add_layout(color_bar, 'right')

# Display the plot in Jupyter Notebook
output_notebook()
show(p)

In [None]:
# Create a Points plot
scatter = hv.Points(Dataset_solo, kdims=['time', 'yball_relative'], vdims=['yball_relative'])

# Apply colormap
scatter = scatter.opts(color='yball_relative', cmap='viridis', colorbar=True)

# Display the plot
hv.extension('bokeh')
scatter


In [None]:
Dataset_three = Dataset[Dataset['Fly'].isin(['Fly 1', 'Fly 2', 'Fly 3'])]

Dataset_three.head()

In [None]:
from bokeh.models import FixedTicker

# Create a new column 'FlyIndex' that represents the index of each fly
flies = Dataset_three['Fly'].unique()
fly_to_index = {fly: i for i, fly in enumerate(flies)}
Dataset_three['FlyIndex'] = Dataset_three['Fly'].map(fly_to_index)


In [None]:
print(flies)

print(fly_to_index)

In [None]:
# Convert indices to integers
fly_to_index = {fly: int(i) for fly, i in fly_to_index.items()}

# Create a color mapper
color_mapper = LinearColorMapper(palette='Viridis256', low=Dataset_three['yball_relative'].min(), high=Dataset_solo['yball_relative'].max())

# Create a new figure
p = figure()

# Create a new figure with y-range
#p = figure(y_range=(-0.5, len(flies) - 0.5))


# Create a scatter plot of 'time' vs 'FlyIndex', 
# color-coded by 'yball_relative' value with a viridis colormap
source = ColumnDataSource(Dataset_three)
p.circle('time', 'FlyIndex', color={'field': 'yball_relative', 'transform': color_mapper}, source=source)

# Add colorbar
color_bar = ColorBar(color_mapper=color_mapper, label_standoff=12, location=(0,0), title='yball_relative')
p.add_layout(color_bar, 'right')

# Set y-axis labels to be the fly names
p.yaxis.ticker = FixedTicker(ticks=list(fly_to_index.values()))

# Convert indices to strings
fly_to_index = {fly: str(i) for fly, i in fly_to_index.items()}
p.yaxis.major_label_overrides = fly_to_index

# Display the plot in Jupyter Notebook
#output_notebook()
show(p)


In [None]:
bokeh.io.show(p)

In [None]:
# Create a new column 'FlyIndex' that represents the index of each fly
flies = Dataset_three['Fly'].unique()
fly_to_index = {fly: i for i, fly in enumerate(flies)}
Dataset_three['FlyIndex'] = Dataset_three['Fly'].map(fly_to_index)

# Create a Points plot
scatter = hv.Points(Dataset_three, kdims=['time', 'FlyIndex'], vdims=['yball_relative'])

# Apply colormap
scatter = scatter.opts(color='yball_relative', cmap='viridis', colorbar=True)

# Display the plot
hv.extension('bokeh')
scatter


In [None]:
# Pivot your DataFrame to create a 2D grid
grid = Dataset_three.pivot_table(index='FlyIndex', columns='time', values='yball_relative')

# Convert the pivoted DataFrame to an xarray DataArray
da = xr.DataArray(grid)

# Create an Image plot
image = hv.Image(da, kdims=['time', 'FlyIndex'])

# Apply colormap
image = image.opts(cmap='viridis', colorbar=True)

# Display the plot
hv.extension('bokeh')
image


In [None]:
import pandas as pd
import numpy as np
from bokeh.io import show
from bokeh.models import LinearColorMapper, BasicTicker, PrintfTickFormatter, ColorBar
from bokeh.plotting import figure
from bokeh.sampledata.unemployment1948 import data

# Pivot your DataFrame to get 'Fly' in rows and 'Time' in columns
df_pivot = Dataset_three.pivot(index='Fly', columns='time', values='yball_relative')

# Get data for the plot
flies = list(df_pivot.index)
times = list(df_pivot.columns)

# Convert times and flies to strings
times = [str(time) for time in times]
flies = [str(fly) for fly in flies]

yball_relative = df_pivot.values

# Create a color mapper
mapper = LinearColorMapper(palette='Viridis256', low=Dataset_three.yball_relative.min(), high=Dataset_three.yball_relative.max())

# Create a figure
p = figure(title="yball_relative values (flies vs time)",
           x_range=times, y_range=flies,
           x_axis_location="above",
           tools="hover", toolbar_location='below')

source = ColumnDataSource(Dataset_three)

# Create a rectangle (each cell in the matrix)
p.rect(x="time", y="Fly", width=1, height=1,
       source=source,
       fill_color={'field': 'yball_relative', 'transform': mapper},
       line_color=None)

# Add color bar on the right
color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size="5pt",
                     ticker=BasicTicker(desired_num_ticks=len(Dataset_three['yball_relative'].unique())),
                     formatter=PrintfTickFormatter(format="%d%%"),
                     label_standoff=6, border_line_color=None, location=(0, 0))
p.add_layout(color_bar, 'right')

show(p)


In [None]:
import holoviews as hv
import pandas as pd
import numpy as np
from holoviews import opts
from holoviews.plotting.util import process_cmap
from bokeh.palettes import Viridis256

hv.extension('bokeh')

# Pivot your DataFrame to get 'Fly' in rows and 'time' in columns
df_pivot = Dataset_three.pivot(index='Fly', columns='time', values='yball_relative')

# Create a HeatMap
heatmap = hv.HeatMap((df_pivot.columns, df_pivot.index, df_pivot.values))

# Define custom colormap
cmap = process_cmap('Viridis256_r', provider='bokeh')

heatmap.opts(opts.HeatMap(cmap=cmap, colorbar=True, tools=['hover'], width=900, height=400))


In [None]:
Plotlist = []
# Define custom colormap
cmap = process_cmap('Viridis256_r', provider='bokeh')
Genotypes = Dataset['Genotype'].unique()

# Define a function that removes y-axis ticks
def remove_yticks(plot, element):
    plot.handles['yaxis'].ticker = FixedTicker(ticks=[])

for genotype in Genotypes:
    # Filter the data for the current genotype
    df_genotype = Dataset[Dataset['Genotype'] == genotype]
    
    # Pivot the DataFrame to get 'Fly' in rows and 'time' in columns
    df_pivot = df_genotype.pivot(index='Fly', columns='time', values='yball_relative')
    
    # Create a HeatMap
    heatmap = hv.HeatMap((df_pivot.columns, df_pivot.index, df_pivot.values))
    
    # Apply options to the HeatMap
    heatmap.opts(opts.HeatMap(cmap=cmap, colorbar=True, tools=['hover'], width=900, height=900, title=genotype, xlabel="Time(s)", ylabel="", fontscale=1.5))
    
    # Add the HeatMap to the list of plots
    #Plotlist.append(heatmap)
    # Save each plot as a separate file
    hv.save(heatmap, f"{SavePath}/heatmap_{genotype}.png", fmt='png')
    

In [None]:
# Generate rows and columns length according to the number of genotypes in order to display the plots in a square grid
rows = int(np.ceil(np.sqrt(len(Plotlist))))
cols = int(np.ceil(len(Plotlist) / rows))

# Create a list where each item is a tuple of (genotype, plot)
plot_list = [(genotype, plot) for genotype, plot in zip(Genotypes, Plotlist)]

# Create a layout of plots
layout = hv.Layout(Plotlist).cols(cols)

# Set options for the layout

# Save the layout as a html file
hv.save(layout, str(SavePath / '231005_TNTBroad.png'), fmt='png')

In [None]:
# Plot one of the HeatMaps
Plotlist[0]

In [None]:
Plotlist = []
# Define custom colormap
cmap = process_cmap('Viridis256_r', provider='bokeh')
Genotypes = Dataset['Genotype'].unique()

# Define a function that removes y-axis ticks
def remove_yticks(plot, element):
    plot.handles['yaxis'].ticker = FixedTicker(ticks=[])

for genotype in Genotypes:
    # Filter the data for the current genotype
    df_genotype = Dataset[Dataset['Genotype'] == genotype]
    
    # Pivot the DataFrame to get 'Fly' in rows and 'time' in columns
    df_pivot = df_genotype.pivot(index='Fly', columns='time', values='yball_relative')
    
    # Convert the index to a Series
    index_series = pd.Series(df_pivot.index)

    # Extract the number from each string and convert it to an integer
    numerical_index = index_series.str.extract('(\d+)').astype(int)

    # Replace the index of df_pivot with the numerical index
    df_pivot.index = numerical_index[0]
    
    # Create a HeatMap
    
    image = hv.Image((df_pivot.columns, df_pivot.index, df_pivot.values))
    
    # Apply options to the HeatMap
    image.opts(opts.Image(tools=['hover'], width=300, height=300, title=genotype, xlabel="Time(s)", ylabel=""))
    
    # Add the HeatMap to the list of plots
    Plotlist.append(image)

# Raster plots of events

In [None]:
Dataset = generate_dataset(Folders)
Dataset.head()

In [None]:
Dataset = extract_interaction_events(Dataset, mark_in_df=True)

Dataset.head()

In [None]:
# Check which values are in Events 

Dataset['Event'].unique()

In [None]:
Genotypes = Dataset['Genotype'].unique()
Genotypes

In [None]:
# Define custom colormap
cmap = process_cmap('Viridis256_r', provider='bokeh')

Genotypes = Dataset['Genotype'].unique()

for genotype in Genotypes:
    # Filter the data for the current genotype
    df_genotype = Dataset[Dataset['Genotype'] == genotype]

    # Create an additional HeatMap for the 'Event' column
    df_event = df_genotype.pivot(index='Fly', columns='time', values='Event')
    
    # Map the 'Event' values to binary: 0 if None, 1 otherwise
    df_event = df_event.notnull().astype(int)
    
    # Create a HeatMap for the 'Event' column
    event_heatmap = hv.HeatMap((df_event.columns, df_event.index, df_event.values))
    
    # Apply options to the HeatMap
    event_heatmap.opts(opts.HeatMap(cmap=['white', 'red'], colorbar=False, tools=['hover'], width=900, height=900, title=f"{genotype} Events", xlabel="Time(s)", ylabel="", fontscale=1.5))
    
    # Save each plot as a separate file
    hv.save(event_heatmap, f"{SavePath}/event_heatmap_{genotype}.png", fmt='png')
