In [None]:
import sys
import os
import seaborn as sns
sys.path.append("../")
from src.data.session import load_all_data, to_dF_F
from scipy.spatial.transform import Rotation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
%load_ext autoreload
%autoreload 2

## Load Data

In [None]:
project_dir = "/Users/anacarolinabotturabarros/PycharmProjects/RSCMINHD/"
animal = 'cohoHDC1_mB1'
session = "2023_03_10/12_51_18"
data_dict = load_all_data(animal,session, project_dir)
img_data = data_dict['img_data']
head_data = data_dict['head_data']
img_time_stamps = data_dict['img_time_stamps']

### Calculate dF/F and normalised calcium traces

Still need to figure out what is the best measure to use.

In [None]:
group_by_unit = img_data.groupby('unit_id')
df_units = []
normalised_c = []
for name, unit in group_by_unit:
    dFF = to_dF_F(unit)
    dFF.name = name
    df_units.append(dFF)
    min_C = unit['C'].min()
    max_C = unit['C'].max()
    min_max_norm = (unit['C']- min_C)/(max_C-min_C)
    normalised_c.append(min_max_norm)
    

# pd.concat(df_units, ignore_index=True, axis=1)
img_data['df/f'] = pd.concat(df_units)
img_data['norm_C'] = pd.concat(normalised_c)

## Plot a unit

Plotting some units might give some idea of what we have

In [None]:
# TODO: Plot all units
unit = 2

# plot C
plt.figure(figsize=(10,2))
unit_data = img_data[img_data['unit_id']==img_data['unit_id'].unique()[unit]]
sns.lineplot(data=unit_data, x='frame',y='C')

# plot df/f based on C
plt.figure(figsize=(10,2))
unit_data = img_data[img_data['unit_id']==img_data['unit_id'].unique()[unit]]
sns.lineplot(data=unit_data, x='frame',y='df/f')

# plot C normalised by min-max normalisation
plt.figure(figsize=(10,2))
unit_data = img_data[img_data['unit_id']==img_data['unit_id'].unique()[unit]]
sns.lineplot(data=unit_data, x='frame',y='norm_C')

## Working with head orientation data

In [None]:
# plot variable from head orientation data
ax = sns.lineplot(data=head_data, x='Time Stamp (ms)',y='qx')
sns.lineplot(data=head_data, x='Time Stamp (ms)',y='qy')
sns.lineplot(data=head_data, x='Time Stamp (ms)',y='qz')
ax.set(ylabel='Direction')

Here we try to bin the different variables into a certain number of bins

In [None]:
bins = 11
head_data['binned_qx'] = pd.cut(head_data['qx'],np.linspace(-1,1,bins))
head_data['binned_qy'] = pd.cut(head_data['qy'],np.linspace(-1,1,bins))
head_data['binned_qz'] = pd.cut(head_data['qz'],np.linspace(-1,1,bins))
head_data['binned_qw'] = pd.cut(head_data['qw'],np.linspace(-1,1,bins))

Here we plot the the count of occurences per bin for each of the variables

In [None]:

fig, axs = plt.subplots(1,4, sharey=True, figsize=(20,5))
head_qua_vars = head_data.columns[1:]

for ax,var in zip(axs,head_qua_vars):
    p = sns.countplot(x=head_data['binned_'+var], ax=ax)
    ax.set_xticklabels(head_data['binned_qx'].cat.categories, rotation=45, ha='right')
    ax.set(ylabel=None)

axs[0].set_ylabel('Count')

### Transform quartenion data into Euler coordinates

Two ways of calculating the euler angles. One function (adapted to deal with long format data) or the scipy Rotation function.

In [None]:
import math
 
