In [2]:
import math
import numpy as np
import pandas as pd
import pickle
import os
import cv2
import seaborn as sns
import matplotlib.pyplot as plt

from scipy import signal, fftpack, stats

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

In [42]:
# Retrieve folders for transgenic strains

transgenics = os.listdir('CoBar-Dataset')
transgenics.remove('PR') # Control

In [43]:
def load_data(xp):
    '''
    Load data from an experiment (xp)
    
    Returns
    ------
    gen_dict
        General info on data (on/off periods, collisions...)
    
    data
        Raw data
    
    metadata
        Raw metadata
    '''
    
    # Load gendict
    genDict = np.load(f'CoBar-Dataset/{xp}/U3_f/genotype_dict.npy', allow_pickle=True).item()
    
    # Load data
    with open(f'CoBar-Dataset/{xp}/U3_f/{xp}_U3_f_trackingData.pkl', 'rb') as f:
        data = pickle.load(f)
    
    print(f'{xp} - Data dimension: {data.shape}')
    
    # Extract metadata
    metadata = np.array([list(item) for item in data.index.values])
    print(f'{xp} - Metadata dimension: {metadata.shape}')
    
    return genDict, data, metadata

In [44]:
def get_data_per_fly_per_xp():  
    '''
    Extract data for each fly in each experiment
    
    N: number of flies
    L: number of frames
    
    Returns
    ------
    dict_pretarsi_data
        A dictionary containing xy data for each pretarsus for each fly, size Lx12 for each fly
        key = experiment + fly index (according to tracking video)
    
    dict_metadata
        Corresponding metadata, size Lx6 for each fly
        key = experiment + fly index (according to tracking video)
    
    dict_pos_data
        Centroid xy data, size Lx2 for each fly
        key = experiment + fly index (according to tracking video)
    
    n_trial_data
        Get number of flies in each experiment
        key = experiment
    '''
    
    dict_pretarsi_data = {}
    dict_metadata = {}
    dict_pos_data = {}
    n_trial_data = {}
    
    stim_col = 1 # Metadata column with stimulation info ('on'/'off')
    xp_col = 3   # Metadata column with experiment info (time of experiment)
    fly_col = 4  # Metadata column with fly info (fly 0, 1 or 2) 
    
    for strain in transgenics:
        # Load data for given strain
        genDict, data, metadata = load_data(strain)
        
        # Extract pretarsi data
        pretarsi = ["LFclaw", "LHclaw", "LMclaw", "RFclaw", "RHclaw", "RMclaw"]
        pos = ["posx", "posy"]
        orientation = ["orientation"]
        pretarsi_data = data[pretarsi]   
        pos_data = data["center"][pos + orientation]
    
        # Gather all possible experiments and maximum number of flies
        xps = np.unique(metadata[:,xp_col])
        flies = np.unique(metadata[:,fly_col])
        
        n_trials = len(xps)*len(flies)

        for xp in xps:
            # Extract rows corresponding to current experiment
            xp_idx = np.where(metadata[:,xp_col] == xp)[0]
            
            # Extract corresponding metadata, pretarsi data and positional data
            xp_metadata = metadata[xp_idx]
            xp_pretarsi_data = pretarsi_data.iloc[xp_idx]
            xp_pos_data = pos_data.iloc[xp_idx]

            for fly in flies:
                # Extract rows corresponding to current fly
                fly_idx = np.where(xp_metadata[:,fly_col] == fly)[0]
                
                # Extract corresponding metadata and data for current fly
                xp_fly_metadata = xp_metadata[fly_idx]
                xp_fly_pretarsi_data = xp_pretarsi_data.iloc[fly_idx]
                xp_fly_pos_data = xp_pos_data.iloc[fly_idx]
                
                xp_fly_metadata = np.append(xp_fly_metadata, np.array(range(len(xp_fly_metadata))).reshape(-1,1), axis=1)
                
                
                # Sort timestamps, and re-arrange fly data for time stamps
                # Order = 'off0', 'on0', 'off1', 'on1', 'off2', 'on2', 'off3'
                if not(xp_fly_pretarsi_data.empty):
                    
                    dict_metadata[xp+fly] = np.array(sorted(xp_fly_metadata, key=lambda x: (int(x[stim_col][-1]), x[stim_col])))
                    idx_sort = np.array(list(map(int, dict_metadata[xp+fly][:,-1])))
                    dict_pretarsi_data[xp + fly] = np.array(xp_fly_pretarsi_data)[idx_sort,:] * 38/832                    
                    dict_pos_data[xp + fly] = np.array(xp_fly_pos_data)[idx_sort,:]
                    
                    # Only convert x, y pos positions to mm
                    dict_pos_data[xp + fly][:,:2] *= 38/832
                else:
                    n_trials -= 1
            
        n_trial_data[strain] = n_trials
        
        print(f'{strain}: {n_trials} trials')
    
    return dict_pretarsi_data, dict_metadata, dict_pos_data, n_trial_data

