# Ripple-associated CG spectral analysis
https://github.com/Eden-Kramer-Lab/spectral_connectivity/blob/master/examples/Intro_tutorial.ipynb

<br>

### Imports

In [2]:
import os
import re
import glob
import pandas as pd
import numpy as np
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
from spectral_connectivity import Multitaper, Connectivity

warnings.filterwarnings('ignore')

<br>

### Define functions

In [32]:
def split_by_phase(df):
    
    sample_df = df.loc[df['phase']=='Sample'].drop(['phase'], axis=1)
    delay_df = df.loc[df['phase']=='Delay'].drop(['phase'], axis=1)
    test_df = df.loc[df['phase']=='Test'].drop(['phase'], axis=1)
    iti_df = df.loc[df['phase']=='ITI'].drop(['phase'], axis=1)
    
    return sample_df, delay_df, test_df, iti_df





def prepare_for_multitaper(df, n_tts):
    '''
    Rearrange and prepare data for multitaper.
    df, Dataframe. 
    n_tts, int - Number of tetrodes (may vary with dataset)
    '''
    
    # Normalise relative timestamp to correct small jitters
    #df['relative_timestamp']=df.relative_timestamp.round(8)

    # Drop unecessary columns
    to_drop = ['start_time', 'end_time','timestamp']
    df = df.drop(to_drop, axis=1)

    # Create TT list
    tts = ['TT{}'.format(x) for x in range(1,n_tts+1)]
    
    # Re-shape dataset into array
    
    reshaped = df.pivot(
        index="relative_timestamp", 
        columns="ripple_nr", 
        values=tts,
    ).values.reshape(
        (
        df['relative_timestamp'].unique().size,
        df['ripple_nr'].unique().size,
        n_tts
        )
    )
    
    return reshaped


def process_multitaper(df, fs, window, step):
    '''
    Create multitaper object
    '''

    multitaper = Multitaper(
        df, 
        sampling_frequency=fs, 
        time_window_duration = window,
        time_window_step = step,
        #time_halfbandwidth_product= time_halfbandwidth_product
    )

    return multitaper


def create_connectivity(m, exp_type, nblocks):
    '''
    Create connectivity object
    '''

    conn = Connectivity.from_multitaper(
        m, 
        expectation_type = exp_type, 
        blocks = nblocks
)

    return conn


def get_conn (df, n_tts, fs, window, step, exp_type, nblocks):
    
    reshaped_df = prepare_for_multitaper(df, n_tts)
    
    print(reshaped_df.shape)
    multitaper = process_multitaper(reshaped_df, fs, window, step)
    
    
    conn = create_connectivity(multitaper, exp_type, nblocks)
    
    return conn
    

<br>

### Open datasets

In [4]:
main_path = 'PreProcessedData'
cg_data = pd.read_csv(os.path.join(main_path, 'cg_data.csv'), index_col=False)
ripple_data = pd.read_csv(os.path.join(main_path, 'cg_analysis_ripple_library.csv'), index_col=False)

In [5]:
cg_data.loc[(cg_data['phase']=='Test (Pre-choice)')|(cg_data['phase']=='Test (Past-choice)'), 'phase']='Test'

<br>

### Split dataset according to sampling rates 

In [6]:
# Currently the only session with 20K sampling
# ripple_data.loc[ripple_data.session_code == 20191113131818, 'ripple_nr'].unique() --To check

In [7]:
# Data with different sampling rates will be processed separately
# 20191113131818 corresponds to the only dataset collected at 20K and therefore with 2k samples/s after processing
cg_data_2k = cg_data[cg_data.ripple_nr.between(0,74)]
cg_data_3k = cg_data[~cg_data.ripple_nr.between(0,74)]

In [8]:
cg_data_2k.groupby(['phase']).ripple_nr.nunique()

phase
Delay     29
ITI        6
Sample    14
Test      26
Name: ripple_nr, dtype: int64

In [9]:
cg_data_3k.groupby(['phase']).ripple_nr.nunique()

phase
Delay     114
ITI        98
Sample    186
Test      155
Name: ripple_nr, dtype: int64

<br>

## Create connectivity object

#### 1. Prepare data to create mutitaper:

"If we have three dimensions, dimension 1 is time, dimension 2 is trials, and dimensions 3 is signals. It is important to know note that dimension 2 now has a different meaning in that it represents trials and not signals now. Dimension 3 is now the signals dimension. We will show an example of this later."

time_series : array, shape (n_time_samples, n_trials, n_signals)


#### 2. Create multitaper object

Controls the duration of the segment of time the transformation is computed on (seconds)
w_duration = 0.1

Control how far the time window is slid (overlap).
Setting the step to smaller than the time window duration will make the time windows overlap
step = 0.02

Controls the frequency resolution of the Fourier transformed signal.
Setting this parameter will define the default number of tapers used in the transform 
(number of tapers = 2 * time_halfbandwidth_product - 1.).
Need to study this one although setting the above automatically changes this one
time_halfbandwidth_product = 10


#### 3. Create connectivity object
The Connectivity class computes the frequency-domain connectivity measures from the Fourier coeffcients.

In [10]:
# Split by phase
sample_2k, delay_2k, test_2k, iti_2k = split_by_phase(cg_data_2k)
sample_3k, delay_3k, test_3k, iti_3k = split_by_phase(cg_data_3k)

In [11]:
# Can be over time, trials, tapers or any combination of two or three. Default is by trials and tapers
n_tts =14

# Can be over time, trials, tapers or any combination of two or three. Default is by trials and tapers
exp_type = "trials_tapers"



iti_2k_reshaped, iti_2k_multitaper, iti_2k_conn = get_conn(
    iti_2k, n_tts, 
    fs=2000, window=0.1, step=0.02, exp_type=exp_type, nblocks=10
)


(6529, 6, 14)


### Power calculation from conn object

In [12]:
# Check conn shape (n time windows, frequencies, n_tetrodes)
iti_2k_power = iti_2k_conn.power()
iti_2k_power.shape

(159, 100, 14)

In [13]:
############################ TEST TO SEE IF SPLITTING IMPROVES LACK OF MEMORY

In [14]:
reshaped_df = prepare_for_multitaper(iti_3k, n_tts)

In [15]:
reshaped_df.shape

(33733, 98, 14)

In [24]:
#X1, X2 = np.split(reshaped_df, [80], axis=1)
#X1.shape, X2.shape

In [1]:
iti_3k_conn = get_conn(iti_3k, n_tts, fs=3000, window=0.1, step=0.05, exp_type=exp_type, nblocks=50)

NameError: name 'get_conn' is not defined

In [None]:
iti_3k_conn.power().shape

### Power spectral density plot - Heatmap

In [None]:
# PSD plot - Averaged across tetrodes (14 tetrodes = 1 response)   -- PEDING TO AFTER CHECKING VARIABILITY
# 1 plot per SWR location


In [None]:
# Transform data into something more readable

# Combine datasets

# Combine 3ks
# Combine 2ks with 3ks
# Output 4 datasets

# Average across tts


# Average across events



# Create variability metric



# Plot per location