def euler_from_quaternion(x, y, z, w):
        """
        Convert a quaternion into euler angles (roll, pitch, yaw)
        roll is rotation around x in radians (counterclockwise)
        pitch is rotation around y in radians (counterclockwise)
        yaw is rotation around z in radians (counterclockwise)
        """
        t0 = +2.0 * (w * x + y * z)
        t1 = +1.0 - 2.0 * (x * x + y * y)
        roll_x = [math.atan2(i0, i1) for i0, i1 in zip(t0,t1)]
     
        t2 = +2.0 * (w * y - z * x)
        t2 = [+1.0 if i2 > +1.0 else i2 for i2 in t2]
        t2 = [-1.0 if i2 < -1.0 else i2 for i2 in t2]
        pitch_y = [math.asin(i2) for i2 in t2]
     
        t3 = +2.0 * (w * z + x * y)
        t4 = +1.0 - 2.0 * (y * y + z * z)
        yaw_z = [math.atan2(i3, i4) for i3, i4 in zip(t3, t4)]
     
        return roll_x, pitch_y, yaw_z # in radians

In [None]:
x, y, z = euler_from_quaternion(head_data['qx'], head_data['qy'], head_data['qz'], head_data['qw'])
print(x[0],y[0],z[0])

In [None]:
quat = np.array([head_data['qx'], head_data['qy'], head_data['qz'], head_data['qw']])
R = Rotation.from_quat(quat.T)
quat.T.shape
print(quat.T[0])
print(R.as_quat()[0])
euler_data = R.as_euler('xyz', degrees=True)
print(R.as_euler('xyz')[0])

Combining euler data to original head orientation data and saving it

In [None]:
head_data = pd.concat([head_data, pd.DataFrame(euler_data, columns=['roll_x', 'pitch_y', 'yaw_z'])], axis=1)

In [None]:
# bin euler variables
bins = 61
head_data['binned_roll_x'] = pd.cut(head_data['roll_x'],np.linspace(-180,180,bins))
head_data['binned_pitch_y'] = pd.cut(head_data['pitch_y'],np.linspace(-90,90,bins))
head_data['binned_yaw_z'] = pd.cut(head_data['yaw_z'],np.linspace(-180,180,bins))

In [None]:
head_data['frame'] = head_data.index
head_data.to_csv(os.path.join(project_dir,'data/processed', animal, session,'headOrientationProc.csv'))

## Finding intervals for the different bins

In [None]:
import itertools
from operator import itemgetter

def find_intervals(data):
    ranges =[]    
    for key, group in itertools.groupby(enumerate(data), lambda x:x[0]-x[1]):
        group = list(map(itemgetter(1), group))
        if len(group) > 1:
            ranges.append(pd.Interval(group[0], group[-1],closed='both'))
        # else:
            # ranges.append(group[0])
    return ranges

def get_intervals(head_data, variable):
    """get frame intervals where angles are within the specified bin

    :param head_data: head orientation data
    :type head_data: pd.DataFrame
    :param variable: name of the variable in head_data to base intervals
    :type variable: str
    :return: intervals of frames with angles between specific groups
    :rtype: dict
    """
    grouped = head_data[[variable,'frame']].groupby(variable)
    intervals = {}
    for name, group in grouped:
        intervals[name] = find_intervals(group['frame'])
    return intervals
     

In [None]:
# Find intervals for one of the quartenion variables
intervals = get_intervals(head_data,'binned_qx')

# Find intervals for yaw_z, one of the Euler variables
intervals_yaw = get_intervals(head_data,'binned_yaw_z')


## Calculate the area under the curve for all intervals per bin and per unit 

In [None]:
def calculate_auc(unit_data, intervals):
    # calculate the area under the curve norm_C for each unit between intervals found and sum the area under the curve
    """_summary_

    :param unit_data: _description_
    :type unit_data: _type_
    :param intervals: _description_
    :type intervals: _type_
    :return: _description_
    :rtype: _type_
    """
    
    all_auc = pd.DataFrame()

    for key, value in intervals.items():
        auc_dict = {}
        for name, unit in unit_data:
            auc=0
            for i in value:
                img_interval = unit[unit['frame'].between(i.left, i.right)]
                auc += np.trapz(img_interval['norm_C'],img_interval['frame'])
            auc_dict[name] = auc
        all_auc[key] = pd.Series(auc_dict)
    
    return all_auc

