In [1]:
import mne
from mne.preprocessing import ICA
from mne.decoding import CSP
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import accuracy_score

<h1 style="color: red; font-weight: bold;">Preprocessing</h1>

## 1.) Loading data

In [2]:
def load_data(subject, session="T"):
    """
    Loads data, includes all channels including EEG and EOG.
    Args:
        subject (int): number of the subject
        session (str):T for training, E for evaluation
    returns: 
        raw (RawGDF): an mne object containing the raw data
        data_raw (list(list(float))): list containing each channel and corresponding data, channel x data samples 
        times (list): list containing each channel and corresponding times, channel x time samples
    """
    
    file_path = f"/Users/kelokomesu/School/COGS 189/final_project/data/A0{subject}{session}.gdf"
    raw = mne.io.read_raw_gdf(file_path, preload=True, eog=["EOG-left", "EOG-central", "EOG-right"], verbose=0)
    data_raw, times = raw.get_data(return_times=True)
    raw.rename_channels({'EEG-Fz': 'EEG-Fz','EEG-0': 'EEG-2','EEG-1': 'EEG-3','EEG-2': 'EEG-4','EEG-3': 'EEG-5','EEG-4': 'EEG-6','EEG-5': 'EEG-7','EEG-C3': 'EEG-C3','EEG-6': 'EEG-9','EEG-Cz': 'EEG-Cz','EEG-7': 'EEG-11','EEG-C4': 'EEG-C4','EEG-8': 'EEG-13','EEG-9': 'EEG-14','EEG-10': 'EEG-15','EEG-11': 'EEG-16','EEG-12': 'EEG-17','EEG-13': 'EEG-18','EEG-14': 'EEG-19','EEG-Pz': 'EEG-Pz','EEG-15': 'EEG-21','EEG-16': 'EEG-22','EOG-left': 'EOG-left','EOG-central': 'EOG-central','EOG-right': 'EOG-right'})

    return raw, data_raw, times

In [3]:
# test subject 1
sub1_raw, sub1_data_raw, sub1_times = load_data(1)

  next(self.gen)


## 2.) Filter data

In [4]:
def bandpass_filter(raw_data, bands):
    """
    Creates bandpass filters for raw data
    Args:
        raw_data (list(list(floats))): raw data from a single subject
        bands (list(tuples)): bands which need to be created
    Returns:
        filtered_data (list(RawGDF)): returns a list of RawGDF objects, each with their specified bandpass filter
    """

    filtered_data = []
    for band in bands:
        filtered_data.append(raw_data.filter(l_freq=band[0], h_freq=band[1], verbose=0))
        
    return filtered_data

In [5]:
# a list of RawGDF objects, each with thier own bandpass filter
sub1_filtered_list = bandpass_filter(sub1_raw,[(0,5),(5,40)])

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


## 3.) Artifact Detection and Removal (if needed)

### Manual Artifact Detection and Removal

In [6]:
def detect_artifact(raw_data, artifact):
    """
    Produces a number of plots to check if the current dataframe has major artifacts
    Args:
        raw_data (RawGDF): RawGDF object
        artifact (string): 'low freq drift'. 'power line noise', 'heartbeat', 'ocular'
    Returns:
        Plot based on specified artifact
    """
    data = raw_data.copy()
    if artifact == "low freq drift":
        return data.plot(duration=60, remove_dc=False, verbose=0) #direct current offset set to False
    elif artifact == "power line noise":
        return data.plot_psd(fmax=100, verbose=0) #
    elif artifact == "eog":
        eog_epochs = mne.preprocessing.create_eog_epochs(data, baseline=(-0.5, -0.2)) #finds eog artifacts and extracts into an Epoch object in one step
        eog_epochs.plot_image(combine="mean")
        # eog_epochs.average().plot_joint()

In [None]:
detect_artifact(sub1_filtered_list[1], "eog")

### Automatic Artifact Detection and Removal

In [None]:
def do_ICA(filtered_data, n_components, remove):
    data = filtered_data.copy()
    ica = ICA(n_components=n_components, random_state=97, method="fastica")
    ica.fit(data)
    ecg_idx, ecg_scores = ica.find_bads_ecg(data)
    return ecg_idx, ecg_scores

    # # Automatically detect eye-blink artifacts using EOG channels
    # eog_indices, eog_scores = ica.find_bads_eog(raw)
    # # Plot correlation with EOG signals
    # ica.plot_scores(eog_scores)
    # # Remove detected components
    # if remove:
    #     ica.exclude = eog_indices
    #     # Apply ICA cleaning to EEG data
    #     raw_cleaned = ica.apply(raw)
    #     return raw_cleaned

