"""
This example compares the classification performance of 
linear support vector machine (LinearSVC) on the
Riemannian Transfer Learning (RPA, Rodrigues et al., 2018) method
and the golden-standard subject-wise train-test cross-validation method
using real P300 BCI data.

Copyright © 2023, Fahim Doumi <fahim.doumi@outlook.fr> and Fatih Altindis <fthaltindis@gmail.com>
Team ViBS (head: Marco Congedo <marco.congedo@gipsa-lab.grenoble-inp.fr>)
GIPSA-lab, CNRS, Université Grenoble Alpes
License: BSD 3-Clause 

References:
P.L.C. Rodrigues, C. Jutten, M. Congedo (2018)
Riemannian procrustes analysis: transfer learning for brain–computer interfaces
IEEE Transactions on Biomedical Engineering, 66, 8, 2390-2401.
pdf: https://hal.science/hal-01971856/document
"""

In [None]:
import numpy as np
import pandas as pd
import warnings 
import matplotlib.pyplot as plt

from datetime import datetime
from joblib import dump, load
from tqdm import tqdm

from moabb.datasets import BNCI2014008, bi2013a, bi2014a, bi2014b, bi2015a
from moabb.paradigms import P300

from pyriemann.estimation import ERPCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.transfer import TLSplitter, TLCenter, TLStretch, TLRotate, TLClassifier
from pyriemann.transfer import encode_domains, decode_domains

from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.svm import LinearSVC
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.exceptions import ConvergenceWarning

In [None]:
# In this program we want to consider 1 session of 1 subject as 1 target
def get_subject_data(subject, session, X, y, metadata):
    # Create session string based on input session number
    session_str = f'session_{session}'

    # Select data for a specific subject and session
    X_subject_session = X[(metadata["subject"] == subject) & (metadata["session"] == session_str)]
    y_subject_session = y[(metadata["subject"] == subject) & (metadata["session"] == session_str)]

    return X_subject_session, y_subject_session

warnings.filterwarnings('ignore', category=UserWarning, message='Convergence not reached.')
warnings.filterwarnings("ignore", category=ConvergenceWarning)

In [20]:
# Load the database, choose the one you want to test
dataset = BNCI2014008()
paradigm = P300()
X, y, metadata = paradigm.get_data(dataset)

In [21]:
# Selection of the source subject/session :
# need to be changed for each different database ;
# corresponds to the subject with the best score obtained 
# with the WithinSessionEvaluation function (from PyRiemann)
# using the same pipeline you will use here. 
# Example in the repository : FindSource.ipynb 
source = 8
session_source = 0

subject_list = np.unique(metadata["subject"]) # all subjects from database you want to select
session_list = []
target_list = []

for subject in subject_list:
    # get sessions for each subject
    sessions = metadata[metadata["subject"]==subject]["session"].unique()
    # Convert each session to an integer
    sessions = [int(s.split('_')[-1]) for s in sessions]
    # Add the pairs (subject, session) to target_list
    target_list.extend([(subject, session) for session in sessions if not (subject == source and session == session_source)])

In [19]:
# number of trials of the target domain for training
n_trials = [6, 12, 32, 48]

# defining the source and target domain
source_domain = f'subject_{source:02}_session_{session_source}'
target_domain = ''

# Object for splitting the datasets into training and validation partitions
# the training set is composed of all data points from the source domain for the RPA
# and only partition of the target domain will be training part for the calibration pipeline
n_splits = 5 # how many times to split the target domain into train/test for cross-validation
seed = 50 # set seed for reproducible results
tl_cv = TLSplitter(
    target_domain=target_domain,
    cv=StratifiedShuffleSplit(n_splits=n_splits, random_state=seed),
)

# setting up base classifier for TL
clf_base = LinearSVC(tol=1e-6, class_weight="balanced")

#scores (here 2 types, as many as you want)
cumulative_scores_bac = {target: {meth: {trials: [] for trials in n_trials} for meth in ['rpa', 'calibration']} for target in target_list}
cumulative_scores_roc = {target: {meth: {trials: [] for trials in n_trials} for meth in ['rpa', 'calibration']} for target in target_list}

