# SEPIO performed on a simulated sulcus
Define a gray matter region as show below, and attempt to classify which source is on.

In [None]:
# Import libraries
import matplotlib.pyplot as plt
import numpy as np
from sklearn import metrics
from sklearn.model_selection import train_test_split
from os import path
import sys
import pysensors as ps
from pysensors.classification import SSPOC
sys.path.insert(0, path.join('.'))
from modules.leadfield_importer import FieldImporter
import modules.disc_plotter as dp

# Point to unzipped dataset folder
folder = r"...\SEPIO_dataset"

# Access lead fields
disc_fields_file = path.join(folder, 'leadfields', 'DISC_250um_grid_dist_ref.npz')
disc_noise = 4.1  # uV rms
seeg_fields_file = path.join(folder, 'leadfields', 'SEEG_250um_grid_dist_ref.npz')
seeg_noise = 2.7  # uV rms
ecog_fields_file = path.join(folder, 'leadfields', 'ECOG_250um_grid_rotated.npz')
ecog_noise = 4.1  # uV rms
# magnitude = 0.08e-9  # nAm (200um)
magnitude = 0.125e-9  # nAm (250um)
# magnitude = 0.245e-9  # nAm (350um)
# magnitude = 0.5e-9  # nAm (500um)
num_trials = 30000 # Number of Monte-Carlo cycles to randomize noise
car = False

In [None]:
# Function to convert position to lead field index
# Assumes a 250 um voxel
def translate_index_to_position(x: int, y: int, z: int):
    x_pos = (x-20) / 4
    y_pos = (y-20) / 4
    z_pos = z / 4

    return (x_pos, y_pos, z_pos)

## Data Generation

### Get fields from ANSYS models:

In [None]:
### Initializing lead fields and moving into desired position relative to sources ###

# Get fields
field_importer_disc = FieldImporter()
field_importer_disc.load(disc_fields_file)
field_importer_disc.duplicate_fields()
# field_importer_disc.rotate(theta=-45, electrodes=range(128,256))
field_importer_disc.translate(z=-1, electrodes=range(128,256))
field_importer_disc.translate(y=4, electrodes=range(256))
# field_importer_disc.rotate(psi=15, electrodes=range(256))
# field_importer_disc.rotate(theta=90, electrodes=range(256))
disc_fields = field_importer_disc.fields
num_electrodes_disc = np.shape(disc_fields)[4]

field_importer_seeg = FieldImporter()
field_importer_seeg.load(seeg_fields_file)
field_importer_seeg.translate(y=4, electrodes=range(18))
# field_importer_seeg.rotate(psi=15, electrodes=range(18))
# field_importer_seeg.rotate(theta=90, electrodes=range(18))
seeg_fields = field_importer_seeg.fields[:,:,:,:,2:]
num_electrodes_seeg = np.shape(seeg_fields)[4]

