# Reading data

In [None]:
import mne
import numpy as np
import pandas as pd
import re
from brainvision import read_raw_brainvision as read_raw_brainvision_local
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

In [None]:
unified_events_dict = {
    'Stimulus/ 3-B-NOSTOP-L': 1,
    'Stimulus/ 3-B-NOSTOP-R': 2,
    
    'Stimulus/ 3-B-STOP1-SE-L': 3,
    'Stimulus/ 3-B-STOP1-SE-R': 4,
    'Stimulus/ 3-B-STOP1-SS-L': 5,
    'Stimulus/ 3-B-STOP1-SS-R': 6,
    
    'Stimulus/ 3-B-STOP2-SE-L': 7,
    'Stimulus/ 3-B-STOP2-SE-R': 8,
    'Stimulus/ 3-B-STOP2-SS-L': 9,
    'Stimulus/ 3-B-STOP2-SS-R': 10,
    
    'Stimulus/ 3-B-STOP3-SE-L': 11,
    'Stimulus/ 3-B-STOP3-SE-R': 12,
    'Stimulus/ 3-B-STOP3-SS-L': 13,
    'Stimulus/ 3-B-STOP3-SS-R': 14,

    'Stimulus/ 3-B-STOP4-SE-L': 15,
    'Stimulus/ 3-B-STOP4-SE-R': 16,
    'Stimulus/ 3-B-STOP4-SS-L': 17,
    'Stimulus/ 3-B-STOP4-SS-R': 18,

    'Stimulus/ 3-B-STOP5-SE-L': 19,
    'Stimulus/ 3-B-STOP5-SE-R': 20,
    'Stimulus/ 3-B-STOP5-SS-L': 21,
    'Stimulus/ 3-B-STOP5-SS-R': 22,

    'Stimulus/ 3-B-STOP6-SE-L': 23,
    'Stimulus/ 3-B-STOP6-SE-R': 24,
    'Stimulus/ 3-B-STOP6-SS-L': 25,
    'Stimulus/ 3-B-STOP6-SS-R': 26,

    'Stimulus/ 3-B-STOP7-SE-L': 27,
    'Stimulus/ 3-B-STOP7-SE-R': 28,
    'Stimulus/ 3-B-STOP7-SS-L': 29,
    'Stimulus/ 3-B-STOP7-SS-R': 30,   
    
    'Stimulus/ 3-R-B-NOSTOP-L': 31,
    'Stimulus/ 3-R-B-NOSTOP-R': 32,
    
    'Stimulus/ 3-R-B-STOP1-SE-L': 33,
    'Stimulus/ 3-R-B-STOP1-SE-R': 34,

    'Stimulus/ 3-R-B-STOP2-SE-L': 35,
    'Stimulus/ 3-R-B-STOP2-SE-R': 36,

    'Stimulus/ 3-R-B-STOP3-SE-L': 37,
    'Stimulus/ 3-R-B-STOP3-SE-R': 38,

    'Stimulus/ 3-R-B-STOP4-SE-L': 39,
    'Stimulus/ 3-R-B-STOP4-SE-R': 40,

    'Stimulus/ 3-R-B-STOP5-SE-L': 41,
    'Stimulus/ 3-R-B-STOP5-SE-R': 42,

    'Stimulus/ 3-R-B-STOP6-SE-L': 43,
    'Stimulus/ 3-R-B-STOP6-SE-R': 44,

    'Stimulus/ 3-R-B-STOP7-SE-L': 45,
    'Stimulus/ 3-R-B-STOP7-SE-R': 46,

    'Stimulus/ 3-STOP1-SE-L': 47,
    'Stimulus/ 3-STOP1-SE-R': 48,
    'Stimulus/ 3-STOP1-SS-L': 49,
    'Stimulus/ 3-STOP1-SS-R': 50,

    'Stimulus/ 3-STOP2-SE-L': 51,
    'Stimulus/ 3-STOP2-SE-R': 52,
    'Stimulus/ 3-STOP2-SS-L': 53,
    'Stimulus/ 3-STOP2-SS-R': 54,

    'Stimulus/ 3-STOP3-SE-L': 55,
    'Stimulus/ 3-STOP3-SE-R': 56,
    'Stimulus/ 3-STOP3-SS-L': 57,
    'Stimulus/ 3-STOP3-SS-R': 58,

    'Stimulus/ 3-STOP4-SE-L': 59,
    'Stimulus/ 3-STOP4-SE-R': 60,
    'Stimulus/ 3-STOP4-SS-L': 61,
    'Stimulus/ 3-STOP4-SS-R': 62,

    'Stimulus/ 3-STOP5-SE-L': 63,
    'Stimulus/ 3-STOP5-SE-R': 64,
    'Stimulus/ 3-STOP5-SS-L': 65,
    'Stimulus/ 3-STOP5-SS-R': 66,

    'Stimulus/ 3-STOP6-SE-L': 67,
    'Stimulus/ 3-STOP6-SE-R': 68,
    'Stimulus/ 3-STOP6-SS-L': 69,
    'Stimulus/ 3-STOP6-SS-R': 70,

    'Stimulus/ 3-STOP7-SE-L': 71,
    'Stimulus/ 3-STOP7-SE-R': 72,
    'Stimulus/ 3-STOP7-SS-L': 73,
    'Stimulus/ 3-STOP7-SS-R': 74,
}