In [None]:
do_ICA(sub1_filtered_list[0], 20, True)

## 4.) Epoching

In [7]:
def create_epoch(filtered_data, min, max):
    """
    Creating epochs from filtered and artifact clean data
    Args:
        filtered_data (RawGDF): gdf object
    Returns:
        epochs (epoch): mne epoch object for this subject and filter
    """

    events = mne.events_from_annotations(filtered_data, verbose=0)

    event_dict = {"rejected trial":1,"eye movements":2,"idling eeg (eyes open)":3,"idling eeg (eyes closed)":4,"new run":5,"start trial":6,"left":7,"right":8,"feet":9,"tongue":10}
    
    epochs = mne.Epochs(
        filtered_data.copy(), # raw filtered data
        events[0], # event array
        tmin=min, # time before event, baseline correction (seconds)
        tmax=max, # time after event (seconds)
        event_repeated="drop", # what to do when multiple events have the same starting time
        preload=True, # preload data into memory for faster processing
        event_id=event_dict, # mapping event id to the actual event description
        verbose=0
        )
    
    return epochs

In [8]:
sub1_epochs_list = [create_epoch(i, -0.5, 2) for i in sub1_filtered_list]

Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Multiple event values for single event times found. Keeping the first occurrence and dropping all others.
Not setting metadata
585 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 585 events and 626 original time points ...
1 bad epochs dropped
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Multiple event values for single event times found. Keeping the first occurrence and dropping all others.
Not setting metadata
585 matching events found
Setting baseline interval to [-0.5, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 585 events and 626 original time points ...
1 bad epochs dropped


<h1 style="color: red; font-weight: bold;">Classification</h1>

# 5.) Feature Extraction

In [8]:
def do_CSP(epoch_data, n_components, labels):
    """
    Extracts spatial features from epoched trails from specified classes. Features are also standardized.
    Args:
        epoch_data (epoch): mne epoch object
    Returns:
        dataset (list(list(int))): dataset as a np array with the last entry in each list as the label

    """
    d = epoch_data.copy()

    epoch_list = []
    for label in labels:
        epoch_list.append(d[label])
    data = mne.concatenate_epochs(epoch_list, verbose=0)

    X = data.pick("eeg").get_data()
    y = data.events[:,-1]

    csp = CSP(n_components=n_components)
    X_csp = csp.fit_transform(X, y)

    dataset = np.hstack((X_csp, y.reshape(-1,1)))
    return dataset


In [10]:
sub1lr_dataset = do_CSP(sub1_epochs_list[1], 6, ["left", "right"])

Not setting metadata
144 matching events found
Applying baseline correction (mode: mean)
Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list)


    Using tolerance 1.5e-05 (2.2e-16 eps * 22 dim * 3e+09  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.


In [None]:
# list of tuples (X data, y data) each tuple has one subject



# 6.) Creating Training, Validation, and Test Datasets

In [9]:
# normal train, val, test
def create_tvt_datasets(dataset, trainr, valtestr):
    """
    Creating test, validation, and test datasets for classification
    Args:
        epoch_data (epoch): epoch object of a single subject
        trainr, valtestr (float): ratio of splits
        label (list(str)): name of the class we are interested in for this dataset, "left", "right", "feet", "toungue"
    """

    data = dataset.copy()

    # Get the number of epochs
    n_epochs = len(data)
    
    # Generate shuffled indices
    indices = np.arange(n_epochs)
    np.random.seed(42)
    np.random.shuffle(indices)

    # Compute split points
    train_end = int(n_epochs * trainr)
    val_end = train_end + int(n_epochs * valtestr)

    # Split indices
    train_indices = indices[:train_end]
    val_indices = indices[train_end:val_end]
    test_indices = indices[val_end:]

    # Create subsets
    epochs_train = data[train_indices]
    epochs_val = data[val_indices]
    epochs_test = data[test_indices]

    return epochs_train, epochs_val, epochs_test

In [12]:
sub1lr_train, sub1lr_val, sub1lr_test = create_tvt_datasets(sub1lr_dataset, 0.7, 0.15)

# 7.) Classify

In [10]:
def LDA_fit(train_data):
    X_train = train_data[:,:-1]
    y_train = train_data[:,-1]

    # Scale data for LDA
    scaler = StandardScaler()
    scaler.fit(X_train)
    X_train_scaled = scaler.transform(X_train)
    # Initialize and train LDA classifier
    lda = LinearDiscriminantAnalysis()
    return lda.fit(X_train_scaled, y_train), scaler


In [14]:
sub1_fitted_model, sub1_scaler = LDA_fit(sub1lr_train)

In [11]:
def LDA_predict(model, X, scaler=None):
    """
    Takes in a model and 
    """
    X_test = X[:,:-1]

    # scale X test data from a fitted train scaler
    X_test_scaled = scaler.transform(X_test)

    # Predict the labels for test data
    y_pred = model.predict(X_test_scaled)

    return y_pred

In [16]:
sub1_preds = LDA_predict(sub1_fitted_model, sub1lr_test, sub1_scaler)

In [17]:
accuracy = accuracy_score(sub1_preds, sub1lr_test[:,-1])

In [18]:
accuracy

0.8695652173913043

<h1 style="color: red; font-weight: bold;">Classification</h1>

In [12]:
def cross_subject_classification_pipeline(bsub, tsubs, lf, hf, emin, emax, csp_comp,):
    """
    Complete cross subject classification and evaluation
    Args:
        subject_base (int): baseline subject for model to be trained on
        subject_test (list(ints)): list of subjects to test base model against
        lf (int): lower frequency to filter raw data
        hf (int): high frequency to filter raw data
        emin (int/float): time in seconds before cue, for baseline correction
        emax (int/float): time in seconds after cue
        csp_comp (int): number of csp components to extract

    """
    # base subject load data
    bsub_raw, bsub_data_raw, bsub_times = load_data(bsub)

    # base subject filter data
    bsub_filtered = bandpass_filter(bsub_raw,[(lf,hf)])

    # base subject epoch data
    bsub_epoch = create_epoch(bsub_filtered[0], emin, emax)

    # base subject feature extraction 
    bsub_dataset = do_CSP(bsub_epoch, csp_comp, ["left", "right"])

    # base subject create train val, test sets
    bsub_train, bsublr_val, bsub_test = create_tvt_datasets(bsub_dataset, 0.7, 0.15)

    # base subject fit LDA model
    bsub_model, bsub_scaler, = LDA_fit(bsub_train)

    # base subject LDA model predictions
    bsub_preds = LDA_predict(bsub_model, bsub_train, bsub_scaler)

    results = []

    # base subject prediction on its own train test split
    bsub_train_y = bsub_train[:,-1]
    bsub_bsub_accuracy = accuracy_score(bsub_train_y, bsub_preds)

    results.append(f"bsub->bsub: {bsub_bsub_accuracy}")

    for tsub in tsubs:
        print(f"start: {tsub}")
        tsub_raw, _, _ = load_data(tsub)

        # test subject filter data
        tsub_filtered = bandpass_filter(tsub_raw,[(lf,hf)])

        # test subject epoch data
        tsub_epoch = create_epoch(tsub_filtered[0], emin, emax)

        # test subject feature extraction 
        tsub_dataset = do_CSP(tsub_epoch, csp_comp, ["left", "right"])

        # using the base subject model on the whole test subject dataset
        tsub_preds = LDA_predict(bsub_model, tsub_dataset, bsub_scaler)
        tsub_dataset_y = tsub_dataset[:,-1]
        
        results.append(f"bsub->sub{tsub}:{accuracy_score(tsub_dataset_y, tsub_preds)}")
        print(f"end: {tsub}")
    return results

In [50]:
cross_subject_classification_pipeline(
    bsub=2,
    tsubs=[1,3,5,6,7,8,9],
    lf=4,
    hf=50,
    emin=-1,
    emax=4,
    csp_comp=6
)

Extracting EDF parameters from /Users/kelokomesu/School/COGS 189/final_project/data/A02T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 6e-05 (2.2e-16 eps * 22 dim * 1.2e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.
start: 1
Extracting EDF parameters from /Users/kelokomesu/School/COGS 189/final_project/data/A01T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG
Creating raw.info structure...
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 5.9e-05 (2.2e-16 eps * 22 dim * 1.2e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.
end: 1
start: 3
Extracting EDF parameters from /Users/kelokomesu/School/COGS 189/final_project/data/A03T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG
Creating raw.info structure...
Reading 0 ... 660529  =      0.000 ...  2642.116 secs...


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 7.6e-05 (2.2e-16 eps * 22 dim * 1.5e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.
end: 3
start: 5
Extracting EDF parameters from /Users/kelokomesu/School/COGS 189/final_project/data/A05T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG
Creating raw.info structure...
Reading 0 ... 686119  =      0.000 ...  2744.476 secs...


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 5.3e-05 (2.2e-16 eps * 22 dim * 1.1e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.
end: 5
start: 6
Extracting EDF parameters from /Users/kelokomesu/School/COGS 189/final_project/data/A06T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG
Creating raw.info structure...
Reading 0 ... 678979  =      0.000 ...  2715.916 secs...


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 8.1e-05 (2.2e-16 eps * 22 dim * 1.7e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.
end: 6
start: 7
Extracting EDF parameters from /Users/kelokomesu/School/COGS 189/final_project/data/A07T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG
Creating raw.info structure...
Reading 0 ... 681070  =      0.000 ...  2724.280 secs...


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 5.2e-05 (2.2e-16 eps * 22 dim * 1.1e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.
end: 7
start: 8
Extracting EDF parameters from /Users/kelokomesu/School/COGS 189/final_project/data/A08T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG
Creating raw.info structure...
Reading 0 ... 675269  =      0.000 ...  2701.076 secs...


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 9.2e-05 (2.2e-16 eps * 22 dim * 1.9e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.
end: 8
start: 9
Extracting EDF parameters from /Users/kelokomesu/School/COGS 189/final_project/data/A09T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG
Creating raw.info structure...
Reading 0 ... 673327  =      0.000 ...  2693.308 secs...


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 9.2e-05 (2.2e-16 eps * 22 dim * 1.9e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.
end: 9


['bsub->bsub0.76',
 'bsub->sub1:0.3194444444444444',
 'bsub->sub3:0.6319444444444444',
 'bsub->sub5:0.4027777777777778',
 'bsub->sub6:0.6388888888888888',
 'bsub->sub7:0.7222222222222222',
 'bsub->sub8:0.7708333333333334',
 'bsub->sub9:0.2916666666666667']

# Formatting Data for RL

In [13]:
def format_for_jacob(bsub):

    # base subject load data
    bsub_raw, _, _ = load_data(bsub)

    # base subject filter data
    bsub_filtered = bandpass_filter(bsub_raw,[(4,40)])

    # base subject epoch data
    bsub_epoch = create_epoch(bsub_filtered[0], -1, 5)

    # base subject feature extraction 
    bsub_dataset = do_CSP(bsub_epoch, 6, ["left", "right"])

    return [bsub_dataset[:,:-1], bsub_dataset[:,-1]]


In [14]:
X_data = []
y = []
for subject in [1,2,3]:
    X_data.append(format_for_jacob(subject)[0])
    y.append(format_for_jacob(subject)[1])

  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 6.5e-05 (2.2e-16 eps * 22 dim * 1.3e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 6.5e-05 (2.2e-16 eps * 22 dim * 1.3e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 6.8e-05 (2.2e-16 eps * 22 dim * 1.4e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 6.8e-05 (2.2e-16 eps * 22 dim * 1.4e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 8.2e-05 (2.2e-16 eps * 22 dim * 1.7e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.


  next(self.gen)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


Computing rank from data with rank=None


  data = mne.concatenate_epochs(epoch_list, verbose=0)


    Using tolerance 8.2e-05 (2.2e-16 eps * 22 dim * 1.7e+10  max singular value)
    Estimated rank (data): 22
    data: rank 22 computed from 22 data channels with 0 projectors
Reducing data rank from 22 -> 22
Estimating class=7 covariance using EMPIRICAL
Done.
Estimating class=8 covariance using EMPIRICAL
Done.


In [72]:
np.array(X_data).shape

(3, 144, 6)

In [74]:
np.array(y)

array([[7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
        7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
        7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
        7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
        7., 7., 7., 7., 7., 7., 7., 7., 8., 8., 8., 8., 8., 8., 8., 8.,
        8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
        8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
        8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.,
        8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.],
       [7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
        7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
        7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
        7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
        7., 7., 7., 7., 7., 7., 7., 7., 8., 8., 8., 8., 8., 8.,

In [15]:
clf = LinearDiscriminantAnalysis()

In [16]:
from eeg_env_agent_test import EEGFeatureEnv

env = EEGFeatureEnv(X_train=X_data[0], y_train=y[0],other_subjects_X=X_data[1:], other_subjects_y=y[1:], classifier=clf)

In [17]:
from stable_baselines3 import PPO

In [None]:
model = PPO('MlpPolicy', env, 
            learning_rate=1e-4, 
            gamma=0.995,
            gae_lambda=0.9,
            batch_size=256,
            clip_range=0.1,
            verbose=1)
model.learn(total_timesteps=5000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