# Note: ECoG electrode sits at 0,0,5mm
field_importer_ecog = FieldImporter()
field_importer_ecog.load(ecog_fields_file)
field_importer_ecog.duplicate_fields(count=121)  # number of ecog electrodes
num_electrodes_ecog = np.shape(field_importer_ecog.fields)[4]
# X translations:
field_importer_ecog.translate(x=-10, electrodes=range(0,11))
field_importer_ecog.translate(x=-8, electrodes=range(11,22))
field_importer_ecog.translate(x=-6, electrodes=range(22,33))
field_importer_ecog.translate(x=-4, electrodes=range(33,44))
field_importer_ecog.translate(x=-2, electrodes=range(44,55))
field_importer_ecog.translate(x=0, electrodes=range(55,66))
field_importer_ecog.translate(x=2, electrodes=range(66,77))
field_importer_ecog.translate(x=4, electrodes=range(77,88))
field_importer_ecog.translate(x=6, electrodes=range(88,99))
field_importer_ecog.translate(x=8, electrodes=range(99,110))
field_importer_ecog.translate(x=10, electrodes=range(110,121))
# Y translations:
field_importer_ecog.translate(z=-6, y=13, electrodes=range(0,num_electrodes_ecog,11))
field_importer_ecog.translate(z=-4, y=12, electrodes=range(1,num_electrodes_ecog,11))
field_importer_ecog.translate(z=-2, y=11, electrodes=range(2,num_electrodes_ecog,11))
field_importer_ecog.translate(z=0, y=10, electrodes=range(3,num_electrodes_ecog,11))
field_importer_ecog.translate(z=2, y=9, electrodes=range(4,num_electrodes_ecog,11))
field_importer_ecog.translate(z=4, y=8, electrodes=range(5,num_electrodes_ecog,11))
field_importer_ecog.translate(z=6, y=7, electrodes=range(6,num_electrodes_ecog,11))
field_importer_ecog.translate(z=8, y=6, electrodes=range(7,num_electrodes_ecog,11))
field_importer_ecog.translate(z=10, y=5, electrodes=range(8,num_electrodes_ecog,11))
field_importer_ecog.translate(z=12, y=4, electrodes=range(9,num_electrodes_ecog,11))
field_importer_ecog.translate(z=14, y=3, electrodes=range(10,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=0, y=10, electrodes=range(0,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=2, y=9, electrodes=range(1,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=4, y=8, electrodes=range(2,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=6, y=7, electrodes=range(3,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=8, y=6, electrodes=range(4,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=10, y=5, electrodes=range(5,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=12, y=4, electrodes=range(6,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=14, y=3, electrodes=range(7,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=16, y=2, electrodes=range(8,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=18, y=1, electrodes=range(9,num_electrodes_ecog,11))
# field_importer_ecog.translate(z=20, y=0, electrodes=range(10,num_electrodes_ecog,11))
# Translate whole ecog... currently centered at 0,0,18
field_importer_ecog.translate(y=-24, z=0, electrodes=range(num_electrodes_ecog))  # Now spans 20-40 or 5.0-10mm. 1mm away from center of dipole on "flat" part
ecog_fields = field_importer_ecog.fields[:,:,:,:,:]
num_electrodes_ecog = np.shape(ecog_fields)[4]

### Generate Dipoles

Notes:

- X is INTO the page

In [None]:
### Defining all dipoles in gray matter space ###

dipoles = []
positions = []

for x_pos in (11, 17, 23, 29):
# for x_pos in (27, 29, 31, 33):
# for x_pos in (30,):
    positions.append(np.array(
        [[x_pos,16,7]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,15,9]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,14,11]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,13,13]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude, 0]]
    ))
    positions.append(np.array(
        [[x_pos,14,15]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,16,17]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,17,18]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude, 0]]
    ))
    positions.append(np.array(
        [[x_pos,17,20]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.8, -magnitude*0.8]]
    ))
    positions.append(np.array(
        [[x_pos,16,21]]
    ))
    dipoles.append(np.array(
        [[0, 0, -magnitude]]
    ))
    positions.append(np.array(
        [[x_pos,14,20]]
    ))
    dipoles.append(np.array(
        [[0, magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,13,19]]
    ))
    dipoles.append(np.array(
        [[0, 0, -magnitude]]
    ))
    positions.append(np.array(
        [[x_pos,11,19]]
    ))
    dipoles.append(np.array(
        [[0, 0, -magnitude]]
    ))
    positions.append(np.array(
        [[x_pos,10,20]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,9,22]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,8,24]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude, 0]]
    ))
    # CS starts
    positions.append(np.array(
        [[x_pos,9,26]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude, 0]]
    ))
    positions.append(np.array(
        [[x_pos,11,27]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,13,28]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,15,29]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,17,30]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,19,31]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, magnitude*0.7]]
    ))
    # Fundus
    positions.append(np.array(
        [[x_pos,21,32]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,22,33]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude, 0]]
    ))
    positions.append(np.array(
        [[x_pos,22,35]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.8, -magnitude*0.8]]
    ))
    positions.append(np.array(
        [[x_pos,21,35]]
    ))
    dipoles.append(np.array(
        [[0, 0, -magnitude]]
    ))
    # End fundus
    positions.append(np.array(
        [[x_pos,19,35]]
    ))
    dipoles.append(np.array(
        [[0, magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,17,34]]
    ))
    dipoles.append(np.array(
        [[0, magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,15,33]]
    ))
    dipoles.append(np.array(
        [[0, magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,13,32]]
    ))
    dipoles.append(np.array(
        [[0, magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,11,31]]
    ))
    dipoles.append(np.array(
        [[0, magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,9,30]]
    ))
    dipoles.append(np.array(
        [[0, magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,8,29]]
    ))
    dipoles.append(np.array(
        [[0, 0, -magnitude]]
    ))
    # CS ends
    positions.append(np.array(
        [[x_pos,6,30]]
    ))
    dipoles.append(np.array(
        [[0, 0, -magnitude]]
    ))
    positions.append(np.array(
        [[x_pos,5,32]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,4,34]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, -magnitude*0.7]]
    ))
    positions.append(np.array(
        [[x_pos,3,36]]
    ))
    dipoles.append(np.array(
        [[0, -magnitude*0.7, -magnitude*0.7]]
    ))

assert len(dipoles) == len(positions)
# np.array(positions)[:,0,:]
dipole_positions = []

for all_pos in positions:
    for pos in all_pos:
        # print(pos)
        dipole_positions.append((translate_index_to_position(pos[0], pos[1], pos[2])))