In [5]:
# Load datas for the source subject by using `get_subject_data`
d_list_source = []
X_source, y_source = get_subject_data(source, session_source, X, y, metadata)
erpcov_source = ERPCovariances(classes=["Target"], estimator='lwf') #use SVD if N >= 32
cov_source = erpcov_source.fit_transform(X_source, y_source)
d_list_source = d_list_source + [f'subject_{source:02}_session_{session_source}'] * len(X_source)
d_list_source = np.array(d_list_source)

#Encoding the source for TL
cov_source_enc, y_source_enc = encode_domains(cov_source, y_source, d_list_source)

In [None]:
#TL LOOP
for target in target_list:
    subject, session = target
    print("Subject:", subject, "Session:", session)    
    
    # Load data for subject using `get_subject_data`
    d_list = []
    X_target, y_target = get_subject_data(subject, session, X, y, metadata)
    d_list = d_list + [f'subject_{subject:02}_session_{session}'] * len(X_target)
    domains = np.array(d_list)

    # Encoding datas for transfer learning
    X_enc, y_enc = encode_domains(X_target, y_target, domains)
    
    for trials in tqdm(n_trials):
        # Create dict for storing results of this particular CV split by scorer
        scores_cv_bac = {meth: [] for meth in ['rpa', 'calibration']}
        scores_cv_roc = {meth: [] for meth in ['rpa', 'calibration']}

        # Change the target domain
        target_domain = f'subject_{subject:02}_session_{session}'
        tl_cv.target_domain = target_domain

        # Change fraction of the target training partition
        tl_cv.cv.train_size = trials
        print(f"Number of trials from target domain (Subject: {subject}, Session: {session}) for training is {trials}")

        for train_idx, test_idx in tl_cv.split(X_enc, y_enc):
            # Split the target domain into training and testing
            X_enc_train, X_enc_test = X_enc[train_idx], X_enc[test_idx]
            y_enc_train, y_enc_test = y_enc[train_idx], y_enc[test_idx]
            
            # Fit ERPCovariances with training trials and obtain covariances matrices
            # A different prototype of super trial is obtained with training partition
            # of target for each different split and n_trials
            # and X_test is transform with this prototype
            erpcov_target = ERPCovariances(classes=["Target"], estimator='lwf') #use SVD if N >= 32
            cov_train_enc = erpcov_target.fit_transform(X_enc_train, y_target[train_idx])
            X_test = erpcov_target.transform(X_enc_test)

            # Concatenate training from source and target
            X_train = np.concatenate((cov_source_enc, cov_train_enc))
            y_train = np.concatenate((y_source_enc, y_enc_train))
            
            # (1) RPA pipeline: recenter, stretch, and rotate
            # Classifier is trained with points from source only
            pipeline_rpa = make_pipeline(
                TLCenter(target_domain=target_domain),
                TLStretch(
                    target_domain=target_domain,
                    final_dispersion=1,
                    centered_data=True,
                ),
                TLRotate(target_domain=target_domain, metric='riemann'),
                TangentSpace(metric="riemann"),
                TLClassifier(
                    target_domain=target_domain,
                    estimator=clf_base,
                    domain_weight={source_domain: 1.0, target_domain: 0.0},
                ),
            )

            pipeline_rpa.fit(X_train, y_train)
            _, y_true, _ = decode_domains(X_enc_test, y_enc_test)
            y_pred_bac_rpa = pipeline_rpa.predict(X_test)
            y_test = np.array([y_true == i for i in np.unique(y_true)]).T
            y_pred = np.array([y_pred_bac_rpa == i for i in np.unique(y_pred_bac_rpa)]).T
            scores_cv_bac['rpa'].append(balanced_accuracy_score(y_true, y_pred_bac_rpa))
            scores_cv_roc['rpa'].append(roc_auc_score(y_test, y_pred))

            # (2) Calibration: use only data from target-train partition.
            # Classifier is trained only with points from the target domain.
            pipeline_cal = make_pipeline(
                TangentSpace(metric="riemann"),
                TLClassifier(
                    target_domain=target_domain,
                    estimator=clf_base,
                    domain_weight={source_domain: 0.0, target_domain: 1.0},
                ),
            )

            pipeline_cal.fit(cov_train_enc, y_enc_train)
            _, y_true, _ = decode_domains(X_enc_test, y_enc_test)
            y_pred_bac_cal = pipeline_cal.predict(X_test)
            y_test = np.array([y_true == i for i in np.unique(y_true)]).T
            y_pred = np.array([y_pred_bac_cal == i for i in np.unique(y_pred_bac_cal)]).T
            scores_cv_bac['calibration'].append(balanced_accuracy_score(y_true, y_pred_bac_cal))
            scores_cv_roc['calibration'].append(roc_auc_score(y_test, y_pred))

            # Get the average score of each pipeline
        for meth in ['rpa', 'calibration']:
            cumulative_scores_bac[target][meth][trials].append(np.mean(scores_cv_bac[meth]))
            cumulative_scores_roc[target][meth][trials].append(np.mean(scores_cv_roc[meth]))