In [45]:
raw_pretarsi_data, raw_metadata, _, n_trial_data = get_data_per_fly_per_xp()

MDN - Data dimension: (28770, 70)
MDN - Metadata dimension: (28770, 6)
MDN: 12 trials
SS01049 - Data dimension: (26385, 70)
SS01049 - Metadata dimension: (26385, 6)
SS01049: 11 trials
SS01054 - Data dimension: (31162, 70)
SS01054 - Metadata dimension: (31162, 6)
SS01054: 13 trials
SS01540 - Data dimension: (26361, 70)
SS01540 - Metadata dimension: (26361, 6)
SS01540: 11 trials
SS02111 - Data dimension: (26396, 70)
SS02111 - Metadata dimension: (26396, 6)
SS02111: 11 trials
SS02279 - Data dimension: (28776, 70)
SS02279 - Metadata dimension: (28776, 6)
SS02279: 12 trials
SS02377 - Data dimension: (28764, 70)
SS02377 - Metadata dimension: (28764, 6)
SS02377: 12 trials
SS02608 - Data dimension: (28740, 70)
SS02608 - Metadata dimension: (28740, 6)
SS02608: 12 trials
SS02617 - Data dimension: (26355, 70)
SS02617 - Metadata dimension: (26355, 6)
SS02617: 11 trials


In [46]:
# Some keys to test
mdn = '200206_1539540'
turn = '200212_1620432'

In [47]:
def findOnPeriods(key, raw_metadata, display=False):
    '''
    Find On-Stimulation periods for a given experiment
    
    Parameters
    ------
    key
        Experiment key
    
    raw_metadata
        Raw metadata for each fly in each experiment
    
    display
        If true, displays on-stimulation intervals
    
    Returns
    ------
    numpy.ndarray
        All frames where the stimulation was ON
    '''
    
    metadata = raw_metadata[key]
    on_periods = ['on0', 'on1', 'on2']

    on_intervals = []

    for p in on_periods:
        start_period = np.where(metadata[:,1] == p)[0][0]
        end_period = np.where(metadata[:,1] == p)[0][-1]
        on_intervals.extend(list(range(start_period, end_period)))
        if display:
            print(f'{p}: {[start_period, end_period]}')
    return np.array(on_intervals)

In [48]:
def findStimulationData(raw_pretarsi_data, raw_metadata):
    '''
    Extract data where the stimulation was ON
    
    raw_pretarsi_data
        Raw xy data for each pretarsus, size Lx12 for each fly
    
    raw_metadata
        Corresponding metadata, size Lx6 for each fly
    
    Returns
    ------
    stim_data
        XY stimulation data for each pretarsus, size lx12 for each fly, l < L
    '''
    stim_data = {}
    
    for key, data in raw_pretarsi_data.items():
        on_idxs = findOnPeriods(key, raw_metadata)
        
        stim_data[key] = data[on_idxs,:]
        
    return stim_data