# print(dipole_positions)
# dips = dipoles + (positions/magnitude)
dpt = dp.ElectrodePlotter(electrode_ids=[0,1], num_electrodes=1, dipoles=np.array(dipole_positions))
dpt.translate(y=1)
dpt.plot_3d(view=(0,0))

### Generate Voltages

In [None]:
### Calculate simulated voltages on each device ###

# Initialize labels
labels = np.random.randint(0, len(dipoles), num_trials)

# Generate voltages for each trial
v_disc = np.zeros((num_trials, num_electrodes_disc))
v_seeg = np.zeros((num_trials, num_electrodes_seeg))
v_ecog = np.zeros((num_trials, num_electrodes_ecog))
for trial in range(num_trials): # Monte-Carlo cycle
    # Make dipoles
    (x,y,z) = positions[labels[trial]].transpose()  # This gets each dipole's position
    S = dipoles[labels[trial]].flatten()
    
    # G is the gain matrix
    G_disc = np.reshape(disc_fields[x, y, z, :, :], (-1, num_electrodes_disc))
    G_disc = np.nan_to_num(G_disc)
    G_seeg = np.reshape(seeg_fields[x, y, z, :, :], (-1, num_electrodes_seeg))
    G_seeg = np.nan_to_num(G_seeg)
    G_ecog = np.reshape(ecog_fields[x, y, z, :, :], (-1, num_electrodes_ecog))
    G_ecog = np.nan_to_num(G_ecog)

    # Make device voltages
    v_disc[trial, :] = np.matmul(S, G_disc) + np.random.normal(scale=disc_noise*1e-6, size=num_electrodes_disc)
    v_seeg[trial, :] = np.matmul(S, G_seeg) + np.random.normal(scale=seeg_noise*1e-6, size=num_electrodes_seeg)
    v_ecog[trial, :] = np.matmul(S, G_ecog) + np.random.normal(scale=ecog_noise*1e-6, size=num_electrodes_ecog)

# uV conversion
v_disc = np.multiply(v_disc, 1e6)
# indices = np.r_[240:256, 128:240]  # Take last 16 channels and put them in front- correction for column positions when rotating
# v_disc[:, 128:256] = v_disc[:, indices]
v_seeg = np.multiply(v_seeg, 1e6)
v_ecog = np.multiply(v_ecog, 1e6)
# Multi-device cases
# v_disc_ecog = np.hstack((v_disc, v_ecog))
# v_seeg_ecog = np.hstack((v_seeg, v_ecog))
v_disc_ecog = np.hstack((v_ecog, v_disc))
v_seeg_ecog = np.hstack((v_ecog, v_seeg))

## Classification

In [None]:
### Assign data and labels for train and test datasets ###

# Data/label assignment
X_disc = v_disc
X_seeg = v_seeg
X_ecog = v_ecog
X_disc_ecog = v_disc_ecog
X_seeg_ecog = v_seeg_ecog
y = labels

# Data splits
X_train_disc, X_test_disc, y_train_disc, y_test_disc = train_test_split(X_disc, y, test_size=0.20)
X_train_seeg, X_test_seeg, y_train_seeg, y_test_seeg = train_test_split(X_seeg, y, test_size=0.20)
X_train_ecog, X_test_ecog, y_train_ecog, y_test_ecog = train_test_split(X_ecog, y, test_size=0.20)
X_train_disc_ecog, X_test_disc_ecog, y_train_disc_ecog, y_test_disc_ecog = train_test_split(X_disc_ecog, y, test_size=0.20)
X_train_seeg_ecog, X_test_seeg_ecog, y_train_seeg_ecog, y_test_seeg_ecog = train_test_split(X_seeg_ecog, y, test_size=0.20)

In [None]:
### Perform SEPIO ###
# Basic settings
threshold = 1
l1_penalty = 1e-3
basis = ps.basis.SVD(n_basis_modes=16)

# Sensor count step size and maximum
step = 8
max_sensors = 256

# Accuracy arrays
accs_disc = []
accs_seeg = []
accs_ecog = []
accs_disc_ecog = []
accs_seeg_ecog = []
sensor_range = np.arange(step, max_sensors+1, step)

