# Phantom Processing with SEPIO
This script processes recorded phantom data, calculates equivalent simulation, generates comparative metrics, and performs SEPIO on voltage arrays.

## Load Libraries & Primary Settings

In [None]:
### Libraries ###
import matplotlib.pyplot as plt
import numpy as np
from sklearn import metrics
from sklearn.model_selection import train_test_split
from os import path
import sys
import pysensors as ps
from pysensors.classification import SSPOC
import warnings
from modules.leadfield_importer import FieldImporter
import scipy as sp
from scipy import signal
from scipy.io import savemat
from scipy.fftpack import rfft, fftfreq
import math

sys.path.insert(0, path.join('.'))

In [None]:
### Phantom Data Settings ###
# Targets phantom data files to be loaded later

# Change target directory
folder = r"...\SEPIO_dataset"
output = path.join(folder,'outputs')
datadir = path.join(folder,'phantom')
# Acquire amplifier data file from "Format_data-v1.m" or "noise_analysis..._v7.m", run, go to variables, and save variable
file = ['dev1-d1','dev2-d1',
        'dev1-d2','dev2-d2',
        'dev1-d3','dev2-d3',
        'dev1-d4','dev2-d4',
        'dev1-d5','dev2-d5',
        'dev1-d6','dev2-d6',
        'dev1-d7','dev2-d7',
        'dev1-d8','dev2-d8',
        'dev1-d9','dev2-d9',
        'dev1-d10','dev2-d10',
        'dev1-d11','dev2-d11'] # Point to file names; !! Must be in device pairs per dipole !!