def calculate_int_length(intervals):
    """Calculate the length of the intervals for a dictionary of intervals and return the total sum of the lengths for a range of angles.

    :param intervals: Dictionary of intervals
    :type intervals: dict
    :return: dictionary of total sum of lengths
    :rtype: dict
    """
    interval_lengths = {}
    for itv in intervals.keys():
        int_idx = pd.IntervalIndex(intervals[itv])
        total_time = np.sum(int_idx.right - int_idx.left +1)
        interval_lengths[itv] = total_time
    return interval_lengths

In [None]:
# calculate the area under the curve norm_C for each unit between intervals found
# TODO: make this a function - get direction tunning
all_auc = calculate_auc(img_data.groupby('unit_id'), intervals)
# print(all_auc.head())
interval_lengths = calculate_int_length(intervals)
interval_lengths = pd.Series(interval_lengths, name='interval_lengths')
dir_tuning = all_auc.T.div(interval_lengths, axis=0).T # normalise it by diving by the total amount of time spent looking at that direction

In [None]:
# for intervals based on yaw - Here we use Yaw as that seems to correspond to the angle that would give you the haed direction - but maybe that should be properly checked (comparing with actual data?)
all_auc_yaw = calculate_auc(img_data.groupby('unit_id'), intervals_yaw)
# print(all_auc_yaw.head())
interval_yaw_lengths = calculate_int_length(intervals_yaw)
interval_yaw_lengths = pd.Series(interval_yaw_lengths, name='interval_lengths')
dir_tuning_yaw = all_auc_yaw.T.div(interval_yaw_lengths, axis=0).T

In [None]:
dir_tuning

## Visualise direction tuning

In [None]:
# TODO: make this a function for plotting a single ROI, with options (radial, linear)
dir_tuning_yaw['unit_id']=dir_tuning_yaw.index
new_melt = pd.melt(dir_tuning_yaw, id_vars='unit_id')

unit = 90
plot_unit = new_melt[new_melt.unit_id == new_melt.unit_id.unique()[unit]]
ax = sns.lineplot(plot_unit, x=plot_unit.variable.astype(str), y='value')
ax.set_xticklabels(plot_unit.variable.astype(str), rotation=45, ha='right')
ax.set(ylabel='Direction Tunning',xlabel='Direction qx')

In [None]:
plot_unit.loc[:,'degree'] = pd.IntervalIndex(plot_unit['variable']).left*np.pi/180
g = sns.FacetGrid(plot_unit, subplot_kws=dict(projection='polar'), height=4.5, sharex=False, sharey=False, despine=False)
g.map(sns.lineplot,"degree","value")

In [None]:
new_melt.loc[:,'degree'] = pd.IntervalIndex(new_melt['variable']).left*np.pi/180
g = sns.FacetGrid(new_melt, hue='unit_id',subplot_kws=dict(projection='polar'), height=4.5, sharex=False, sharey=False, despine=False)
g.map(sns.lineplot,"degree","value")

In [None]:
# TODO: make this a function - plot all ROIs
new_melt.loc[:,'degree'] = pd.IntervalIndex(new_melt['variable']).left*np.pi/180
g = sns.FacetGrid(new_melt, col='unit_id', col_wrap=10,hue='unit_id',subplot_kws=dict(projection='polar'), height=4.5, sharex=False, sharey=False, despine=False)
g.map(sns.lineplot,"degree","value")


In [None]:
g.savefig(os.path.join(project_dir,"reports/figures",animal,session,"all_ROIs.png")) # TODO: setup to create folder if it doesn't exist

In [None]:
df=dir_tuning_yaw.iloc[:,:-1].T
pref_dir = df.idxmax()
order = pref_dir.sort_values().index
ax = sns.heatmap(df.T.loc[order.to_list(),:])
ax.set_xticklabels(plot_unit.variable.astype(str), rotation=45, ha='right')
ax.set(ylabel='Units', xlabel='Direction')