In [49]:
def buildStimulationArray(stim_data, nCoords=2):
    '''
    Build array from dictionary data containg stimulation xy positions for each fly
    
    stim_data
        Dictionary of "on" xy data for each fly
    
    nCoords
        Parameter to adjust size of matrix
        nCoords = 2 if no wavelet transformation is applied
        nCoords = 40 else
    
    Returns
    stim_array
        Transformation of stim_data into a numpy array: size Nx(l*K)
        N: number of flies
        l: number of on-stimulation frames
        K: number of features
    '''
    
    nLegs = 6
    min_nFrames = stim_data[min(stim_data, key=lambda x: stim_data[x].shape[0])].shape[0]
    
    stim_array = np.zeros((len(stim_data), min_nFrames*nLegs*nCoords))
    
    for i, (key, data) in enumerate(stim_data.items()):
        stim_array[i,:] = data[:min_nFrames,:].flatten()
    
    return stim_array

In [50]:
stim_data = findStimulationData(raw_pretarsi_data, raw_metadata)

In [51]:
def findWavelets(stim_data):
    '''
    Apply wavelet transformation
    
    We used 20 frequencies from 1 to 40 Hz, 40 being the Nyquist frequency (Frame rate: 80 fps)
    
    Parameters
    ------
    stim_data
        A dictionary with raw data for each fly
        key = experiment + fly index (according to tracking video)
    
    Returns
    ------
    wavelet_data
        A dictionary with wavelet transform of data for each fly
    '''
    
    wavelet_data = {}

    n_scales = 20
    fps = 80
    f_min = 1
    f_max = fps/2
    #logvar_thresh = -6.

    for (key, data) in stim_data.items():
        wavelet_data[key] = np.zeros((list(data.shape) + [n_scales]))

        # Wavelet transformation
        for i in range(data.shape[1]):
            sig = abs(signal.cwt(data[:,i], signal.morlet2, np.geomspace(f_min, f_max, n_scales)).T)

            wavelet_data[key][:,i,:] = sig

        wavelet_data[key] = wavelet_data[key].reshape(wavelet_data[key].shape[0], wavelet_data[key].shape[1]*wavelet_data[key].shape[2])

        # Frame-normalization
        for t in range(wavelet_data[key].shape[0]):
            wavelet_data[key][t,:] = wavelet_data[key][t,:]/(wavelet_data[key][t,:].sum())
    
    return wavelet_data

In [52]:
wavelet_data = findWavelets(stim_data)

In [53]:
# Build array from wavelet data
stim_array = buildStimulationArray(wavelet_data, nCoords=40)

In [54]:
# Extract classes for each fly (1 = MDN, 2 = SS01049...)

classes = [] # List of colours attributed to each strain for later plotting
strains = [] # Strain corresponding to each fly
unique_classes = [] # Unique version of 'classes'

for i, key in enumerate(n_trial_data.keys()):
    #print(i)
    #print(n_trial_data[key])
    #print(n_trial_data[key] * [i])
    unique_classes.append(sns.color_palette()[i])
    for j in range(n_trial_data[key]):
        classes.append(sns.color_palette()[i])
        strains.append(key)
unique_classes = np.array(unique_classes)
classes = np.array(classes)
strains = np.array(strains)

In [55]:
# Embed stimulation array in a 2D spaceusing TSNE
print(stim_array.shape)
embedded_array = TSNE(n_components=2).fit_transform(stim_array)

(105, 168960)


In [1]:
g = sns.jointplot(embedded_array[:,0], embedded_array[:,1], kind="kde")

#Clear the axes containing the scatter plot
g.ax_joint.cla()

# set the current axis to be the joint plot's axis
plt.sca(g.ax_joint)

# plt.scatter takes a 'c' keyword for color
# you can also pass an array of floats and use the 'cmap' keyword to
# convert them into a colormap
sc = plt.scatter(embedded_array[:,0], embedded_array[:,1], c=classes, label=strains)
lp = lambda i: plt.plot([], color=unique_classes[i], mec="none",
                        label=list(n_trial_data.keys())[i], ls="", marker="o")[0]
handles = [lp(i) for i in np.arange(len(np.unique(strains)))]
plt.legend(handles=handles, ncol=2, prop={'size': 8})
plt.tight_layout()

NameError: name 'sns' is not defined