In [20]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import cebra
from sklearn.preprocessing import StandardScaler

# Check if CUDA is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [3]:
Contact_data = pd.read_feather("/mnt/upramdya_data/MD/MultiMazeRecorder/Datasets/Skeleton_TNT/250106_FinalEventCutoffData_norm/contact_data/250106_Pooled_contact_data.feather")

In [6]:
Contact_data.head()

Unnamed: 0,index,frame,time,object,x_Head,y_Head,x_Thorax,y_Thorax,x_Abdomen,y_Abdomen,...,experiment,Nickname,Brain region,Date,Genotype,Period,FeedingState,Orientation,Light,Crossing
0,0,535,18.448276,ball_1,67.948608,291.797211,67.731575,303.940399,67.61805,324.499023,...,231208_TNT_Fine_2_Videos_Tracked,23292 (ORCO-GAL4),Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1
1,1,536,18.482759,ball_1,68.463242,287.871277,68.001106,299.813629,67.53537,320.325226,...,231208_TNT_Fine_2_Videos_Tracked,23292 (ORCO-GAL4),Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1
2,2,537,18.517241,ball_1,71.965149,284.093201,71.674355,295.817657,68.011589,316.210266,...,231208_TNT_Fine_2_Videos_Tracked,23292 (ORCO-GAL4),Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1
3,3,538,18.551724,ball_1,71.875038,280.341827,71.749344,292.238403,68.141876,312.562561,...,231208_TNT_Fine_2_Videos_Tracked,23292 (ORCO-GAL4),Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1
4,4,539,18.586207,ball_1,68.433395,279.709137,68.357704,291.553284,68.070953,312.044952,...,231208_TNT_Fine_2_Videos_Tracked,23292 (ORCO-GAL4),Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1


# Data preparation 

Here we'll start by creating a key column, that should be a unique identifier for each contact in the dataset (basically one number per interaction between fly and contact_index)

Then we remove anything that is not tracking data. We also remove the ball centre to only keep the preprocessed tracking data. 

Finally, we save to h5 file and load it to cebra

In [5]:
for col in Contact_data.columns:
    print(col)
    

index
frame
time
object
x_Head
y_Head
x_Thorax
y_Thorax
x_Abdomen
y_Abdomen
x_Rfront
y_Rfront
x_Lfront
y_Lfront
x_Rmid
y_Rmid
x_Lmid
y_Lmid
x_Rhind
y_Rhind
x_Lhind
y_Lhind
x_Rwing
y_Rwing
x_Lwing
y_Lwing
contact_index
x_centre
y_centre
euclidean_distance
x_centre_preprocessed
y_centre_preprocessed
fly
flypath
experiment
Nickname
Brain region
Date
Genotype
Period
FeedingState
Orientation
Light
Crossing


In [10]:
# Make a trial column that is an index for each unique combination of fly and contact_index

Contact_data['trial'] = Contact_data.groupby(['fly', 'contact_index']).ngroup()



In [13]:
# For each trial, shift the time column so that the first time is 0

Contact_data['time_shifted'] = Contact_data.groupby('trial')['time'].transform(lambda x: x - x.min())

In [15]:
Contact_data.columns

Index(['index', 'frame', 'time', 'object', 'x_Head', 'y_Head', 'x_Thorax',
       'y_Thorax', 'x_Abdomen', 'y_Abdomen', 'x_Rfront', 'y_Rfront',
       'x_Lfront', 'y_Lfront', 'x_Rmid', 'y_Rmid', 'x_Lmid', 'y_Lmid',
       'x_Rhind', 'y_Rhind', 'x_Lhind', 'y_Lhind', 'x_Rwing', 'y_Rwing',
       'x_Lwing', 'y_Lwing', 'contact_index', 'x_centre', 'y_centre',
       'euclidean_distance', 'x_centre_preprocessed', 'y_centre_preprocessed',
       'fly', 'flypath', 'experiment', 'Nickname', 'Brain region', 'Date',
       'Genotype', 'Period', 'FeedingState', 'Orientation', 'Light',
       'Crossing', 'trial', 'time_shifted'],
      dtype='object')

In [None]:
# Prepare a list of columns to keep

columns_to_keep = ['frame', 'time', 'x_Head', 'y_Head', 'x_Thorax',
       'y_Thorax', 'x_Abdomen', 'y_Abdomen', 'x_Rfront', 'y_Rfront',
       'x_Lfront', 'y_Lfront', 'x_Rmid', 'y_Rmid', 'x_Lmid', 'y_Lmid',
       'x_Rhind', 'y_Rhind', 'x_Lhind', 'y_Lhind', 'x_Rwing', 'y_Rwing',
       'x_Lwing', 'y_Lwing', 'contact_index',
       'euclidean_distance', 'x_centre_preprocessed', 'y_centre_preprocessed',
       'fly', 'flypath', 'experiment', 'Nickname', 'Brain region', 'Date',
       'Genotype', 'Period', 'FeedingState', 'Orientation', 'Light',
       'Crossing', 'trial', 'time_shifted']