In [None]:
def remove_unnecessary_events(raw, events_names = ['Time 0/', 'New Segment/']):
    events_to_remove = []
    events, events_id = mne.events_from_annotations(raw)

    for event_name in events_names:
        events_to_remove.append(events_id[event_name])
    
    # Filter out the unwanted events
    filtered_events = np.array([event.tolist() for event in events if event[2] not in events_to_remove])
    filtered_event_id = {key: value for key, value in events_id.items() if value not in events_to_remove}
    
    new_event_desc = {value:str(key) for key, value in filtered_event_id.items()}
    
    # Use mne.annotations_from_events to create new annotations
    new_annotations = mne.annotations_from_events(filtered_events, sfreq=raw.info['sfreq'], event_desc=new_event_desc)
    
    # Set new annotations to raw data
    raw_copy = raw.copy().set_annotations(new_annotations)

    return raw_copy

## Load data

In [None]:
dir = '../data/sst_old'
filename = 'AD1406_SST14_Artif Rej 75'

filename2 = 'AB2407_SST14_new_Artif Rej 75'
# dir = '../data/gng'
# filename = 'GNG_AA0303-64 el'

In [None]:
raw = read_raw_brainvision_local(f'{dir}/{filename}.vhdr', preload=True)
print(raw.get_data().shape)
raw2 = read_raw_brainvision_local(f'{dir}/{filename2}.vhdr', preload=True)
print(raw2.get_data().shape)

# raw = remove_unnecessary_events(raw)
# raw2 = remove_unnecessary_events(raw2)

raw_concatenated = mne.concatenate_raws([raw, raw2])
print(raw_concatenated.get_data().shape)
events, event_id = mne.events_from_annotations(raw_concatenated, unified_events_dict)

In [None]:
fig = raw_concatenated.plot(duration=360)

In [None]:
# Define the mapping function using regular expressions
def map_event(event):
    # Patterns for matching
    patterns = {
        r'^Stimulus/ 3-B-NOSTOP.*': 'go/nostop',
        r'^Stimulus/ 3-B-STOP(\d+).*': lambda m: f'go/stop/{m.group(1)}',
        r'^Stimulus/ 3-R-B-NOSTOP.*': 'response/correct',
        r'^Stimulus/ 3-R-B-STOP(\d+).*': lambda m: f'response/incorrect/{m.group(1)}',
        r'^Stimulus/ 3-STOP(\d+).*': lambda m: f'stop/{m.group(1)}'
    }	
    # Check each pattern
    for pattern, replacement in patterns.items():
        match = re.fullmatch(pattern, event)
        if match:
            return replacement if not callable(replacement) else replacement(match)
    # Default return value if no pattern matches
    return 'unknown'

# Define the function to categorize events into 'go', 'response', or 'stop'
def categorize_type(event_general):
    if 'go' in event_general:
        return 'go'
    elif 'response' in event_general:
        return 'response'
    elif 'stop' in event_general:
        return 'stop'
    return 'unknown'

# Define the SSD mapping
ssd_mapping = {
    1: 100,
    2: 150,
    3: 200,
    4: 250,
    5: 300,
    6: 350,
    7: 400
}

# Define the function to map to SSD values
def map_ssd(event_general):
    if 'response/incorrect/' in event_general:
        match = re.search(r'response/incorrect/(\d+)', event_general)
        if match:
            number = int(match.group(1))
            return ssd_mapping.get(number, np.nan)  # Use np.nan for missing values
    return np.nan

In [None]:
events_df = pd.DataFrame(events, columns=['latency', 'duration', 'id'])

# Invert the dictionary to map IDs to event names
id_to_event = {value: key for key, value in event_id.items()}

# Create a new 'event' column by mapping 'id' to event names
events_df['event'] = events_df['id'].map(id_to_event)
events_df['event_general'] = events_df['event'].apply(map_event)
events_df['type'] = events_df['event_general'].apply(categorize_type)
events_df['ssd'] = events_df['event_general'].apply(map_ssd)
events_df['response_type'] = events_df['event_general'].str.extract(r'response/(correct|incorrect)', expand=False)
events_df['response_type'] = events_df['response_type'].fillna('n-a')

# Calculate the mean of the 'ssd' column
ssd_mean = np.nanmean(events_df['ssd'])
# Create the 'ssd_centered' column by subtracting the mean from each 'ssd' value
events_df['ssd_centered'] = events_df['ssd'] - ssd_mean