In [13]:
# Calculate average scores for each method and each n_trials per target
average_scores_roc = {target: {meth: {trials: np.mean(cumulative_scores_roc[target][meth][trials]) for trials in n_trials} for meth in ['rpa', 'calibration']} for target in target_list}
average_scores_bac = {target: {meth: {trials: np.mean(cumulative_scores_bac[target][meth][trials]) for trials in n_trials} for meth in ['rpa', 'calibration']} for target in target_list}

In [14]:
# averages score for all targets
average_scores_all_targets_bac = {meth: {trials: 0 for trials in n_trials} for meth in ['rpa', 'calibration']}
for meth in ['rpa', 'calibration']:
    for trials in n_trials:
        all_targets_scores_bac = []
        for target in target_list:
            all_targets_scores_bac.append(average_scores_bac[target][meth][trials])
        average_scores_all_targets_bac[meth][trials] = np.mean(all_targets_scores_bac)

average_scores_all_targets_roc = {meth: {trials: 0 for trials in n_trials} for meth in ['rpa', 'calibration']}
for meth in ['rpa', 'calibration']:
    for trials in n_trials:
        all_targets_scores_roc = []
        for target in target_list:
            all_targets_scores_roc.append(average_scores_roc[target][meth][trials])
        average_scores_all_targets_roc[meth][trials] = np.mean(all_targets_scores_roc)

In [None]:
#Saving the scores 
path = '/../../'  #your file path
database_name = dataset.__class__.__name__
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dump(average_scores_roc, f'{path}{database_name}_average_scores_roc_{timestamp}.joblib')
dump(average_scores_bac, f'{path}{database_name}_average_scores_bac_{timestamp}.joblib')
dump(average_scores_all_targets_bac, f'{path}{database_name}_average_scores_all_targets_bac_{timestamp}.joblib')
dump(average_scores_all_targets_roc, f'{path}{database_name}_average_scores_all_targets_roc_{timestamp}.joblib')

#if i want to load scores
#average_scores_bac = load('/../../"database_name"_average_scores_all_targets_bac_"timestamp".joblib')

# Plots and tables : changes scores variables if you are using more than 1 scorer 

In [None]:
# Plot for all targets mean score per n_trials

fig, ax = plt.subplots(figsize=(10, 5))
# for each method
for meth in ['rpa', 'calibration']:
    # get scores for all trial numbers
    scores = [average_scores_all_targets_bac[meth][trials] for trials in n_trials]
    
    # plot scores
    ax.plot(range(len(n_trials)), scores, label=meth, lw=3.0)

# set title, labels, etc.
ax.set_title(f"Results for {database_name}")
ax.set_xlabel('Number of training trials in target domain')
ax.set_ylabel('Classification score')
ax.set_ylim(0.48, 0.6) # change values if needed
ax.legend(loc='lower right')

