In [None]:
import numpy as np
from viz import quickplot
import matplotlib.pyplot as plt
import seaborn as sns
import mne
import pickle as pkl
import time
from sim import *
from inverse_solutions import *
from source_covs import *
from util import *
from evaluate import *
%matplotlib qt
pth_res = 'assets'

## Load Data

In [None]:
with open(pth_res + '/leadfield.pkl', 'rb') as file:
    leadfield = pkl.load(file)[0]
with open(pth_res + '/pos.pkl', 'rb') as file:  
    pos = pkl.load(file)[0]
with open(pth_res + '/info.pkl', 'rb') as file:  
    info = pkl.load(file)

fwd = mne.read_forward_solution(pth_res + '/fsaverage-fwd.fif', verbose=0)

## Region growing

In [None]:
numberOfDipoles, numberOfElectrodes = pos.shape
tris_lr = [fwd['src'][0]['use_tris'], fwd['src'][1]['use_tris']]
pos_lr = [pos[:int(numberOfDipoles/2), :], pos[-int(numberOfDipoles/2):, :]]
# determine distance of closest neighbor for each dipole
distanceCrit = calc_dist_crit(pos)
print(f'A neighbor is a voxel in {distanceCrit:.1f} mm vicinity')
# Pick single dipole
pick_idx = np.random.choice(np.arange(numberOfDipoles), size=1)[0]
pick_pos = pos[pick_idx, :]
hem = int(pos[pick_idx, 0] > 0)
# Find its first neighbors on the mesh
neighbors = get_triangle_neighbors(tris_lr)

In [None]:
# quick plot for testing
order = 2
y = np.zeros((numberOfDipoles))
y[get_n_order_indices(order, pick_idx, neighbors)] = 1

quickplot(y, pth_res, backend='mayavi', title=f'Order: {order}')

In [None]:
def calc_dist_vec(singlePosition, allPositions):
    return np.sqrt(np.sum(np.square(allPositions - singlePosition), axis=1))
def calc_dist_crit(pos):
    numberOfDipoles = pos.shape[0]
    distOfClosestNeighbor = np.zeros((numberOfDipoles))
    for i in range(numberOfDipoles):
        distVec = calc_dist_vec(pos[i, :], pos)
        distOfClosestNeighbor[i] = np.min(distVec[np.nonzero(distVec)])
    distanceCrit = np.ceil(np.max(distOfClosestNeighbor))
    return distanceCrit

def get_triangle_neighbors(tris_lr):
    ''' Make a list for each dipole with indices of its triangle neighbors.    
    '''
    numberOfDipoles = len(np.unique(tris_lr[0])) + len(np.unique(tris_lr[1]))
    neighbors = [list() for _ in range(numberOfDipoles)]
    # correct right-hemisphere triangles
    tris_lr_adjusted = deepcopy(tris_lr)
    tris_lr_adjusted[1] += int(numberOfDipoles/2)
    # left
    for hem in range(2):
        for idx in range(numberOfDipoles):
            trianglesOfIndex = tris_lr_adjusted[hem][np.where(tris_lr_adjusted[hem] == idx)[0], :]
            for tri in trianglesOfIndex:
                neighbors[idx].extend(tri)
                # Remove self-index (otherwise neighbors[idx] is its own neighbor)
                neighbors[idx] = list(filter(lambda a: a != idx, neighbors[idx]))
            # Remove duplicates
            neighbors[idx] = list(np.unique(neighbors[idx]))                    
    return neighbors

def get_n_order_indices(order, pick_idx, neighbors):
    ''' Iteratively performs region growing by selecting neighbors of 
    neighbors for <order> iterations.
    '''
    if order == 0:
        return pick_idx
    flatten = lambda t: [item for sublist in t for item in sublist]

    current_indices = [pick_idx]
    for cnt in range(order):
        # current_indices = list(np.array( current_indices ).flatten())
        # print(f'\norder={cnt}, current_indices={current_indices}\n')
        new_indices = [neighbors[i] for i in current_indices]
        new_indices = flatten( new_indices )
        current_indices.extend(new_indices)
    return current_indices