# Brain-Supervised Image Editing
## Supplementary Material
### Code

# Load dependencies

In [1]:
import numpy as np
import scipy as sp
import mne
from os import listdir
from os.path import join
import pandas as pd
import sklearn
import re
import pickle
from tqdm.notebook import tqdm
import sys

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from functools import reduce
import matplotlib
import matplotlib.pyplot as plt
import mne
import pandas as pd
from os import listdir, path

TASK_LIST = ['facecat/female','facecat/male',
             'facecat/nosmile', 'facecat/smiles',
             'facecat/old', 'facecat/young',
            'facecat/blond', 'facecat/darkhaired']

# Load Preprocessed EEG data

In [20]:
import lzma
file = lzma.open('data/facecat.xz', 'rb')
db = pickle.load(file)
file.close()

EEG data have already been preprocessed such that they are cleaned and segmented into epochs. An epoch corresponds to all data produced by a single stimulus viewing event. 

The data are structured as a large nested dictionary:

Level One: 30 dictionaries, keys are integers from 0 to 29 - one for each participant in the study.

Level Two: 8 dictionaries, keys are strings corresponding to task grouping and semantic feature - ['facecat/blond', 'facecat/darkhaired', 'facecat/female', 'facecat/male', 'facecat/nosmile', 'facecat/old', 'facecat/smiles', 'facecat/young']

Level Three: 4 keys, keys are "test_x", "test_y", "train_x", "train_y", corresponding to the train and test set data and labels.

Level Four (labels): A numpy 2D array (X, Y), where X corresponds to labels as integers 0 and 1, indicating non-target (0) and target (1) classes. Thus for "facecat/female", a label of 1 indicates the associated EEG data correspond to viewing a female face, whereas a label of 0 indicates the EEG data correspond to viewing a male face. Y corresponds to the specific stimuli ID/latent ID tied to the 

Level Four (data): A numpy.ndarray of shape (X, Y=5, Z=32). X corresponds to the epochs, Y corresponds to averaged measured EEG voltages at 0-100ms, 100-200ms, 200-300ms, 300-400ms, and 400-500ms. Z corresponds to each of the 32 channels used to record the EEG data for each epoch.

Note: Due to artifact removal procedures during pre-processing (eye/head movements and other sources of noise), the size of the test and train sets varies between participants

# Load in the latents used for CelebGAN generation

In [42]:
file = open('data/latents.pkl', 'rb')
latents = pickle.load(file)
file.close()

# Reshape Data so it is anonymized (100ms averages), with channel names

In [None]:
 # Results Loop
results_dict = {}
for user in tqdm(range(len(epos))):
    results_dict[user] = {}
    lda_data = epoch_to_lda_data(epos[user], times=[-0.2, 0.9])
    
    for task in TASK_LIST:
        test_x = split_average(lda_data[task]['test_x'], n_split=25)
        train_x = split_average(lda_data[task]['train_x'], n_split=25)
        
        train_x = train_x.reshape([train_x.shape[0], train_x.shape[1] * train_x.shape[2]])

        test_y = lda_data[task]['test_y'].rel
        train_y = lda_data[task]['train_y'].rel
        
        img_ids = lda_data[task]['test_y'].img
        event_type =  lda_data[task]['test_y'].event
        
        lda = LDA(solver = 'lsqr', shrinkage = 'auto')
        lda = lda.fit(train_x, list(train_y)) 
        
        test_x = test_x.reshape([test_x.shape[0], test_x.shape[1] * test_x.shape[2]])
        yhat = lda.predict_proba(test_x)[:,1]
        
        results_dict[user][task] = {"yhat":yhat, "y":test_y, "img_id": list(img_ids), "event":event_type}

In [None]:
sklearn.metrics.roc_auc_score(list(results_dict[0]['facecat/female']['y']), results_dict[0]['facecat/female']['yhat'], average="macro", multi_class='ovo')