# Train datasets on SEPIO
disc_model = SSPOC(l1_penalty=l1_penalty, basis=basis).fit(X_train_disc, y_train_disc)
seeg_model = SSPOC(l1_penalty=l1_penalty, basis=basis).fit(X_train_seeg, y_train_seeg)
ecog_model = SSPOC(l1_penalty=l1_penalty, basis=basis).fit(X_train_ecog, y_train_ecog)
disc_ecog_model = SSPOC(l1_penalty=l1_penalty, basis=basis).fit(X_train_disc_ecog, y_train_disc_ecog)
seeg_ecog_model = SSPOC(l1_penalty=l1_penalty, basis=basis).fit(X_train_seeg_ecog, y_train_seeg_ecog)
# print(f'Num DISC electrodes: {np.shape(disc_model.selected_sensors)}')
# print(f'Num SEEG electrodes: {np.shape(seeg_model.selected_sensors)}')
# print(f'Num ECOG electrodes: {np.shape(ecog_model.selected_sensors)}')
# print(f'Num DISC+ECOG electrodes: {np.shape(disc_ecog_model.selected_sensors)}')
# print(f'Num SEEG+ECOG electrodes: {np.shape(seeg_ecog_model.selected_sensors)}')
sensors_optimal = disc_ecog_model.selected_sensors # Specific sensor set to DiSc_ECoG

# Test accuracy on each device set and store accuracies
for s in sensor_range:
    # Make sure we don't try to use more electrodes than each device set has available
    if s <= num_electrodes_disc:
        disc_model.update_sensors(n_sensors=s, xy=(X_train_disc, y_train_disc), quiet=True)
        y_pred_disc = disc_model.predict(X_test_disc[:, disc_model.selected_sensors])
        accs_disc.append(metrics.accuracy_score(y_test_disc, y_pred_disc))

    if s <= num_electrodes_seeg:
        seeg_model.update_sensors(n_sensors=s, xy=(X_train_seeg, y_train_seeg), quiet=True)
        y_pred_seeg = seeg_model.predict(X_test_seeg[:, seeg_model.selected_sensors])
        accs_seeg.append(metrics.accuracy_score(y_test_seeg, y_pred_seeg))
        
    if s <= num_electrodes_ecog:
        ecog_model.update_sensors(n_sensors=s, xy=(X_train_ecog, y_train_ecog), quiet=True)
        y_pred_ecog = ecog_model.predict(X_test_ecog[:, ecog_model.selected_sensors])
        accs_ecog.append(metrics.accuracy_score(y_test_ecog, y_pred_ecog))
        
    if s <= num_electrodes_disc + num_electrodes_ecog:
        disc_ecog_model.update_sensors(n_sensors=s, xy=(X_train_disc_ecog, y_train_disc_ecog), quiet=True)
        y_pred_disc_ecog = disc_ecog_model.predict(X_test_disc_ecog[:, disc_ecog_model.selected_sensors])
        accs_disc_ecog.append(metrics.accuracy_score(y_test_disc_ecog, y_pred_disc_ecog))
        
    if s <= num_electrodes_seeg + num_electrodes_ecog:
        seeg_ecog_model.update_sensors(n_sensors=s, xy=(X_train_seeg_ecog, y_train_seeg_ecog), quiet=True)
        y_pred_seeg_ecog = seeg_ecog_model.predict(X_test_seeg_ecog[:, seeg_ecog_model.selected_sensors])
        accs_seeg_ecog.append(metrics.accuracy_score(y_test_seeg_ecog, y_pred_seeg_ecog))

In [None]:
### Plot accuracy results ###
plt.plot(sensor_range[:num_electrodes_disc//step], accs_disc, '-o', label='DISC')
plt.plot(sensor_range[:num_electrodes_seeg//step], accs_seeg, '-o', label='SEEG')
plt.plot(sensor_range[:(num_electrodes_disc+num_electrodes_ecog)//step], accs_disc_ecog, '-o', label='DISC+ECOG')
plt.plot(sensor_range[:(num_electrodes_seeg+num_electrodes_ecog)//step], accs_seeg_ecog, '-o', label='SEEG+ECOG')
plt.plot(sensor_range[:num_electrodes_ecog//step], accs_ecog, '-o', label='ECOG')
plt.xlabel('Number of sensors')
plt.ylabel('Accuracy')
plt.ylim((0,1))
plt.title('Accuracy vs. # Sensors (Test)')
plt.legend()
plt.savefig(path.join(folder,'outputs','10_SEPIO_sulcus.png'),dpi=300)

In [None]:
### Print benchmark points ###
# {device} @ {number_of_sensors}: {classification_accuracy}
print("SEEG @ 16:",np.round(accs_seeg[-1],2))
print("DiSc @ 256:",np.round(accs_disc[-1],2))
print("ECoG @ 48:",np.round(accs_ecog[5],2))
print("ECoG @ 128:",np.round(accs_ecog[-1],2))
print("SEEG+ECoG @ 146:",np.round(accs_seeg_ecog[-1],2))
print("SEEG+ECoG @ 146:",np.round(accs_disc_ecog[17],2))
print("SEEG+ECoG @ 256:",np.round(accs_disc_ecog[-1],2))