# set x-axis ticks and labels
ax.set_xticks(range(len(n_trials)))
ax.set_xticklabels(n_trials)

plt.show()

In [None]:
# Plot for each n_trials : highly recommanded if you are using all subjects

# get a sorted list of subjects
subjects_sorted = sorted(target_list)

# one plot for each n_trials
fig, axs = plt.subplots(len(n_trials), 1, figsize=(10, 5 * len(n_trials)))
fig.suptitle(f"Results for {database_name}", fontsize=16, y=1.0)

# for each ntrials
for i, trials in enumerate(n_trials):
    ax = axs[i]

    # Collect scores and subjects in pairs for each method
    scores_subjects = {}
    for meth in ['rpa', 'calibration']:
        # Pair scores and subjects together
        scores_subjects[meth] = [(average_scores_bac[subj][meth][trials], subj) for subj in subjects_sorted]

        # Sort pairs by scores in ascending order
        scores_subjects[meth].sort()

    # Unpack scores and subjects from sorted pairs for each method
    for meth in ['rpa', 'calibration']:
        scores, sorted_subjects = zip(*scores_subjects[meth])

        # plot scores
        ax.plot(range(len(scores)), scores, label=meth, lw=3.0)

    # set title, labels, etc.
    ax.set_title(f"Training trials in target domain: {int(trials)}")
    ax.set_xlabel('Subject, Session')
    ax.set_ylabel('Classification score')
    ax.legend(loc='lower right')
    ax.set_ylim(0.48, 0.6) # change values if needed
    ax.set_xticks(range(len(sorted_subjects)))
    ax.set_xticklabels(sorted_subjects)

# adjust layout
plt.tight_layout()
plt.subplots_adjust(right=1.5) # change values if needed
plt.show()


In [None]:
# Plot for each subject/session, not recommanded if you use all subjects from database
# create a figure with one subplot for each subject
fig, axs = plt.subplots(len(subjects_sorted), 1, figsize=(10, 5 * len(subjects_sorted)))
fig.suptitle(f"Results for {database_name}", fontsize=16, y=1.0)

# for each subject
for i, subj in enumerate(subjects_sorted):
    ax = axs[i]
    
    # for each method
    for meth in ['rpa', 'calibration']:
        # get scores for all trial numbers
        scores = [average_scores_bac[subj][meth][trials] for trials in n_trials]
        
        # plot scores
        ax.plot(range(len(n_trials)), scores, label=meth, lw=3.0)

    # set title, labels, etc.
    ax.set_title(f"Subject {subj}")
    ax.set_xlabel('Number of training trials in target domain')
    ax.set_ylabel('Classification score')
    ax.legend(loc='lower right')
    ax.set_ylim(0.48, 0.6) # change values if needed

    # set x-axis ticks and labels
    ax.set_xticks(range(len(n_trials)))
    ax.set_xticklabels(n_trials)

# adjust layout
plt.tight_layout()
plt.show()


In [None]:
# Table for each n_trials with all targets
for trials in n_trials:
    data = {}
    for meth in ['rpa', 'calibration']:
        scores = [average_scores_bac[targets][meth][trials] for subj in subjects_sorted]
        data[meth] = scores
    df = pd.DataFrame(data, index=subjects_sorted)
    print(f"Number of trials: {trials}")
    print(df)
    print("\n---\n")

In [None]:
# Table for all targets mean score per n_trials
# Creat dict to store scores for each method 
data = {}

# for each method
for meth in ['rpa', 'calibration']:
    # Obtenir les scores moyens pour tous les nombres d'essais
    scores = [average_scores_all_targets_bac[meth][trials] for trials in n_trials]
    # Ajouter les scores à la méthode correspondante dans le dictionnaire
    data[meth] = scores

# Créer un DataFrame à partir du dictionnaire
df_all_subjects = pd.DataFrame(data, index=n_trials)

# Renommer l'index
df_all_subjects.index.name = 'Number of trials'

# Afficher le DataFrame
print(df_all_subjects)