In [1]:
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import pandas as pd
import pickle
import spam.kalisphera

# Classes
## SyntheticVTA

In [5]:
class SyntheticVTA:
    """ Represent the synthetic stimulation used for the parameters optimization.
    This class contains the amplitude sequences to be assessed, and methods needed to generate synthetic VTA."""

    # Class constants:
    # Coefficients of the stim. amplitude to VTA radius relation
    # --- I(r) = a*r² + b*r + c ---
    COEF_A = 0.3595
    COEF_B = 0.4279
    COEF_C = 0

    # Coefficient for displacement of center of mass
    # --- c(r) = c_0 + v/|v| * gamma * r ---
    COEF_GAMMA = 0.1

    def __init__(self, psm, c0, is_directional=False):
        self.shape = psm.shape
        self.voxel_size = psm.transform[1, 1]
        self.amplitudes = np.linspace(1, 8, 15)
        self.index = -1
        self.endFlag = False
        self.c0 = c0

        #if is_directional:
            #self.center = c0 +

        # Compute unit direction vector

    def set_amp_sequence(self, amplitudes):
        self.amplitudes = amplitudes

    def next(self) -> bool:
        self.index += 1

        if self.index >= len(self.amplitudes):
            self.index = -1
            self.endFlag = True

        return self.endFlag

    def get_vta(self, disp=False):
        # Assign parameters
        vta_image = np.zeros(self.shape, dtype="<f8")
        vta_amp = self.amplitudes[self.index]
        vta_radius = self.current2radius(vta_amp)/self.voxel_size

        # Convert back from numpy float to native float (to avoid error in kalisphera)
        vta_radius = vta_radius.item()

        # Creates the image with synthetic VTA
        spam.kalisphera.makeSphere(vta_image, self.center, vta_radius)

        if disp:
            plt.figure()
            plt.imshow(vta_image[vta_image.shape[0] // 2], vmin=0, vmax=1, cmap='Greys_r')
            plt.show()

        return vta_image, vta_amp

    def directional_center(self):
        return self.center

    @classmethod
    def radius2current(cls, radius) -> float:
        # Function to obtain the current from the radius (in MNI space) of a VTA
        current = cls.COEF_A*np.square(radius) + cls.COEF_B*radius + cls.COEF_C

        return current.item()

    @classmethod
    def current2radius(cls, current) -> float:
        # Function to obtain the radius (in MNI space) of a VTA from the current
        radius = np.sqrt(current/cls.COEF_A + np.square(cls.COEF_B/(2*cls.COEF_A)) - cls.COEF_C/cls.COEF_A) - (cls.COEF_B/(2*cls.COEF_A))

        return radius.item()


## PSM

In [6]:
class PSM:
    """ Represent a Probabilistic Stimulation Map.
    This class is used to handle the PSM and clinical improvement estimation (linear model) based on the map."""

    def __init__(self, filename):
        # Input argument filename shall not have extension!

        # Read nifti file
        nii = nib.load(filename + '.nii.gz')

        # Read pickle file for additional coefficient (intercept and stim. amplitude)
        pkl_file = open(filename + '.pkl', 'rb')
        dict_coef = pickle.load(pkl_file)
        pkl_file.close()

        self.coef_intercept = dict_coef['intercept']
        self.coef_amplitude = dict_coef['amplitude']
        self.coef_image = nii.get_fdata()
        self.transform = nii.affine
        self.shape = self.coef_image.shape

    def estimate(self, VTA):
        return np.sum(np.multiply(VTA, self.coef_image), axis=None)


### LogRegPSM (PSM subclass)

In [7]:
class LogRegPSM(PSM):
    """ Represent a Probabilistic Stimulation Map with logistic regression.
    This subclass extend the PSM class for logistic regression model map.
    It includes the overide of the estimate method adapted with the sigmoid link function."""

    def estimate(self, VTA_image, VTA_amp):
        bX_img = np.sum(np.multiply(VTA_image, self.coef_image), axis=None)
        bX_amp = VTA_amp * self.coef_amplitude
        b0 = self.coef_intercept

        return self.sigmoid(bX_img + bX_amp + b0)

    @staticmethod
    def sigmoid(z):
        return 1/(1+np.exp(-z))

## Lead

In [8]:
class Lead:
    OMNIDIR_CONTACTS = [0, 16, 17, 7, 8, 18, 19, 15]
    EFFECT_THRESHOLD = 0.99

    def __init__(self, df: pd.DataFrame, leadID: float, psm: PSM):

        # Set property
        self.psm = psm
        self.leadID = leadID

        # Filter table to keep only the correct lead
        df_contacts = df.loc[df.leadID == leadID]
        df_contacts.reset_index(drop=True, inplace=True)
        df_levels = df_contacts.loc[df_contacts.contactID.isin(self.OMNIDIR_CONTACTS)]
        df_levels.reset_index(drop=True, inplace=True)

        # Keep the dataframes as property for best level/contact selection
        self.df_contacts = df_contacts
        self.df_levels = df_levels

        # Get the rows of ALL minima
        row_best_contact = df_contacts.loc[df_contacts.amplitude == np.min(df_contacts.amplitude)]
        row_best_level = df_levels.loc[df_levels.amplitude == np.min(df_levels.amplitude)]

        # Get the contact ID
        gt_best_contact = row_best_contact['contactID'].to_numpy()
        gt_best_level = row_best_level['contactID'].to_numpy()

        print('            GT (level) =', gt_best_level)

        # Define ground truth as being the smallest effect threshold of the lead
        self.groundtruth = {'bestContact':  gt_best_contact,
                            'bestLevel':    gt_best_level}

    @classmethod
    def set_effect_threshold(cls, threshold):
        cls.EFFECT_THRESHOLD = threshold

    def find_best_level(self):
        # Init variable
        all_therapeutic_windows = []

        for index, level in self.df_levels.iterrows():

            print('   Processing level #', level.contactID)
            therapeutic_window = self.find_therapeutic_window(level)
            all_therapeutic_windows.append(therapeutic_window)

            # Feedback to user
            print('      Therapeutic window: ', therapeutic_window)

        return self.df_levels['contactID'][np.argmin(all_therapeutic_windows)]


    def find_best_contact(self, psm):

        pass # To be implemented


    def find_therapeutic_window(self, level):

        # Get the center coordinates in a single array
        center_mni = np.array([level.contactCoord_1, level.contactCoord_2, level.contactCoord_3])
        center_mni = np.expand_dims(center_mni, axis=1)

        # Transform the center in the voxel space
        center = homogeneous_transform(center_mni, np.linalg.inv(self.psm.transform), isvector=0)

        # Create synthetic VTA object
        stim = SyntheticVTA(self.psm, center)

        # Init empty list to keep track of scores
        scores = []
        amplitudes = []

        while not stim.next():
            # Get the synthetic VTA
            vta_image, vta_amp = stim.get_vta()

            # Compute estimation
            estimated_score = self.psm.estimate(vta_image, vta_amp)

            # Feedback to user
            print('   - (', vta_amp, ' mA ) Relative improvement = ', estimated_score)

            # Check if clinical improvement is higher than defined effect threshold
            if estimated_score > self.EFFECT_THRESHOLD:
                return vta_amp

        # Return none if effect threshold was not reached
        return np.inf



# Functions

In [9]:
def homogeneous_transform(coord, transform, isvector=0):
    # Apply transform to the points array p0
    # [p1 1]^T = M x [p0 1]^T

    if isvector:
        coord = np.concatenate((coord, np.zeros((1, coord.shape[1]))), axis=0)
    else:
        coord = np.concatenate((coord, np.ones((1, coord.shape[1]))), axis=0)

    transformed_coord = transform @ coord

    return transformed_coord[0:3, :].T

In [10]:
def remove_empty_leads(df):
    # For best level selection, some leads may contains only directional data and make the pipeline fail. Filter those lead avoid any issue.
    omnidir_contact = [0, 16, 17, 7, 8, 18, 19, 15]

    for leadID in np.unique(df.leadID):
        df_contacts = df.loc[df.leadID == leadID]
        df_contacts.reset_index(drop=True, inplace=True)
        df_levels = df_contacts.loc[df_contacts.contactID.isin(omnidir_contact)]
        df_levels.reset_index(drop=True, inplace=True)

        if len(df_levels) == 0:
            df = df.drop(np.where(df.leadID == leadID)[0])
            df.reset_index(drop=True, inplace=True)

    return df

# Load data

In [11]:
df = pd.read_csv('../../03_Data/01_Tables/bernTableElectrodeLoc.csv')

# Filter table if needed (in this case: only right hemisphere leads are used)
df = df.loc[np.logical_not(np.mod(2*df.leadID, 2).astype(bool))]
df.reset_index(drop=True, inplace=True)

# Load map
psm = LogRegPSM('../../03_Data/06_Maps/map_v2')

# Test leads

In [12]:
def get_correct_ratio(effect_thres):

    correct = 0
    total = 0


    df_level = remove_empty_leads(df)
    Lead.set_effect_threshold(effect_thres)

    for leadID in np.unique(df_level.leadID):

        print('Processing lead #', leadID)
        lead = Lead(df, leadID, psm)
        best_level_pred = lead.find_best_level()
        print('         Best level:', best_level_pred)

        if best_level_pred in lead.groundtruth['bestLevel']:
            print('         (CORRECT)')
            correct += 1

        else:
            print('         (FALSE) --> GT = ', lead.groundtruth['bestLevel'])

        total += 1

    return correct/total

In [13]:
effect_thresholds = np.linspace(0, 1, 10)
ratios = []
for effect_threshold in effect_thresholds:
    ratios.append(get_correct_ratio(effect_threshold))

plt.figure()

plt.plot(effect_thresholds, ratios)
plt.xlabel("effect threshold")
plt.ylabel("ratio of correct guess")
plt.show()

Processing lead # 1.0
            GT (level) = [7]
   Processing level # 0


AttributeError: 'SyntheticVTA' object has no attribute 'center'