In [None]:
df=dir_tuning_yaw.iloc[:,:-1].T
pref_dir = df.idxmax()
pref_dir.value_counts()
ax = sns.countplot(x=pref_dir.values, order=df.index)
t = ax.set_xticklabels(df.index, rotation=45, ha='right')

## Checking direction tunning on first and second halves

In [None]:
half = int(head_data.shape[0]/2)
intervals_yaw_1of2 = get_intervals(head_data.loc[:half-1,:],'binned_yaw_z')
intervals_yaw_2of2 = get_intervals(head_data.loc[half:,:],'binned_yaw_z')

In [None]:
# for intervals based on yaw - Here we use Yaw as that seems to correspond to the angle that would give you the haed direction - but maybe that should be properly checked (comparing with actual data?)
all_auc_yaw_1of2 = calculate_auc(img_data[img_data.frame < half].groupby('unit_id'), intervals_yaw_1of2)
# print(all_auc_yaw.head())
interval_yaw_lengths_1of2 = calculate_int_length(intervals_yaw_1of2)
interval_yaw_lengths_1of2 = pd.Series(interval_yaw_lengths_1of2, name='interval_lengths')
dir_tuning_yaw_1of2 = all_auc_yaw_1of2.T.div(interval_yaw_lengths_1of2, axis=0).T

In [None]:
# for intervals based on yaw - Here we use Yaw as that seems to correspond to the angle that would give you the haed direction - but maybe that should be properly checked (comparing with actual data?)
all_auc_yaw_2of2 = calculate_auc(img_data[img_data.frame >= half].groupby('unit_id'), intervals_yaw_2of2)
# print(all_auc_yaw.head())
interval_yaw_lengths_2of2 = calculate_int_length(intervals_yaw_2of2)
interval_yaw_lengths_2of2 = pd.Series(interval_yaw_lengths_2of2, name='interval_lengths')
dir_tuning_yaw_2of2 = all_auc_yaw_2of2.T.div(interval_yaw_lengths_2of2, axis=0).T

In [None]:
dir_tuning_yaw_1of2['unit_id']=dir_tuning_yaw_1of2.index
new_melt_1of2 = pd.melt(dir_tuning_yaw_1of2, id_vars='unit_id')
new_melt_1of2.loc[:,'half'] = 1

dir_tuning_yaw_2of2['unit_id']=dir_tuning_yaw_2of2.index
new_melt_2of2 = pd.melt(dir_tuning_yaw_2of2, id_vars='unit_id')
new_melt_2of2.loc[:,'half'] = 2

df2plot = pd.concat([new_melt_1of2,new_melt_2of2])
df2plot.loc[:,'degree'] = pd.IntervalIndex(df2plot['variable']).left*np.pi/180

g = sns.FacetGrid(df2plot, col='unit_id', col_wrap=10,hue='half',subplot_kws=dict(projection='polar'), height=4.5, sharex=False, sharey=False, despine=False)
g.map(sns.lineplot,"degree","value")
g.add_legend(loc='lower right', fontsize = 32)


In [None]:
g.savefig(os.path.join(project_dir,"reports/figures",animal,session,"all_ROIs_sessionHalves.png"))

## Checking representation of direction bins on the first and second halves

In [None]:
int_lengths = pd.concat([interval_yaw_lengths_1of2, interval_yaw_lengths_2of2],axis=1)
int_lengths.columns = ['first','second']
int_lengths.loc[:,'direction'] = int_lengths.index

In [None]:
int_lengths_flat = pd.melt(int_lengths,id_vars='direction')
int_lengths_flat
sns.lineplot(int_lengths_flat,x=int_lengths_flat['direction'].astype(str),y=int_lengths_flat['value'],hue=int_lengths_flat['variable'])

In [None]:
sns.barplot(int_lengths)