dev_flip = np.array([False,True]) # Whether device channels need to be flipped
labels = np.arange(0,len(file)//2) # !! Make sure labels match the order of dipoles !!

dataset = []
for f in file:
    dataset.append(path.join(datadir,(f + '.mat')))

# Update bad channels to be removed for any new devices used
drop = np.array([63,88,127]) # Includes both devices (index 0-127)

# Choose downsampling
original_sampling_rate = 20000 # Hz
desired_downsample = 2000 # Hz

# Choose filter
# 2-119 Hz is generally clean data; use low frequencies for slow wave activity (1 Hz square)
band = True
band_freq = [2.,118.] # Hz (+-1); converted to normalized frequency
filter_order = 5 # Scipy butter filter order

# Common average referencing; applies to both simulation and phantom
car = True

# Alignment of the two devices recorded stimulation
# Shifts start position to 1.5*pi after the first maximum within one period (0 degree start)
align = True
alignPhase = np.pi/2 # Phase in radians relative to maximum (after); assumes a cosine wave
alignFreq = 20 # Hz; what frequency is the main signal
startcut = 5 # s; time cut from the front of signal
period = desired_downsample//alignFreq # Index count for one full period
if desired_downsample%alignFreq != 0:
    print("WARNING: Make the downsample rate an integer multiple of the driven frequency (alignFreq)!")

# Reduce overall signal power
red_factor = 0.2 # Linear scale on total power to match typical DiSc signal magnitude

In [None]:
### Simulation Settings ###
# Loads lead field files

# CHANGE: Select file
device = 'DISC'  # Options: {'DISC', 'IMEC', 'SEEG'}
fields_file = path.join(folder,'leadfields', 'DISC_30mm_p2-5Sm_MacChr.npz')

# Get fields
field_importer = FieldImporter()
field = field_importer.load(fields_file)
num_electrodes = np.shape(field_importer.fields)[4]
fields = field_importer.fields
scale = 0.5 # mm; voxel size

# Common average reference is defined above by the `car` variable

# Define dipoles
magnitude = 20e-9 # Amp*(voxel size in mm); *2 for peak-to-peak; Divide by dipole separation (in mm) to get dipole in nA*m; 20e-9/4e-3 = 5mA (10M)
min_voltage = 0  # uV
noise = 0 # uV rms; typically 4.1 uV; Only applied during voltage calculation; Can be applied in post
num_trials = 10**3

# Generate Monopolar Leadfield
# Vector magnitude of each vector <x,y,z> yields a Volt/Amp LF
monofield = np.sqrt(fields[:,:,:,0,:]**2 + fields[:,:,:,1,:]**2 + fields[:,:,:,2,:]**2)
print(monofield.shape)

# Drop channels to make a 64-channel device
dropCh = np.arange(64)
# Pattern the dropped channels to convert the 128-ch leadfield to a 64-ch lead field
# Sets of two rows dropped
monofield64 = np.zeros(monofield[:,:,:,:64].shape)
for ch in dropCh:
    monofield64[:,:,:,ch] = monofield[:,:,:,ch*2]

# Bad channel removal; MUST match above drop for phantom
drop1 = np.array([63]) # device 1
drop2 = np.array([24,63])+64 # device 2
if (drop == np.append(drop1,drop2)).all:
    pass
else:
    print("Error: Bad channel drops for phantom and simulation do not match!")


## Phantom Data Processing

### Load and pre-process

In [None]:
### Functions to load and clean data ###


def load_clean(dat):
    print("Starting load and clean:",dat)
    data = sp.io.loadmat(dat)
    data = data.get('amplifier_data')
    
    # Downsample
    factor = int(original_sampling_rate/desired_downsample)
    data = sp.signal.decimate(data,factor,axis=-1)

    # Assign time array
    total_time = data.shape[-1]/desired_downsample
    time = np.linspace(0,total_time,data.shape[-1])

    # Filtering w/ butterworth
    low = band_freq[0]/(0.5*desired_downsample)
    high = band_freq[1]/(0.5*desired_downsample)
    sos = signal.butter(filter_order,[low,high],analog=False,btype='bandpass',output='sos')
    final = signal.sosfilt(sos,data)

    # FFT
    length = data.shape[1]
    freq = fftfreq(length,d=(time[1]-time[0])*2)
    fft = rfft(data)

    # Apply common average referencing
    carVal = np.zeros((data.shape[1]))
    if car == True:
        print("~CAR Selected~")
        for t in np.arange(time.shape[0]):
            carVal[t] = np.nanmean(final[:,t])
            data[:,t] -= np.nanmean(data[:,t])
            final[:,t] -= np.nanmean(final[:,t])

    # Return bad channels to zero after CAR
    data = np.nan_to_num(data)
    final = np.nan_to_num(final)

    # Align if chosen
    if align:
        print("~Alignment Selected~")
        tdata = np.copy(final)

        # Start with the desired initial cut
        t = startcut*desired_downsample

        # Sum data over x periods to average out noise
        avgperiods = 20 # Increase to improve averaging. Too large may cause errors due to phase shift over time.
        avgdata = np.zeros((tdata.shape[0],period))
        for p in np.arange(avgperiods):
            avgdata[:] += tdata[:,t+p*period:t+(p+1)*period]

        # Calculate RMS per channel over the whole recording period
        rms = np.sqrt(np.mean(np.square(avgdata),axis=0))
        #print("RMS:",rms)
        # Find point of highest signal and adjust to that index
        maxSignalIndex = np.argmax(rms)
        initialPhase = (2*np.pi)*(period-maxSignalIndex)/period
        print("Initial phase (rad.):",initialPhase)
        
        # Add phase shift to desired start
        phaseShift = maxSignalIndex + int(alignPhase*period/(2*np.pi))
        if phaseShift >= period:
            phaseShift -= period
        print("Phase shift (index):",phaseShift)
        t += phaseShift
        
        #Adjust to new start index
        print('Trimming front to index:',t)
        time = time[:-t]
        data = data[:,t:]
        final = final[:,t:]

    # Observe data shape and removed channels
    print("Shapes should match:", '\n', "Data - ",data.shape)
    #print("Shapes should match:", '\n', "Time -", time.shape, '\n', "Data - ",data.shape, '\n', "FFT - ",final.shape)
    return final, time, data, fft, freq, carVal

In [None]:
### Run load and clean function; Trim data; Combine per device ###

# Loop over all files provided for preprocessing
length = np.array(())
data = []
carVals = []
for i,d in enumerate(dataset):
    #Load data
    f, t, _, _, _, carVal = load_clean(d)
    
    # Flip electrode order if necessary for either device
    if (i%2==0) and dev_flip[0]: # Device 1
        f = np.flip(f,axis=0)
    if (i%2==1) and dev_flip[1]: # Device 2
        f = np.flip(f,axis=0)
    
    # Add to arrays
    data.append(f)
    length = np.append(length,int(t.shape[0]))
    carVals.append(carVal)

# Trim files down to shortest length
minlen = int(np.min(length))
for i,d in enumerate(data):
    data[i] = data[i][:,:minlen]

# Stack data for 2-device pairs (assumes pairings of 0&1, 2&3, ...)
# Stacked is a list of Numpy arrays list[ (ch,timepoints) , (ch,timepoints) , ...] for each dipole, ID'd by label[index]
stacked = []
for i in np.arange(len(data)):
    if i%2 == 0:
        stacked.append(np.vstack((data[i],data[i+1])))

# Check results
print('In total:\nData is of the shape:', stacked[0].shape,'\nTime (s):',minlen/desired_downsample,'\nLengths are:',length)


In [None]:
### Re-order data to column-wise  and drop channels ###

# Old order from my own script (unknown source)
order = np.array([28,24,20,16,12,8,4,0,
                  2,6,10,14,18,22,26,30,
                  31,27,23,19,15,11,7,3,
                  1,5,9,13,17,21,25,29,
                  35,39,43,47,51,55,59,63,
                  61,57,53,49,45,41,37,33,
                  32,36,40,44,48,52,56,60,
                  62,58,54,50,46,42,38,34])

ordered = np.copy(stacked)
for d in np.arange(len(ordered)): # Each dipole
    for i,n in enumerate(order): # Sort by order
        if dev_flip[0]: # Flip sorting for device 1
            ordered[d][63-n] = stacked[d][63-i] # Sort device 1
        else:
            ordered[d][n] = stacked[d][i] # Sort device 1
        if dev_flip[1]: # Flip sorting for device 2
            ordered[d][127-n] = stacked[d][127-i] # Sort device 2
        else:
            ordered[d][63+n] = stacked[d][63+i] # Sort device 2

for d in np.arange(len(ordered)):
    for ch in drop:
        ordered[d][ch,:] = np.nan

In [None]:
### Option to save filtered data
save = False
folder = r"C:\Users\willi\Documents\NEI\Phantom\P02.26.2024-filtered-noCAR"

if save:
    for dev in range(2): # dev for each device ID
        for d in range(11): # d for each dipole ID
            savemat(path.join(folder,f"d{d+1}-dev{dev+1}.mat"),{"array":ordered[d][dev*64:(dev+1)*64-1,:]})

In [None]:
###[Settings] Visualize a few cycles; All channels ###
visPeriods = 4 # 2-4 looks good (scales off of alignPeriod; needs to be correct)
skipSec = 0 # how many seconds to skip
dipole = [0,1,2,3,4,5,6,7,8,9,10] # Dipole index
visChannels = np.array([0,8,24,32]) # Show only specific channels
visChannels = np.array([-1]) # Uncomment to plot all channels; Note - very messy

skipTime = skipSec*desired_downsample
t = visPeriods*period
T = np.arange(t)/(t*alignFreq/visPeriods) + skipSec
if np.sum(visChannels) < 0:
    for d in dipole:
        x = np.arange(stacked[d].shape[0]) # array of all channel IDs
        for ch in x:
            plt.plot(T,ordered[d][ch,skipTime:skipTime+t])
else:
    for d in dipole:
        for ch in visChannels:
            plt.plot(T,ordered[d][ch,skipTime:skipTime+t])
plt.xlabel("Time (s.)")
plt.ylabel("Channel magnitude (uV)")
plt.show()

In [None]:
###[Settings] Spectral assessment
# Check if filters are appropriate to remove excess noise

# Plot assignment
fig, ax = plt.subplots(11,2,figsize=(10,20))

for dipole in range(11):
    # Calculate FFT
    fft_dev1 = ordered[dipole][:64,:]
    fft_dev2 = ordered[dipole][64:,:]
    frequencies = np.linspace(0,desired_downsample/2,fft_dev1.shape[1])

    # FFT plot
    for ch in range(64): # ch for each channel
        ax[dipole,0].psd(fft_dev1[ch,:],Fs=desired_downsample,NFFT=1024)
        ax[dipole,1].psd(fft_dev2[ch,:],Fs=desired_downsample,NFFT=1024)
    # Figure text
    for i in range(2): # i for each plot
        ax[dipole,i].set_title(f"(Dipole {dipole}, Device {i+1})")
        ax[dipole,i].set_ylabel("Power/Frequency\n(dB/Hz)")
        ax[dipole,i].set_xlabel("Frequency (Hz)")
        ax[dipole,i].set_xlim([0,200])
        ax[dipole,i].set_ylim([-60,60])

fig.suptitle(f"Power Spectral Density")
#plt.subplots_adjust(wspace=.1,hspace=1)
plt.tight_layout()
plt.show()

In [None]:
###[Settings] Peak voltage; Column-wise sorted ###
dipole = 0 # Dipole index, not ID
device = 1 # Device index
skipSec = 0 # Number of seconds to skip
duration = 1 #periods (integer)
phaseSh = skipSec*desired_downsample
channels = 64 # how many channels per device

values = np.zeros((ordered[0][:channels,0].shape))
for i in range(duration):
    values += ordered[dipole][device*channels:(device+1)*channels,phaseSh+i*period+period//4]
    # Change in shifted relative to ordered
    #vo = ordered[dipole][:channels,phaseSh+i*period+period//4]
    #vs = shifted[dipole][:channels,phaseSh+i*period+period//4]
    #values += (vs-vo)
values = values/(duration)


# Split data into columns and color code
colors = plt.cm.rainbow(np.linspace(0,1,channels//8))
y = np.zeros((channels//8,8))
x = np.zeros((channels//8,8))
for i in range(channels//8):
    y[i] = values[i*8:(i+1)*8]
for i in range(channels//8):
    x[i] = np.arange(i*8,(i+1)*8)

# Plotting
for i in range(channels//8):
    plt.scatter(x[i],y[i],color = colors[i])
    plt.plot(x[i],y[i],color = colors[i])
plt.xlabel("Channels (#)")
plt.ylabel("Peak voltage (@1/4 period) (uV)")
plt.show()

In [None]:
###[Settings] Peak voltage - both devices; Column-wise sorted ###
dipole = 0 # Dipole index, not ID
skipSec = 0 # Number of seconds to skip
duration = 100 #periods (integer)
phaseSh = skipSec*desired_downsample

values = np.zeros((ordered[0][:,0].shape))
for i in range(duration):
    values += ordered[dipole][:,phaseSh+i*period+period//4]
    # Change in shifted relative to ordered
    #vo = ordered[dipole][:channels,phaseSh+i*period+period//4]
    #vs = shifted[dipole][:channels,phaseSh+i*period+period//4]
    #values += (vs-vo)
values = values/(duration)


# Split data into columns and color code
colors = plt.cm.rainbow(np.linspace(0,1,16))
y = np.zeros((16,8))
x = np.zeros((16,8))
for i in range(16):
    y[i] = values[i*8:(i+1)*8]
for i in range(16):
    x[i] = np.arange(i*8,(i+1)*8)

# Plotting
for i in range(16):
    plt.scatter(x[i],y[i],color = colors[i])
    plt.plot(x[i],y[i],color = colors[i])
plt.xlabel("Channels (#)")
plt.ylabel("Peak voltage (@1/4 period) (uV)")
plt.show()

In [None]:
### Plotting with aligned depth ###
for i in range(16):
    plt.scatter(range(8),y[i],color = colors[i])
    plt.plot(range(8),y[i],color = colors[i])
plt.xlabel("Row/depth")
plt.ylabel("Peak voltage (@1/4 period) (uV)")
plt.show()

### Sort Data for SEPIO

In [None]:
### SEPIO on a snapshot of peak values for each cycle ###
# Extract peak of each cycle in the form (128,) and use that instead of the whole period
# This makes determination of important electrodes much easier

# Count trials per label
count = (ordered[0].shape[1]//period)
# Count total trials
count_total = count*len(ordered)

# Data of shape (trial #, electrode #) e.x. (10000,128), and labels of shape (trial #)
X = np.zeros((count_total,ordered[0].shape[0]))
Y = np.zeros((count_total))

# Collect data at each peak (1/4 + 2N)*pi
offset = period//4
for i in range(len(ordered)):
    # i is label/dipole index; k is label id
    for n in np.arange(count):
        sample = count*i + n
        Y[sample] = i
        X[sample] = ordered[i][:,offset+period*n]
yp = Y.astype(int)

# Remove NaN and scale down signal
Xp = np.nan_to_num(X)*red_factor

## Simulation Data Processing

### Generate Voltage Map

In [None]:
###[Settings] Visualize one axis of the Leadfield ###
# Used to determine centrality of fields relative to z axis
x,y,z = 22,30,10
plt.plot(monofield64[x,y,:,:]/1000)
plt.xlabel('Chosen axis')
plt.ylabel('Lead Field (uV/nA)')

In [None]:
### Set up device Leadfields ###
LF1 = np.copy(monofield64) # device 1 LF
LF2 = np.copy(monofield64) # device 2 LF

In [None]:
### Define Phantom Locations ###
"""
All locations are defined by <x,y,z> (3,) relative to global axes
In phantom #1 (2x MacChr, 11x dipoles, cortical sulcus arrangement),
    images are taken en-face from far above with the central dipole toward the camera.
    Locations are measured in ImageJ relative to the centerline at the base of the mount.
    Positive x is measured right, toward device 1.
    Positive y is measured manually by caliper - depth of dipoles in the en-face images.
        -3mm, 0mm, or 3mm in the first phantom setup.
    Positive z is measured normal to the mount internal surface.
Devices are measured by a starting and ending point of the electrode span.
These relative locations are then used to align the dipole lead locations relative to each device and
    assign the corresponding voxel location for each lead field.
Device probe rotation is taken into account last based on pre-mounting measurements of probe rotation
    relative to PCB enclosure.
"""
### Define devices
# N Devices ((top_x,top_y,top_z),(bottom_x,bottom_y,bottom_z)) in mm (top=proximal, bottom=distal)
dev_pos = np.array([[[3.22,0.30,4.04],[0.47,0.66,19.29]],
                    [[-4.43,0.52,3.27],[-2.24,1.97,17.94]]
                    ],dtype=float)

# Device on-axis rotation; Measured value; Assigned in degrees (converted to radians)
#!!! This MUST be assigned by viewing down the probe (from distal perspective) and measure
#   rotation counter-clockwise. Both device angles must account for column 1 aligning with +x axis
#   e.g. to the right when imaging from the side with dipole 6 protruding toward the camera.
dev_rot = np.array([0,180])*np.pi/180

# Offset vectors from center of dev_pos locations to center of leadfield file; One per device
# MacChr is offset z-9, so z+9 is to the LF center
#! Currently (1/26/24) only uses z axis as an on-axis offset; defined in voxels, then scaled to um by `scale`
dev_offset = np.array([[0,0,9],[0,0,9]],dtype=float)*scale

# Assign device leadfields; useful when mixing device types
dev_lf = np.array([LF1,LF2],dtype=float)

# Check that data is provided for all devices; does not check each variable individually
if dev_pos.shape[0] != dev_rot.shape[0] or dev_pos.shape[0] != dev_offset.shape[0] or dev_pos.shape[0] != dev_lf.shape[0]:
    print("Device information appears incorrect. Shape mismatch.")

### Define dipoles
# N Dipoles (((pos_x,pos_y,pos_z),(neg_x,neg_y,neg_z)),...) in mm
dipoles = np.array([
    [[1.59,4.60,9.46],[5.51,3.97,11.10]],#dipole 1
    [[2.58,3.67,3.01],[5.29,3.97,6.62]], #dipole 2, etc. 
    [[0.47,4.60,10.77],[2.11,3.04,13.51]],
    [[1.08,-2.90,10.45],[4.26,-2.19,13.68]],
    [[1.78,-3.67,6.02],[4.39,-2.82,8.45]],
    [[-0.43,0.77,13.10],[-1.03,1.26,17.08]],
    [[-1.74,-2.71,9.29],[-6.02,-3.26,9.85]],
    [[-1.42,-3.34,3.14],[-4.47,-3.10,7.01]],
    [[-1.81,-2.99,12.13],[-4.26,-3.26,15.27]],
    [[-1.59,4.41,9.74],[-5.53,4.49,10.32]],
    [[-1.23,3.70,2.95],[-3.20,3.73,6.65]]
],dtype=float)
# Dipole magnitudes; monopoles can be made by assigning a zero magnitude here to one of the pair
dipmag = np.array([
    [-magnitude,magnitude]
    ])
# Repeate dipole if matching pattern is desired
dipmag = np.repeat(dipmag,dipoles.shape[0],axis=0)
# Check that dipoles each have a magnitude assigned
if dipoles.shape[0] != dipmag.shape[0]:
    print("Number of dipole location and magnitudes mismatch.")

# Function to convert locations to a device reference frame and LF indices
def translate(devnum):
    # Translate the location of dipoles into devnum reference frame
    ## Dipole_index_relative_to_LF = [LF_midpoint_vector(absolute)] + [Origin_to_dipole(relative to origin)] - [Origin_to_LF_center(relative to origin)]
    # Includes the x,z rotation inherently since each monopole is handled separately (not a vector dipole to rotate)
    translated = np.copy(dipoles)
    
    # Find center of device span (not LF) relative to measured origin
    ref = (dev_pos[devnum,1]+dev_pos[devnum,0])/2 # To center of assigned dev_loc

    # Translate
    for i,d in enumerate(translated):
        translated[i] =  d - ref
    
    # Returns the dipole vectors relative to the device span midpoint (not LF midpoint)
    return translated

# Function to rotate the device reference frame on-axis to the device
def rotate(translated,devnum):
    # Modifies translated dipoles from the `translate` fucntion
    # First, rotate the LF results after transform into the correct x-z angle that DiSc is placed
    # Second, rotate in the x-y (z-axis) to match device rotation within enclosure defined in `dev_rot`
    rotated = np.copy(translated)

    # Find device xz angle
    xz_angle = -np.arctan((dev_pos[devnum,1,2]-dev_pos[devnum,0,2])/(dev_pos[devnum,1,0]-dev_pos[devnum,0,0]))

    # Find LF absolute midpoint values
    LF_mid = np.array([0,0,0])
    for i,lf in enumerate(dev_lf[devnum].shape[0:3]):
        LF_mid[i] = scale*lf/2
    
    # Adjust to the center of the LF file
    for i,d in enumerate(rotated): # for each dipole pair
        for k,m in enumerate(d): # for each monopole in the dipole pair
            # Hold vector magnitude to check later
            vec = rotated[i,k,0]**2 + rotated[i,k,1]**2 + rotated[i,k,2]**2
            # Rotate in XZ plane
            x0 = rotated[i,k,0]
            z0 = rotated[i,k,2]
            rotated[i,k,0] = x0*np.cos(xz_angle) + z0*np.sin(xz_angle)
            rotated[i,k,2] = -x0*np.sin(xz_angle) + z0*np.cos(xz_angle)
            # Rotate in XY plane
            x1 = rotated[i,k,0]
            y1 = rotated[i,k,1]
            rotated[i,k,0] = x1*np.cos(dev_rot[devnum]) + y1*np.sin(dev_rot[devnum])
            rotated[i,k,1] = -x1*np.sin(dev_rot[devnum]) + y1*np.cos(dev_rot[devnum])
            # Check vector magnitude
            if abs(vec-(rotated[i,k,0]**2 + rotated[i,k,1]**2 + rotated[i,k,2]**2)) > 0.5:
                print("!!! A dipole vector location has changed magnitude!")
            
    # Adjust to LF midpoint based on defined offset; Now aligned to z-axis
    rotated -= dev_offset[devnum]

    # Convert to LF voxels relative to bottom corner
    rotated = rotated/scale + LF_mid

    # Convert to integers to act as voxel indices
    rotated_vox = np.round(rotated).astype(int)#rotated.astype(int)

    # Calculate displacement error for moving monopoles to voxel centers
    displacement = rotated - rotated_vox
    displacement = np.sqrt(displacement[:,:,0]**2 + displacement[:,:,1]**2 + displacement[:,:,2]**2)
    displacement = np.round(displacement*scale*1000,1) # scale to um

    # Check for dipoles out of LF bounds; If found, place on nearest border and reduce magnitude
    
    return rotated_vox, displacement

In [None]:
### Translate and rotate sources ###
dev0_dipoles = translate(devnum=0)
dev1_dipoles = translate(devnum=1)
dev0_dipoles,dev0_displacement = rotate(dev0_dipoles,devnum=0)
dev1_dipoles,dev1_displacement = rotate(dev1_dipoles,devnum=1)

### Catch out-of-bounds monopoles ###
# This will inherently change the distance of dipoles moved, making device 1 and 2 sources differ
dips = [dev0_dipoles,dev1_dipoles]
for i,dip in enumerate(dips):
    for d in range(dev0_dipoles.shape[0]): # For each dipole 'd'
        for m in range(dev0_dipoles.shape[1]): #For each monopole 'm'
            for k in range(dev0_dipoles.shape[2]): #For each axis 'k'
                # If below zero plane
                if dip[d,m,k] < 0:
                    dif = np.abs(dip[d,m,k]) # Distance outside bounds
                    dipmag[d,m] = dipmag[d,m]*(1-np.min([1,dif/6]))
                    print("D",d,"m",m,"moved by",dif)
                # If above max plane
                elif dip[d,m,k] > LF1.shape[k]:
                    dif = LF1.shape[k] - dip[d,m,k] # Distance outside bounds
                    dipmag[d,m] = dipmag[d,m]*(1-np.min([1,dif/6]))
                    print("D",d,"m",m,"moved by",dif)

In [None]:
### Print displacement error values ###
# This is the distance each monopole point is moved to the center of the nearest voxel in the leadfield space
print('Device 1')
for d in range(11):
    print(f'D{d}-1: {dev0_displacement[d,0]} um; D{d}-2: {dev0_displacement[d,1]} um')
print('\nDevice 2')
for d in range(11):
    print(f'D{d}-1: {dev1_displacement[d,0]} um; D{d}-2: {dev1_displacement[d,1]} um')
print(f'\nMean displacement\nDevice 1: {round(np.mean(dev0_displacement))} um\nDevice 2: {round(np.mean(dev1_displacement))} um')

In [None]:
### Calculate voltages ###

# Assign arrays
num_electrodes = LF1.shape[-1]*2
v = np.zeros((num_trials*dev0_dipoles.shape[0], num_electrodes))
labels = np.zeros((num_trials*dev0_dipoles.shape[0]))

# Cycle through all sources and trials to fill voltage matrix
for d in range(dev0_dipoles.shape[0]): # For each dipole 'd'
    for t in range(num_trials): # For each trial 't'
        for m in range(dev0_dipoles.shape[1]): #For each monopole 'm'

            # Make device voltages
            # G is the gain matrix for each device
            G0 = LF1[dips[0][d,m,0], dips[0][d,m,1], dips[0][d,m,2], :]
            G1 = LF2[dips[1][d,m,0], dips[1][d,m,1], dips[1][d,m,2], :]
            #G = np.reshape(monofield[x, y, z, :], (-1, num_electrodes))
            G0 = np.nan_to_num(G0)
            G1 = np.nan_to_num(G1)

            v[d*num_trials+t, :LF1.shape[-1]] += dipmag[d,m]*G0
            v[d*num_trials+t, LF1.shape[-1]:] += dipmag[d,m]*G1
        
        # Save labels
        labels[d*num_trials+t] = d
labels = labels.astype('int')

v = np.multiply(v, 1e6)  # uV conversion
#v += np.random.normal(scale=noise,size=v.shape) # Noise profile to all electrodes
print("v shape:", v.shape)
print("labels shape:", labels.shape)
print("Unique labels:",np.unique(labels))
v2 = np.copy(v)
v3 = np.copy(v)

In [None]:
### Apply NaN to drop channel voltages ###
# Dropping channels that are not suitable on the physical device used for phantom data collection
# Helps to match the physical to simulated case
drops = np.append(drop1,drop2)
z = np.nan
for d in drops:
    v[:,d] = z

In [None]:
###[Settings] Polarity correction ###
# Flips polarity for specific dipoles and devices to match phantom arrangement
# Necessary since initial phantom recording phase can be either 0 or pi radians
# `1` returns the same array, `-1` flips the voltage values relative to zero volts
# Note: Not necessary for SEPIO performance. A difficult assessment to make
#       so consider dropping this step in future use unless comparing recording to
#       simulation directly.

# Set up as polarity[dipole,devnum] = flip state (1 or -1)

# Currently set up for Feb. 26, 2024 dataset
polarity = np.array([
    [1,1],
    [-1,1],
    [-1,1],
    [-1,1],
    [-1,1],
    [-1,-1],
    [1,1],
    [-1,1],
    [1,-1],
    [1,1],
    [-1,1]
])

for d in np.unique(labels):
    for devnum in np.arange(2):
        v[d*num_trials:(d+1)*num_trials,devnum*64:(devnum+1)*64] *= polarity[d,devnum]

In [None]:
### Common Average Reference ###
# Subtracts the average voltage at each time point per-device
# `car` is assigned for both datasets
# Mismatch of CAR will cause more drastic differences in SEPIO performance

if car:
    v[:,:64] -= np.repeat(np.nanmean(v[:,:64],axis=1),64).reshape(11000,64) # device 1
    v[:,64:] -= np.repeat(np.nanmean(v[:,64:],axis=1),64).reshape(11000,64) # device 1

### Sort Data for SEPIO

In [None]:
### Load data ###
Xs = np.nan_to_num(np.copy(v2))
ys = labels

## Corrections

In [None]:
### Correcting for inherent differences in phantom and simulation ###
# Namely (1) total RMS power and (2) inherent noise in phantom recording
background_noise = 4.1 #uV

# Add background DiSc noise from phantom
Xs += np.random.normal(scale=background_noise,size=Xs.shape)

# Adjust total RMS power of simulation up to phantom
scale_factor = np.sqrt(np.sum(Xp**2))/np.sqrt(np.sum(Xs**2))
Xs *= scale_factor

In [None]:
### Correcting phantom wrap ###
# Each device thin film has some wrapping error measured with microscope following construction
# This number accounts for the mean column shift required per device
# Only important in plotting selected sensors and calculating matched sensor regions
dev1_rot = 3
dev2_rot = -2

# Shift phantom by above column rotation
Xp[:,:64] = np.roll(Xp[:,:64],dev1_rot*8,axis=-1)
Xp[:,64:] = np.roll(Xp[:,64:],dev2_rot*8,axis=-1)

## Signal Statistics

### Compute SNR

In [None]:
### Settings ###
# Figure Directory
figdir = r'C:\Users\willi\Documents\NEI\SEPIO'

# Noise levels
noise_disc = 3.5 # uV; 3.5 uV for P3; 4.1 uV for P2
noise_seeg = 2.7 # typically 2.7 uV; this produces a relevant value to the disc setting

# vSEEG production
nrows = 2 # number of disc rows to average for one virtual seeg ring; 1,2,4,8 only; 2 is best for seeg equivalence

In [None]:
### Compute statistics ###
# Common average reference
Xs_car = Xs #- np.repeat(np.mean(Xs,axis=1),128).reshape((Xs.shape[0],128))
Xp_car = Xp #- np.repeat(np.mean(Xp,axis=1),128).reshape((Xp.shape[0],128))

# Scale phantom data mean down to match simulation (plots will be properly scaled in x-axis/SNR)
#Xp_car = Xp_car*np.mean(np.abs(Xs_car))/np.mean(np.abs(Xp_car))

# Noise addition
Xs_car += np.random.normal(scale=noise_disc,size=Xs_car.shape)
Xp_car += np.random.normal(scale=noise_disc,size=Xp_car.shape)

# Compute virtual SEEG
Xs_vseeg = np.zeros((Xs.shape[0],16//nrows))
Xp_vseeg = np.zeros((Xp.shape[0],16//nrows))
for t in range(Xs_vseeg.shape[0]): # t for each trial
    for s in range(128): # s for each DiSc sensor
        column = s//16
        depth = s - column*16
        Xs_vseeg[t,depth//nrows] += Xs_car[t,s]
for t in range(Xp_vseeg.shape[0]): # t for each trial
    for s in range(128): # s for each DiSc sensor
        column = s//16
        depth = s - column*16
        Xp_vseeg[t,depth//nrows] += Xp_car[t,s]
Xs_vseeg *= 1/(nrows*8) # divide by number of added sensors per ring
Xp_vseeg *= 1/(nrows*8)

# Convert to SNR
Xs_car = np.square(np.abs(Xs_car/noise_disc))
Xp_car = np.square(np.abs(Xp_car/noise_disc))
Xs_vseeg = np.square(np.abs(Xs_vseeg/noise_seeg))
Xp_vseeg = np.square(np.abs(Xp_vseeg/noise_seeg))

# Reshape arrays to be (dipole_class,all_SNR_data_per_class)
striallen = Xs.shape[0]//11
ptriallen = Xp.shape[0]//11
Xs_stats = np.zeros((11,striallen*128))
Xp_stats = np.zeros((11,ptriallen*128))
Xs_vstats = np.zeros((11,striallen*128//(nrows*8)))
Xp_vstats = np.zeros((11,ptriallen*128//(nrows*8)))
for t in np.arange(11):
    Xs_stats[t,:] = Xs_car[t*striallen:(t+1)*striallen,:].flatten()
    Xs_vstats[t,:] = Xs_vseeg[t*striallen:(t+1)*striallen,:].flatten()
    Xp_stats[t,:] = Xp_car[t*ptriallen:(t+1)*ptriallen,:].flatten()
    Xp_vstats[t,:] = Xp_vseeg[t*ptriallen:(t+1)*ptriallen,:].flatten()

# Limit SNR to a min/max
limit = 10 # SNR limits
Xp_stats[Xp_stats>limit] = np.nan
Xp_stats[Xp_stats<-limit] = np.nan
Xp_vstats[Xp_vstats>limit] = np.nan
Xp_vstats[Xp_vstats<-limit] = np.nan
Xp_stats = Xp_stats[~np.isnan(Xp_stats)]
Xp_vstats = Xp_vstats[~np.isnan(Xp_vstats)]
Xs_stats[Xs_stats>limit] = np.nan
Xs_stats[Xs_stats<-limit] = np.nan
Xs_vstats[Xs_vstats>limit] = np.nan
Xs_vstats[Xs_vstats<-limit] = np.nan
Xs_stats = Xs_stats[~np.isnan(Xs_stats)]
Xs_vstats = Xs_vstats[~np.isnan(Xs_vstats)]

### Dipole Position Values

In [None]:
# Functions for shortest path from point to line segment and angle between vectors
def short_path(A, B, C):
    # A is the dipole point
    # B and C are the ends of the device (order agnostic)
    
    # Step 1: Calculate vectors AB and AC
    AB = [b - a for b, a in zip(B, A)]
    AC = [c - a for c, a in zip(C, A)]
    
    # Step 2: Calculate BC vector
    BC = [c - b for c, b in zip(C, B)]
    
    # Check for division by zero
    if len(set(BC)) == 1 and BC[0]!= 0:
        raise ValueError("Points B and C are coincident.")
    
    # Normalize BC vector
    BC_length = math.sqrt(sum(bc**2 for bc in BC))
    BC_norm = [bc / BC_length for bc in BC]
    
    # Project AB and AC onto BC
    proj_AB_BC = sum(a * b for a, b in zip(AB, BC_norm))
    proj_AC_BC = sum(a * b for a, b in zip(AC, BC_norm))
    
    # Find the closest point P on the line segment BC
    if proj_AB_BC > proj_AC_BC:
        P = [b - bc * proj_AB_BC for b, bc in zip(B, BC_norm)]
    else:
        P = [c - bc * proj_AC_BC for c, bc in zip(C, BC_norm)]
    
    # Step 3: Calculate the distance from A to P
    AP = [p - a for p, a in zip(P, A)]
    distance = math.sqrt(sum(ap**2 for ap in AP))
    
    return distance

def normalize(vector):
    magnitude = math.sqrt(sum(i**2 for i in vector))
    return [i/magnitude for i in vector]

def dot_product(A, B):
    return sum(a*b for a, b in zip(A, B))

def angle(A, B):
    # Normalize the vectors
    A_norm = normalize(A)
    B_norm = normalize(B)

    # Calculate the dot product
    dot_prod = dot_product(A_norm, B_norm)

    # No need to divide dot product by magnitudes since they are normalized

    # Convert the cosine of the angle to radians
    radians = math.acos(dot_prod)

    # Convert radians to degrees
    degrees = math.degrees(radians)

    return degrees

def face_angle(A1, A2, B, C):
    # A1/A2 are the dipole points
    # B and C are the ends of the device (order agnostic)

    # Dipole midpoint
    A = (A1+A2)/2
    
    # Step 1: Calculate vectors AB and AC
    AB = [b - a for b, a in zip(B, A)]
    AC = [c - a for c, a in zip(C, A)]
    
    # Step 2: Calculate BC vector
    BC = [c - b for c, b in zip(C, B)]
    
    # Check for division by zero
    if len(set(BC)) == 1 and BC[0]!= 0:
        raise ValueError("Points B and C are coincident.")
    
    # Normalize BC vector
    BC_length = math.sqrt(sum(bc**2 for bc in BC))
    BC_norm = [bc / BC_length for bc in BC]
    
    # Project AB and AC onto BC
    proj_AB_BC = sum(a * b for a, b in zip(AB, BC_norm))
    proj_AC_BC = sum(a * b for a, b in zip(AC, BC_norm))
    
    # Find the closest point P on the line segment BC
    if proj_AB_BC > proj_AC_BC:
        P = [b - bc * proj_AB_BC for b, bc in zip(B, BC_norm)]
    else:
        P = [c - bc * proj_AC_BC for c, bc in zip(C, BC_norm)]
    
    # Step 3: Make vector from A to P
    AP = P - A

    # Find angle between dipole vector and AP angle
    face = angle((A2-A1),AP)
    if face > 90:
        face = 180 - face
    
    return face

In [None]:
### Determine values
# Position values taken from simulation assignments
# dev_pos [dev#,top/bottom,x/y/z]
# dipoles [dip#,top/bottom,x/y/z]

# Desired variables; Distance to dipole, parallel angle, en-face angle
distances = np.zeros((2,11)) # [dev#,dip#]; mean distance in mm
angles = np.zeros((2,11)) # [dev#,dip#]; angle between device and dipole vectors in degrees
face = np.zeros((2,11)) # [dev#,dip#]; angle between straight line to each device and the dipole vector; degrees from en-face

for dev in range(2): # dev for each device ID
    for d in range(11): # d for each dipole ID
        distances[dev,d] = np.min([short_path(dipoles[d,0],dev_pos[dev,0],dev_pos[dev,1]),short_path(dipoles[d,1],dev_pos[dev,0],dev_pos[dev,1])])
        angles[dev,d] = angle(dev_pos[dev,0]-dev_pos[dev,1],dipoles[d,0]-dipoles[d,1])
        face[dev,d] = face_angle(dipoles[d,0],dipoles[d,1],dev_pos[dev,0],dev_pos[dev,1])

# Format data for table
distances = np.round(distances,2)
angles = np.round(angles,1)
face = np.round(face,1)
data = np.vstack([distances,angles,face])

columns = [f'Dipole {d}' for d in np.arange(0,12,1)]
rows = ['Distance (mm)\nDev 1','Distance (mm)\nDev 2','Parallel Angle (deg.)\nDev 1',
        'Parallel Angle (deg.)\nDev 2','En-face angle (deg.)\nDev1','En-face angle (deg.)\nDev2']

cell_text = []
cell_text.append(columns)
for r in range(6): # r for each data row
    row = np.array([rows[r]])
    row = np.hstack([row,data[r]])
    cell_text.append(row.flatten())

# Print basic RMS
print('RMS distances\nDevice 1:',round(np.sqrt(np.mean(data[0]**2)),2),'mm\nDevice 2:',round(np.sqrt(np.mean(data[1]**2)),2),'mm')
print('RMS parallel angles\nDevice 1:',round(np.sqrt(np.mean(data[2]**2)),1),'deg.\nDevice 2:',round(np.sqrt(np.mean(data[3]**2)),1),'deg.')
print('RMS en-face angles\nDevice 1:',round(np.sqrt(np.mean(data[4]**2)),1),'deg.\nDevice 2:',round(np.sqrt(np.mean(data[5]**2)),1),'deg.')

In [None]:
# Plot data table
fig,ax = plt.subplots()
ax.axis('off')

table = ax.table(cell_text,loc='center')
fig.tight_layout()
plt.show()

fig.savefig(path.join(output,'9_supplemental_table.png'),transparent=True,dpi=600)

## SEPIO
Training and classification

### Load SEPIO Functions

In [None]:
### Cut to desired sensors ###
def cut_sensors(X,sensors):
    """Feed in X and a list of sensors to keep. Returns X with only those sensors remaining"""
    cut_X = np.zeros((X.shape))
    cut_X += 0.1 # zero offset to avoid model solve errors
    for s in sensors:
        cut_X[:,s] = X[:,s]
    return cut_X

In [None]:
### Monte Carlo Train ###
# Uses local method instead of SEPIO package due to differences in processing.
def MC_train(cut,sensors,X,y,Xt,yt,nbasis,randomN):
    """
    Performs Monte-Carlo style training on SEPIO with random testsplit and noise for each cyle.
    `cut` must be boolean to select if only given `sensors` (array) are to be used.
    X and y are the dataset and labels, respectively.
    Xt and yt are the test dataset and labels - unused if the same as X and y.
    nbasis assigns the number of basis modes to be used.
    randomN reduces sensor set to a random N selection if supplied a value above 0.
    """
    global MCcount, noise, l1, testsplit
    MC = MCcount
    
    # Define basis mode selection; SVD, Identity, or RandomProjection
    basis = ps.basis.SVD(n_basis_modes=int(nbasis))
    if np.unique(y).shape[0] <= 2:
        MCcoefs = np.zeros((128,))
    else:
        MCcoefs = np.zeros((128,np.unique(y).shape[0]))
    if cut:
        # Record first N if a cut or random range
        sensor_range = np.arange(1,sensors.shape[0],4)
    elif (randomN > 0):
        # Record first N if a cut or random range
        sensor_range = np.arange(1,randomN,4)
    else:
        # Record first 32 for the full dataset
        sensor_range = np.arange(1,128,4)
    MCaccs = np.zeros((sensor_range.shape[0]))
    MCaccs2 = np.zeros((sensor_range.shape[0]))

    if cut:
        print("Reducing to",sensors.shape[0],"sensors.\n",sensors)
    if (X[0] == Xt[0]).all(): # record if test set is the same as training
        cross = False
    else: # If X and Xt vary, we're crossing datasets.
        print("Crossed datasets.")
        cross = True

    # Iterate MC
    for i in range(MC):
        print("Starting run",i+1,"of",MC)
        
        # Split train and test data
        if cross:
            X_train, _, y_train, _ = train_test_split(X, y, test_size=testsplit)
            _,X_test, _, y_test = train_test_split(Xt, yt, test_size=testsplit)
        else:
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=testsplit)

        # Add noise
        X_train += np.random.normal(scale=noise,size=X_train.shape)
        X_test += np.random.normal(scale=noise,size=X_test.shape)
        
        # Reduce to randomN if desired
        if randomN > 0:
            rand_indices = np.random.choice(X.shape[1],randomN,replace=False)
            print("Random sensors:",np.sort(rand_indices))
            X_test = cut_sensors(X_test,rand_indices)
            X_train = cut_sensors(X_train,rand_indices)
        
        # Reducing to top N sensors
        if cut:
            X_train = cut_sensors(X_train,sensors)
            X_test = cut_sensors(X_test,sensors)

        # Train the model
        warnings.simplefilter("ignore")
        model = None
        model = SSPOC(l1_penalty=l1,basis=basis).fit(X_train, y_train)
        # Record coefficients
        MCcoefs += model.sensor_coef_
        # Record accuracies
        accs = []
        accs2 = []
        for s in sensor_range:
            X_test2 = np.copy(X_test)
            X_train2 = np.copy(X_train)
            
            # Standard dataset sensor selection
            model.update_sensors(n_sensors=s, xy=(X_train, y_train), quiet=True)
            sensor = model.selected_sensors
            
            y_pred = model.predict(X_test2[:, sensor])
            accs.append(metrics.accuracy_score(y_test, y_pred))
            y_pred = model.predict(X_train2[:, sensor])
            accs2.append(metrics.accuracy_score(y_train, y_pred))
        MCaccs += accs
        MCaccs2 += accs2
    MCcoefs *= 1/MC
    MCaccs *= 1/MC
    MCaccs2 *= 1/MC
    
    return MCcoefs, MCaccs, MCaccs2

### Process SEPIO

Xp and yp are phantom dataset

Xs and ys are simulated dataset

In [None]:
### SEPIO Settings ###
# Note: MCcount requires a total of ~10 sec. per cycle for this entire section
# If `choose = 2`, requires ~1 min. per cycle for the entire section
noise = 4.1*3 # uV; standard of 4.1 for DiSc; x3-5 for decent training difficulty
l1 = 0.001
testsplit = 0.4
MCcount = 10
N = 16 # Number of top and random sensors to use
nbasis = 10 # Number of basis modes available (8-14 is a good range using SVD selection)
choose = 2 # Number of dipoles to turn on per-class; 1 or 2 only!

# Check SNR stats for Phantom and Simulation
# SNR defined as RMS signal power divided (calculation) by RMS noise power (defined variable)
print('SNR Statistics\n\t\tMean')
print('Phantom\t\t',np.round(np.square(np.sqrt(np.mean(Xp**2))/(noise+background_noise)),2))
print('Simulation\t',np.round(np.square(np.sqrt(np.mean(Xs**2))/(noise+background_noise)),2))

In [None]:
### Assign temp values and trials for choose state ###
# Allows using the original 11 dipole classes or mixing any choose-2 combination, forming 55 dipole classes.
if choose == 1:
    print("Normal choose 1 dipole trials.")
    Xst = np.copy(Xs)
    yst = np.copy(ys)
    Xpt = np.copy(Xp)
    ypt = np.copy(yp)
elif choose ==2:
    print("Choose 2 dipole trials - 55 total. Compute time increased ~5x.")
    classes = np.arange(0,11,1)
    slen = Xs.shape[0]//11
    plen = Xp.shape[0]//11
    Xst = np.zeros((slen*55,128))
    yst = np.zeros((slen*55))
    Xpt = np.zeros((plen*55,128))
    ypt = np.zeros((plen*55))

    i = 0 # class counter
    for c1 in classes[:-1]:
        for c2 in classes[c1+1:]:
            #print(i,c1,c2)
            # c1 and c2 make up every combination of classes without repeat or overlap
            yst[i*slen:(i+1)*slen] = i
            ypt[i*plen:(i+1)*plen] = i

            # Assign sum of two dipoles as new values
            Xst[i*slen:(i+1)*slen] = Xs[c1*slen:(c1+1)*slen,:] + Xs[c2*slen:(c2+1)*slen,:]
            Xpt[i*plen:(i+1)*plen] = Xp[c1*plen:(c1+1)*plen,:] + Xp[c2*plen:(c2+1)*plen,:]
            
            # iterate index
            i += 1

else:
    print("Incorrect choose variable. Defaulting to choose 1 state.")

In [None]:
### [Settings] BIAS OPTIONS ###
# Intended for testing correct device and source assignments, and SEPIO performance sanity check

### BIAS: Manually limit classes to a set
# Warning: low dipole count (2 or 3) can cause errors in SSPOC solving
xslen = Xs.shape[0]//11
xplen = Xp.shape[0]//11

if False:
    keep = np.array([0,1,2,3,4]) # Dipole indices to keep
    xs2 = np.zeros((keep.shape[0]*xslen,128))
    xp2 = np.zeros((keep.shape[0]*xplen,128))
    ys2 = np.zeros((keep.shape[0]*xslen))
    yp2 = np.zeros((keep.shape[0]*xplen))
    for i,d in enumerate(keep):
        xs2[i*xslen:(i+1)*xslen,:] = Xst[d*xslen:(d+1)*xslen,:]
        ys2[i*xslen:(i+1)*xslen] = i
        xp2[i*xplen:(i+1)*xplen,:] = Xpt[d*xplen:(d+1)*xplen,:]
        yp2[i*xplen:(i+1)*xplen] = i
    Xst = np.copy(xs2)
    Xpt = np.copy(xp2)
    yst = np.copy(ys2)
    ypt = np.copy(yp2)


### BIAS: Manually reduce channels with data
if False:
    Ni = 63 # Device 1 removal; 0-63
    Mi = 0 # Device 2 removal; 0-
    if Ni != 0:
        Xst[:,:Ni] = 0.
        Xpt[:,:Ni] = 0.
    if Mi != 0:
        Xst[:,-Mi:] = 0.
        Xpt[:,-Mi:] = 0.

### BIAS: Device signal power shift
# Warning: Reducing power too low may cause errors (zeros in SSPOC don't behave)
if False:
    # decimal to multiply for all data per device
    dev1scale = 1.
    dev2scale = .1
    Xst[:,:63] *= dev1scale
    Xpt[:,:63] *= dev1scale
    Xst[:,63:] *= dev2scale
    Xpt[:,63:] *= dev2scale

In [None]:
### Process Simulation: Full dataset
scoef,saccs,saccs2 = MC_train(cut=False,sensors=None,X=Xst,y=yst,Xt=Xst,yt=yst,nbasis=nbasis,randomN=0)

In [None]:
### Process Simulation: Top-N restricted
if scoef.size == 128:
    sorder = np.flip(np.argsort(np.abs(scoef)))
    sordert = (np.argsort(np.abs(scoef)))[-N:]
    sorderb = (np.argsort(np.abs(scoef)))[:N]
else:
    sorder = np.flip(np.argsort(np.mean(np.abs(scoef),axis=1)))
    sordert = (np.argsort(np.mean(np.abs(scoef),axis=1)))[-N:]
    sorderb = (np.argsort(np.mean(np.abs(scoef),axis=1)))[:N]
scoeft10,saccst10,saccs2t10 = MC_train(cut=True,sensors=sordert,X=Xst,y=yst,Xt=Xst,yt=yst,nbasis=nbasis,randomN=0)

In [None]:
### Process Simulation: Bottom-N restricted
scoefb10,saccsb10,saccs2b10 = MC_train(cut=True,sensors=sorderb,X=Xst,y=yst,Xt=Xst,yt=yst,nbasis=nbasis,randomN=0)

In [None]:
### Process Simulation: Random N
scoefr10,saccsr10,saccs2r10 = MC_train(cut=False,sensors=None,X=Xst,y=yst,Xt=Xst,yt=yst,nbasis=nbasis,randomN=N)

In [None]:
# Print orders
print('Simulation')
print(f'Top-N\n',f'{np.round(sordert[sordert<64].shape[0]*100/N)}% device 1\n','Dev1:',np.sort(sordert[sordert<64]),'\n','Dev2:',np.sort(sordert[sordert>63]-64))
print(f'Bottom-N\n',f'{np.round(sorderb[sorderb<64].shape[0]*100/N)}% device 1\n','Dev1:',np.sort(sorderb[sorderb<64]),'\n','Dev2:',np.sort(sorderb[sorderb>63]-64))

In [None]:
# Test plots
plt.plot(np.arange(1,128,4),saccs,'b',label="Total")
plt.plot(np.arange(1,128,4),saccs2,'b.')
plt.plot(np.arange(1,N,4),saccst10,'g',label=f"Top-{N}")
plt.plot(np.arange(1,N,4),saccs2t10,'g.')
plt.plot(np.arange(1,N,4),saccsr10,'r',label=f"Random-{N}")
plt.plot(np.arange(1,N,4),saccs2r10,'r.')
plt.plot(np.arange(1,N,4),saccsb10,'y',label=f"Bottom-{N}")
plt.plot(np.arange(1,N,4),saccs2b10,'y.')
#plt.plot(np.arange(1,128,4),saccsr2,'y',label=f"Random v2")
#plt.plot(np.arange(1,128,4),saccs2r2,'y.')
plt.legend(loc='lower right')
plt.ylim([0,1])
plt.title("Simulation")
plt.show()

In [None]:
### Process Phantom: Total dataset
pcoef,paccs,paccs2 = MC_train(cut=False,sensors=None,X=Xpt,y=ypt,Xt=Xpt,yt=ypt,nbasis=nbasis,randomN=0)

In [None]:
### Process Phantom: Top-N restricted
if scoef.size == 128:
    porder = np.flip(np.argsort(np.abs(pcoef)))
    pordert = (np.argsort(np.abs(pcoef)))[-N:]
    porderb = (np.argsort(np.abs(pcoef)))[:N]
else:
    porder = np.flip(np.argsort(np.mean(np.abs(pcoef),axis=1)))
    pordert = (np.argsort(np.mean(np.abs(pcoef),axis=1)))[-N:]
    porderb = (np.argsort(np.mean(np.abs(pcoef),axis=1)))[:N]

pcoeft10,paccst10,paccs2t10 = MC_train(cut=True,sensors=pordert,X=Xpt,y=ypt,Xt=Xpt,yt=ypt,nbasis=nbasis,randomN=0)

In [None]:
### Process Phantom: Bottom-N restricted
pcoefb10,paccsb10,paccs2b10 = MC_train(cut=True,sensors=porderb,X=Xpt,y=ypt,Xt=Xpt,yt=ypt,nbasis=nbasis,randomN=0)

In [None]:
### Process Phantom: Random N
pcoefr10,paccsr10,paccs2r10 = MC_train(cut=False,sensors=None,X=Xpt,y=ypt,Xt=Xpt,yt=ypt,nbasis=nbasis,randomN=N)

In [None]:
# Print orders
print('Phantom')
print(f'Top-N\n',f'{np.round(pordert[pordert<64].shape[0]*100/N)}% device 1\n','Dev1:',np.sort(pordert[pordert<64]),'\n','Dev2:',np.sort(pordert[pordert>63]-64))
print(f'Bottom-N\n',f'{np.round(porderb[porderb<64].shape[0]*100/N)}% device 1\n','Dev1:',np.sort(porderb[porderb<64]),'\n','Dev2:',np.sort(porderb[porderb>63]-64))

In [None]:
# Test plots
plt.plot(np.arange(1,128,4),paccs,'b',label="Total")
plt.plot(np.arange(1,128,4),paccs2,'b.')
plt.plot(np.arange(1,N,4),paccst10,'g',label=f"Top-{N}")
plt.plot(np.arange(1,N,4),paccs2t10,'g.')
plt.plot(np.arange(1,N,4),paccsr10,'r',label=f"Random-{N}")
plt.plot(np.arange(1,N,4),paccs2r10,'r.')
plt.plot(np.arange(1,N,4),paccsb10,'y',label=f"Bottom-{N}")
plt.plot(np.arange(1,N,4),paccs2b10,'y.')
#plt.plot(np.arange(1,128,4),paccsr2,'y',label=f"Random v2")
#plt.plot(np.arange(1,128,4),paccs2r2,'y.')
plt.legend(loc='lower right')
plt.title("Phantom")
plt.ylim([0,1])
plt.show()

### Search Basis Optimization

In [None]:
### Settings ###
# This section can be used to assess the best basis selection and number of modes
# This may vary by dataset

run_basis=False

if run_basis: # don't overwrite settings if not used
    nbasis_test = np.arange(2,42,2)
    N = 16 # Top and random sensor N
    noise = 4.1*4 # uV; standard of 4.1 for DiSc
    l1 = 0.001
    testsplit = 0.2
    MCcount = 2

In [None]:
# Process N-mode search; may take some time
if run_basis:
    s_basis_accs = np.zeros((4,nbasis_test.shape[0]))
    p_basis_accs = np.zeros((4,nbasis_test.shape[0]))
    for i,n in enumerate(nbasis_test):
        print(f'Basis cycle {i+1} of {nbasis_test.shape[0]+1}. N={n}')
        # Compute all accuracies
        scoef,saccs,_ = MC_train(cut=False,sensors=None,X=Xst,y=yst,Xt=Xst,yt=yst,nbasis=n,randomN=0)
        sordert = (np.argsort(np.mean(np.abs(scoef),axis=1)))[-N:]
        sorderb = (np.argsort(np.mean(np.abs(scoef),axis=1)))[:N]
        _,saccst10,_ = MC_train(cut=True,sensors=sordert,X=Xst,y=yst,Xt=Xst,yt=yst,nbasis=n,randomN=0)
        _,saccsr10,_ = MC_train(cut=False,sensors=None,X=Xst,y=yst,Xt=Xst,yt=yst,nbasis=n,randomN=N)
        _,saccsb10,_ = MC_train(cut=False,sensors=sorderb,X=Xst,y=yst,Xt=Xst,yt=yst,nbasis=n,randomN=N)
        
        pcoef,paccs,_ = MC_train(cut=False,sensors=None,X=Xpt,y=ypt,Xt=Xpt,yt=ypt,nbasis=n,randomN=0)
        pordert = (np.argsort(np.mean(np.abs(pcoef),axis=1)))[-N:] # order varies between pcoef with noise and pcoefn without noise
        porderb = (np.argsort(np.mean(np.abs(pcoef),axis=1)))[:N]
        _,paccst10,_ = MC_train(cut=True,sensors=pordert,X=Xpt,y=ypt,Xt=Xpt,yt=ypt,nbasis=n,randomN=0)
        _,paccsr10,_ = MC_train(cut=False,sensors=None,X=Xpt,y=ypt,Xt=Xpt,yt=ypt,nbasis=n,randomN=N)
        _,paccsb10,_ = MC_train(cut=True,sensors=porderb,X=Xpt,y=ypt,Xt=Xpt,yt=ypt,nbasis=n,randomN=0)

        # Save values to array
        s_basis_accs[0,i] = saccs[(N//4)-1]
        s_basis_accs[1,i] = saccst10[-1]
        s_basis_accs[2,i] = saccsr10[-1]
        s_basis_accs[3,i] = saccsb10[-1]
        p_basis_accs[0,i] = paccs[(N//4)-1]
        p_basis_accs[1,i] = paccst10[-1]
        p_basis_accs[2,i] = paccsr10[-1]
        p_basis_accs[3,i] = paccsb10[-1]

In [None]:
# Plot basis optimization results
if run_basis:
    fig, ax = plt.subplots(1,1)

    plt.plot(nbasis_test,s_basis_accs[0],label='Sim Full')
    plt.plot(nbasis_test,s_basis_accs[1],label=f'Sim Top {N}')
    plt.plot(nbasis_test,s_basis_accs[2],label=f'Sim Rand {N}')
    plt.plot(nbasis_test,s_basis_accs[3],label=f'Sim Bottom {N}')
    plt.plot(nbasis_test,p_basis_accs[0],label='Phan Full')
    plt.plot(nbasis_test,p_basis_accs[1],label=f'Phan Top {N}')
    plt.plot(nbasis_test,p_basis_accs[2],label=f'Phan Rand {N}')
    plt.plot(nbasis_test,p_basis_accs[3],label=f'Phan Bottom {N}')
    plt.legend(loc='lower left')
    plt.xlabel('Number of Modes')
    plt.ylabel(f'Accuracy @ {N} sensors')
    plt.title("Gaussian")

    # Save
    file = path.join(figdir,'Phantom-FigX-basis_opt.pdf')
    fig.savefig(file,transparent=True,dpi=600)

## Produce Figures

In [None]:
### Figure Directory ###
figdir = r'C:\Users\willi\Documents\NEI\SEPIO'

### Figure 9-1

In [None]:
### Figure 3A ###
x0 = np.arange(4,132,4)
x = np.arange(4,N+4,4)
ymin = 0.
xmax = 32
legend_loc = 'lower right'

fig, ax = plt.subplots(1,1)

# Simulation
plt.plot(x0,saccs,color='#800000',label="Simulation")
plt.plot(x0,saccs2,color='#800000',marker='.',linestyle='')
plt.plot(x,saccst10,color='#FF0000',label=f"Simulation Top-{N}")
#plt.plot(x,saccsr10,color='#FF4500',label=f"Simulation Random-{N}")
plt.plot(x,saccsb10,color='#FFD700',label=f"Simulation Bottom-{N}")
plt.plot(x,saccs2t10,color='#FF0000',marker='^',linestyle='')
#plt.plot(x,saccs2r10,color='#FF4500',marker='>',linestyle='')
plt.plot(x,saccs2b10,color='#FFD700',marker='v',linestyle='')

# Phantom
plt.plot(x0,paccs,color='#000080',label="Phantom")
plt.plot(x0,paccs2,color='#000080',marker='.',linestyle='')
plt.plot(x,paccst10,color='#0000FF',label=f"Phantom Top-{N}")
#plt.plot(x,paccsr10,color='#008080',label=f"Phantom Random-{N}")
plt.plot(x,paccsb10,color='#00FA9A',label=f"Phantom Bottom-{N}")
plt.plot(x,paccs2t10,color='#0000FF',marker='^',linestyle='')
#plt.plot(x,paccs2r10,color='#008080',marker='>',linestyle='')
plt.plot(x,paccs2b10,color='#00FA9A',marker='v',linestyle='')

# Display settings
plt.ylim([ymin,1.0])
plt.xlim([0,xmax])
plt.title('Classification')
plt.ylabel('Accuracy')
plt.xlabel('Number of Sensors')
plt.legend(loc=legend_loc)

fig.tight_layout()
plt.show()

# Save
fig.savefig(path.join(output,'9_phantom_fig1.png'),transparent=True,dpi=600)

In [None]:
### Figure 9A - Split per dataset ###
x0 = np.arange(4,132,4)
x = np.arange(4,N+4,4)
ymin = 0.2
xmax = 64
legend_loc = 'lower right'

fig, ax = plt.subplots(1,2)

# Simulation
ax[0].plot(x0,saccs,color='#800000',label="Simulation")
ax[0].plot(x0,saccs2,color='#800000',marker='.',linestyle='')
ax[0].plot(x,saccst10,color='#FF0000',label=f"Simulation Top-{N}")
ax[0].plot(x,saccsr10,color='#FF4500',label=f"Simulation Random-{N}")
ax[0].plot(x,saccsb10,color='#FFD700',label=f"Simulation Bottom-{N}")
ax[0].plot(x,saccs2t10,color='#FF0000',marker='^',linestyle='')
ax[0].plot(x,saccs2r10,color='#FF4500',marker='>',linestyle='')
ax[0].plot(x,saccs2b10,color='#FFD700',marker='v',linestyle='')

# Phantom
ax[1].plot(x0,paccs,color='#000080',label="Phantom")
ax[1].plot(x0,paccs2,color='#000080',marker='.',linestyle='')
ax[1].plot(x,paccst10,color='#0000FF',label=f"Phantom Top-{N}")
ax[1].plot(x,paccsr10,color='#008080',label=f"Phantom Random-{N}")
ax[1].plot(x,paccsb10,color='#00FA9A',label=f"Phantom Bottom-{N}")
ax[1].plot(x,paccs2t10,color='#0000FF',marker='^',linestyle='')
ax[1].plot(x,paccs2r10,color='#008080',marker='>',linestyle='')
ax[1].plot(x,paccs2b10,color='#00FA9A',marker='v',linestyle='')

# Display settings
ax[0].set_ylim([ymin,1.0])
ax[0].set_xlim([0,xmax])
ax[0].set_title('Simulation Classification')
ax[0].set_ylabel('Accuracy')
ax[0].set_xlabel('Number of Sensors')
ax[0].legend(loc=legend_loc)
ax[1].set_ylim([ymin,1.0])
ax[1].set_xlim([0,xmax])
ax[1].set_title('Phantom Classification')
ax[1].set_ylabel('Accuracy')
ax[1].set_xlabel('Number of Sensors')
ax[1].legend(loc=legend_loc)

fig.tight_layout()
plt.show()

### Bonus: Signal RMS for SEPIO top and bottom, per-dataset

In [None]:
# Display the distribution of RMS per channel for top and bottom sensors
fig, ax = plt.subplots(2,1)

st_rms = np.sqrt(np.mean(Xst[:,sorder][:,:N]**2,axis=0))
sb_rms = np.sqrt(np.mean(Xst[:,sorder][:,-N:]**2,axis=0))

pt_rms = np.sqrt(np.mean(Xpt[:,porder][:,:N]**2,axis=0))
pb_rms = np.sqrt(np.mean(Xpt[:,porder][:,-N:]**2,axis=0))

ax[0].violinplot(st_rms,positions=[1],showmeans=True,showextrema=False)
ax[0].violinplot(sb_rms,positions=[2],showmeans=True,showextrema=False)
ax[0].vlines([1],np.mean(st_rms)-np.std(st_rms),np.mean(st_rms)+np.std(st_rms))
ax[0].vlines([2],np.mean(sb_rms)-np.std(sb_rms),np.mean(sb_rms)+np.std(sb_rms),color='orange')

ax[1].violinplot(pt_rms,positions=[1],showmeans=True,showextrema=False)
ax[1].violinplot(pb_rms,positions=[2],showmeans=True,showextrema=False)
ax[1].vlines([1],np.mean(pt_rms)-np.std(pt_rms),np.mean(pt_rms)+np.std(pt_rms))
ax[1].vlines([2],np.mean(pb_rms)-np.std(pb_rms),np.mean(pb_rms)+np.std(pb_rms),color='orange')

ax[0].set_title('Simulation')
ax[0].xaxis.set_ticks([1,2],[f'Top {N}',f'Bottom {N}'])
ax[0].set_ylabel('Distribution of RMS\n(uV; per channel)')

ax[1].set_title('Phantom')
ax[1].xaxis.set_ticks([1,2],[f'Top {N}',f'Bottom {N}'])
ax[1].set_ylabel('Distribution of RMS\n(uV; per channel)')

fig.tight_layout()
plt.show()

# Save
fig.savefig(path.join(output,'9_phantom-fig2.png'),transparent=True,dpi=600)