In [14]:
Contact_data.head()

Unnamed: 0,index,frame,time,object,x_Head,y_Head,x_Thorax,y_Thorax,x_Abdomen,y_Abdomen,...,Brain region,Date,Genotype,Period,FeedingState,Orientation,Light,Crossing,trial,time_shifted
0,0,535,18.448276,ball_1,67.948608,291.797211,67.731575,303.940399,67.61805,324.499023,...,Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1,22106,0.0
1,1,536,18.482759,ball_1,68.463242,287.871277,68.001106,299.813629,67.53537,320.325226,...,Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1,22106,0.034483
2,2,537,18.517241,ball_1,71.965149,284.093201,71.674355,295.817657,68.011589,316.210266,...,Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1,22106,0.068966
3,3,538,18.551724,ball_1,71.875038,280.341827,71.749344,292.238403,68.141876,312.562561,...,Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1,22106,0.103448
4,4,539,18.586207,ball_1,68.433395,279.709137,68.357704,291.553284,68.070953,312.044952,...,Olfaction,231208,TNTxZ1649,PM15,starved_noWater,std,on,1,22106,0.137931


In [25]:
# Features (coordinates)
feature_columns = ['x_Head', 'y_Head', 'x_Thorax', 'y_Thorax', 'x_Abdomen', 'y_Abdomen', 
                   'x_Rfront', 'y_Rfront', 'x_Lfront', 'y_Lfront', 'x_Rmid', 'y_Rmid', 
                   'x_Lmid', 'y_Lmid', 'x_Rhind', 'y_Rhind', 'x_Lhind', 'y_Lhind', 
                   'x_Rwing', 'y_Rwing', 'x_Lwing', 'y_Lwing', 'x_centre_preprocessed', 
                   'y_centre_preprocessed']

# Time dimension
time_column = 'time_shifted'

# Metadata (you can choose relevant columns for auxiliary variables)
metadata_columns = ['frame','fly', 'flypath', 'experiment', 'Nickname', 'Brain region', 'Date',
       'Genotype', 'Period', 'FeedingState', 'Orientation', 'Light',
       'Crossing', 'euclidean_distance', ]

X = Contact_data[feature_columns].values
T = Contact_data[time_column].values
metadata = Contact_data[metadata_columns]

In [26]:
# 4. Prepare data lists for MultiCEBRA
X_list = []
T_list = []
auxiliary_list = []

In [17]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_normalized = scaler.fit_transform(X)

In [27]:
# 6. Iterate through trials
for trial in Contact_data['trial'].unique():
    trial_data = Contact_data[Contact_data['trial'] == trial]
    
    # Features
    X_trial = scaler.fit_transform(trial_data[feature_columns].values)
    X_list.append(X_trial)
    
    # Time
    T_trial = trial_data[time_column].values
    T_list.append(T_trial)
    
    # Auxiliary data (including time)
    auxiliary_trial = pd.get_dummies(trial_data[metadata_columns], columns=['fly', 'flypath', 'experiment', 'Nickname', 'Brain region', 'Date', 'Genotype', 'Period', 'FeedingState', 'Orientation', 'Light', 'Crossing'])
    auxiliary_trial = np.column_stack((T_trial, auxiliary_trial.values))
    auxiliary_list.append(auxiliary_trial)

In [29]:
# Set up CEBRA model
model = cebra.CEBRA(
    model_architecture='offset10-model',
    batch_size=512,
    learning_rate=3e-4,
    temperature=1,
    output_dimension=3,  # You can adjust this
    max_iterations=10000,
    device=device,
    conditional='time_delta'
)

In [None]:
model.fit(X_list, auxiliary=auxiliary_list)

In [None]:
embeddings = model.transform(X_list)

In [6]:
# Save it as h5 

Contact_data.to_hdf("/mnt/upramdya_data/MD/MultiMazeRecorder/Datasets/Skeleton_TNT/250106_FinalEventCutoffData_norm/contact_data/250106_Pooled_contact_data.h5", key="df", mode="w", format="table")

In [None]:

cebra_df = cebra.load_data("/mnt/upramdya_data/MD/MultiMazeRecorder/Datasets/Skeleton_TNT/250106_FinalEventCutoffData_norm/contact_data/250106_Pooled_contact_data.h5")