In [None]:
events_df

In [None]:
events_df.to_csv(f'{filename}_events.csv')

In [None]:
events_df[events_df['event_general'].str.contains('response/incorrect/')]

---
## Unfold

In [None]:
# Import the Julia package manager
from juliacall import Pkg as jlPkg

# Activate the environment in the current folder
jlPkg.activate(".")

# Check the status of the environment/installed packages -> will be empty at the beginning
print(jlPkg.status())

# Install Julia packages
jlPkg.add("Unfold")
jlPkg.add("DataFrames")
jlPkg.add(url="https://github.com/unfoldtoolbox/UnfoldDecode.jl")

In [None]:
from juliacall import Main as jl

# The function seval() can be used to evaluate a piece of Julia code given as a string
jl.seval("using DataFrames")
jl.seval("using Unfold")
jl.seval("using UnfoldDecode")
Unfold = jl.Unfold
UnfoldDecode = jl.UnfoldDecode

Create model

In [None]:
data = raw.get_data()
print(raw.info['sfreq'])
print(data.shape)

In [None]:
events_df_no_nans = events_df.copy()
events_df_no_nans['ssd_centered'] = events_df_no_nans['ssd_centered'].fillna(0)
events_df_no_nans

In [None]:
# Specify model formula
bf_go = jl.seval("bf_go = firbasis(τ = (-0.2, 0.5), sfreq = 64)")
bf_stop= jl.seval("bf_stop = firbasis(τ = (-0.2, 0.5), sfreq = 64)")
bf_response = jl.seval("bf_response = firbasis(τ = (-0.1, 0.5), sfreq = 64)")

formula_go = jl.seval("f_go = @formula 0 ~ 1")
formula_stop = jl.seval("f_stop = @formula 0 ~ 1")
formula_res = jl.seval("f_res = @formula 0 ~ 1 + response_type + ssd_centered + response_type*ssd_centered")

bfDict = jl.seval("[ \"go\" => (f_go, bf_go), \"stop\" => (f_stop, bf_stop),  \"response\" => (f_res, bf_response)]")

# Convert the Python columns to Julia arrays
type_column = jl.seval("Vector{String}")(events_df_no_nans['type'].tolist())
response_type_column = jl.seval("Vector{String}")(events_df_no_nans['response_type'].tolist())
ssd_centered_column = jl.seval("Vector{Float64}")(events_df_no_nans['ssd_centered'].tolist())
latency_column = jl.seval("Vector{Int64}")(events_df_no_nans['latency'].tolist())

# Create the Julia DataFrame
events_df_jl = jl.DataFrame(type=type_column, latency=latency_column, ssd_centered=ssd_centered_column, response_type=response_type_column)

# Fit Unfold model
m = Unfold.fit(
    Unfold.UnfoldModel,
    bfDict,
    events_df_jl,
    data,
    # solver = jl.seval("(x, y) -> Unfold.solver_default(x, y; stderror = true)"),
    eventcolumn = "type",
)

model_matrix = Unfold.modelmatrix(m)
design_matrix = Unfold.designmatrix(m)

In [None]:
pd.DataFrame(model_matrix.to_numpy()).to_csv('design_matrix.csv')

In [None]:
def jl_results_to_python(results_jl):
    results_py_df = pd.DataFrame({
        'channel': results_jl.channel,
        'coefname': results_jl.coefname,
        'estimate': results_jl.estimate,
        'eventname': results_jl.eventname,
        'group': results_jl.group,
        'stderror': results_jl.stderror,
        'time': results_jl.time
    })
    return results_py_df

In [None]:
results_jl = Unfold.coeftable(m)
results_py = jl_results_to_python(results_jl)
results_py

In [None]:
# Extract the coefficients for one channel
results_channel = results_py[results_py.channel == 1]

results_go = results_channel[results_channel.eventname == 'go']
results_stop = results_channel[results_channel.eventname == 'stop']
results_response = results_channel[results_channel.eventname == 'response']


# Plot the coefficient estimates over time
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

ax1 = sns.lineplot(
    x = results_go.time,
    y=results_go.estimate,
    hue=results_go.eventname,
    ax=ax1
)
ax1.set(xlabel='Time [s]', ylabel='Coefficient estimate', title='Go')

ax2 = sns.lineplot(
    x = results_stop.time,
    y=results_stop.estimate,
    hue=results_stop.eventname,
    ax=ax2
)
ax2.set(xlabel='Time [s]', ylabel='Coefficient estimate', title='Stop')

ax3 = sns.lineplot(
    x = results_response.time,
    y=results_response.estimate,
    hue=results_response.coefname,
    ax=ax3
)
ax3.set(xlabel='Time [s]', ylabel='Coefficient estimate', title='Response')


plt.tight_layout()
plt.show()