<a href="https://colab.research.google.com/github/aykuteken/aykuteken/blob/main/Pain_decoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
pip install shap

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting shap
  Downloading shap-0.41.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (575 kB)
[K     |████████████████████████████████| 575 kB 31.6 MB/s 
Collecting slicer==0.0.7
  Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.41.0 slicer-0.0.7


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 18 15:12:58 2022

@author: aykuteken
"""
##  Pain Decoding using fNIRS and Deep Longitudinal Transfer Learning (Eken, 2022)
## Dataset that was previously collected by 
# Peng K, Yücel MA, Steele SC, Bittner EA, Aasted CM, 
# Hoeft MA, Lee A, George EE, Boas DA, Becerra L and Borsook D (2018) 
# Morphine Attenuates fNIRS Signal Associated With Painful Stimuli in the Medial 
# Frontopolar Cortex (medial BA 10). Front. Hum. Neurosci. 12:394. doi: 10.3389/fnhum.2018.00394

import pandas as pd
import scipy as sp
from scipy import io
from sklearn.model_selection import train_test_split
import json
from os import environ
environ["KERAS_BACKEND"] = "plaidml.keras.backend"
import numpy as np
from numpy import random
import os 
import tensorflow as tf
from tensorflow import keras
from keras.layers import Dense,Flatten, Conv1D, MaxPooling1D, Dropout
from keras import Sequential
import matplotlib.pyplot as plt
import shap

plt.rcParams['lines.linewidth'] = 5
plt.rcParams['lines.linestyle'] = '-'
plt.rcParams['figure.figsize'] = [32,24]
plt.rcParams['axes.labelsize']=25
plt.rcParams['axes.labelweight']='bold'

plt.rcParams['font.size']=25
plt.rcParams['font.weight']='bold'
plt.rcParams['figure.dpi']=100
plt.rcParams["legend.loc"] = 'right' 

random.seed(42)
# First, import the variables 

FVs = io.loadmat('/content/Feature_fNIRS_TF.mat',mat_dtype=True,matlab_compatible=True)

Morphine_HbO_pre_vas3=FVs['Morphine_HbO_pre']['vas3'][0][0]
Morphine_HbO_pre_vas7=FVs['Morphine_HbO_pre']['vas7'][0][0]
Morphine_Hb_pre_vas3=FVs['Morphine_Hb_pre']['vas3'][0][0]
Morphine_Hb_pre_vas7=FVs['Morphine_Hb_pre']['vas7'][0][0]
Morphine_pre_subj_vas3=FVs['Morphine_pre']['vas3_subj'][0][0]
Morphine_pre_subj_vas7=FVs['Morphine_pre']['vas7_subj'][0][0]


Placebo_HbO_pre_vas3=FVs['Placebo_HbO_pre']['vas3'][0][0]
Placebo_HbO_pre_vas7=FVs['Placebo_HbO_pre']['vas7'][0][0]
Placebo_Hb_pre_vas3=FVs['Placebo_Hb_pre']['vas3'][0][0]
Placebo_Hb_pre_vas7=FVs['Placebo_Hb_pre']['vas7'][0][0]
Placebo_pre_subj_vas3=FVs['Placebo_pre']['vas3_subj'][0][0]
Placebo_pre_subj_vas7=FVs['Placebo_pre']['vas7_subj'][0][0]

Morphine_HbO_post30_vas3=FVs['Morphine_HbO_post30']['vas3'][0][0]
Morphine_HbO_post30_vas7=FVs['Morphine_HbO_post30']['vas7'][0][0]
Morphine_Hb_post30_vas3=FVs['Morphine_Hb_post30']['vas3'][0][0]
Morphine_Hb_post30_vas7=FVs['Morphine_Hb_post30']['vas7'][0][0]
Morphine_post30_subj_vas3=FVs['Morphine_post30']['vas3_subj'][0][0]
Morphine_post30_subj_vas7=FVs['Morphine_post30']['vas3_subj'][0][0]

Placebo_HbO_post30_vas3=FVs['Placebo_HbO_post30']['vas3'][0][0]
Placebo_HbO_post30_vas7=FVs['Placebo_HbO_post30']['vas7'][0][0]
Placebo_Hb_post30_vas3=FVs['Placebo_Hb_post30']['vas3'][0][0]
Placebo_Hb_post30_vas7=FVs['Placebo_Hb_post30']['vas7'][0][0]
Placebo_post30_subj_vas3=FVs['Placebo_post30']['vas3_subj'][0][0]
Placebo_post30_subj_vas7=FVs['Placebo_post30']['vas7_subj'][0][0]

Morphine_HbO_post60_vas3=FVs['Morphine_HbO_post60']['vas3'][0][0]
Morphine_HbO_post60_vas7=FVs['Morphine_HbO_post60']['vas7'][0][0]
Morphine_Hb_post60_vas3=FVs['Morphine_Hb_post60']['vas3'][0][0]
Morphine_Hb_post60_vas7=FVs['Morphine_Hb_post60']['vas7'][0][0]
Morphine_post60_subj_vas3=FVs['Morphine_post60']['vas3_subj'][0][0]
Morphine_post60_subj_vas7=FVs['Morphine_post60']['vas7_subj'][0][0]

Placebo_HbO_post60_vas3=FVs['Placebo_HbO_post60']['vas3'][0][0]
Placebo_HbO_post60_vas7=FVs['Placebo_HbO_post60']['vas7'][0][0]
Placebo_Hb_post60_vas3=FVs['Placebo_Hb_post60']['vas3'][0][0]
Placebo_Hb_post60_vas7=FVs['Placebo_Hb_post60']['vas7'][0][0]
Placebo_post60_subj_vas3=FVs['Placebo_post60']['vas3_subj'][0][0]
Placebo_post60_subj_vas7=FVs['Placebo_post60']['vas7_subj'][0][0]

Morphine_HbO_post90_vas3=FVs['Morphine_HbO_post90']['vas3'][0][0]
Morphine_HbO_post90_vas7=FVs['Morphine_HbO_post90']['vas7'][0][0]
Morphine_Hb_post90_vas3=FVs['Morphine_Hb_post90']['vas3'][0][0]
Morphine_Hb_post90_vas7=FVs['Morphine_Hb_post90']['vas7'][0][0]
Morphine_post90_subj_vas3=FVs['Morphine_post90']['vas3_subj'][0][0]
Morphine_post90_subj_vas7=FVs['Morphine_post90']['vas7_subj'][0][0]

Placebo_HbO_post90_vas3=FVs['Placebo_HbO_post90']['vas3'][0][0]
Placebo_HbO_post90_vas7=FVs['Placebo_HbO_post90']['vas7'][0][0]
Placebo_Hb_post90_vas3=FVs['Placebo_Hb_post90']['vas3'][0][0]
Placebo_Hb_post90_vas7=FVs['Placebo_Hb_post90']['vas7'][0][0]
Placebo_post90_subj_vas3=FVs['Placebo_post90']['vas3_subj'][0][0]
Placebo_post90_subj_vas7=FVs['Placebo_post90']['vas7_subj'][0][0]



## We will follow two main paths

## First, we will try to classify the painful condition in sessions seperately. 
## For instance, model knowledge obtained from pre-Drug data will be used for post Drug sessions
## separately (30min, 60 min, 90 min).

## Then, we will try to classify the painful condition in sessions by following the post Drug
## trajectory. Model knowledge obtained from pre-Drug data will be used for to classify data in 
## 30 min post session and model knowledge obtained from 30 min post session will be used for 
## classify 60 min post session and model knowledge coming from 60 min post session will be used
## to classify the 90 min post session.
## Due to the not having enough data, data augmentation will be performed. 

#### ------ DATA AUGMENTATION -----######



def data_split_aug(pre_data,stim_type,fold_type, norm):
    
    # Z-score normalization
    test_data = []
    val_data=[]
    test_label=[]
    val_label =[]

    if norm==1:
        
        for i in range(0,len(pre_data)):
            
            pre_data[i,:,:]=sp.stats.zscore(pre_data[i,:,:],axis=1)
                
    if fold_type=='holdout':
    
        train_data,test_data,train_label, test_label = train_test_split(pre_data,stim_type, test_size=0.2, random_state=None, shuffle=True, stratify=stim_type)
        train_data, val_data, train_label, val_label = train_test_split(train_data,train_label, test_size=0.25, random_state=None, shuffle=True,stratify=train_label)
        
    elif fold_type=='kfold':
        
        train_data = pre_data
        train_label = stim_type
        
    test_data = test_data.reshape(np.size(test_data,0),np.size(test_data,2),np.size(test_data,1))
    train_data = train_data.reshape(np.size(train_data,0),np.size(train_data,2),np.size(train_data,1))
    val_data = val_data.reshape(np.size(val_data,0),np.size(val_data,2),np.size(val_data,1))
        
    s = np.shape(train_data)

    opt =[0,2] # 0 is injecting Gaussian noise, #1 is adding spike #2 is adding trend
    
    aug_data = []
    aug_labels = []
    for j in range(0,20):
        for i in range(0,len(train_label)):
            
            sel =random.choice(opt)
            
            if sel==0:
            
                sigma = [0.05, 0.1, 0.5]
                mu= 0    
        
                ss=random.choice(sigma)
        
                noise = np.random.normal(mu, ss, (1,s[1],s[2]))
        
                new_created_data=train_data[i,:,:]+noise
            
                aug_data.append(np.squeeze(new_created_data))
            
                aug_labels.append(train_label[i])
                
            elif sel==1:
                
                
                std_data=np.std(train_data[i,:,:],axis=0)
                std_data=std_data*1.5
                
                ind = np.arange(0,24)
                
                sel_ind=random.choice(ind)
                
                direct=[0,1]
                
                sel_direct = random.choice(direct)
                
                if sel_direct ==0:
                    
                    train_data[i,:,sel_ind]=-1*std_data[sel_ind]
                    
                else:
                    
                    train_data[i,:,sel_ind]=std_data[sel_ind]
                
                new_created_data=train_data[i,:,:]
            
                aug_data.append(np.squeeze(new_created_data))
            
                aug_labels.append(train_label[i])
                
                
            elif sel==2:
                
                m = [0.01, 0.05, 0.1, 0.2]
                
                x=np.arange(0,24)
                
                x = np.tile(x,(31,1))
                
                
                trend=random.choice(m)*x
                
                new_created_data=train_data[i,:,:]+trend
            
                aug_data.append(np.squeeze(new_created_data))
            
                aug_labels.append(train_label[i])
                
                
                

    aug_train_data = np.stack(aug_data,axis=0)
    aug_train_label = np.stack(aug_labels,axis=0)
    aug_train_label = np.array(aug_train_label,dtype=int)
    
    
    return aug_train_data, test_data, val_data, aug_train_label, test_label, val_label


def Pre_Drug_Model(train_data,val_data,train_label, val_label):
    
    train_label = np.asarray(train_label).astype('float32').reshape((-1,1))
    val_label = np.asarray(val_label).astype('float32').reshape((-1,1))
    
    callbacks = [
    keras.callbacks.ModelCheckpoint(
        "best_pre_model.h5", save_best_only=True, monitor="val_loss"
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.1, patience=10, min_lr=0.00001
    ),
    #keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, verbose=0),
    ]
    drp=0.2
    # design network
    model = Sequential()
    model.add(Conv1D(32, 3, activation='relu', input_shape=(31, 24)))
    model.add(MaxPooling1D(2))
    model.add(Conv1D(64, 3, activation='relu'))
    model.add(MaxPooling1D(2))
    model.add(Conv1D(128, 3, activation='relu'))
    model.add(MaxPooling1D(2))
    model.add(Flatten())
    model.add(Dense(256, activation='relu'))
    model.add(Dropout(drp))
    model.add(Dense(1, activation='sigmoid'))

    
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss="binary_crossentropy",
              metrics=["binary_accuracy"])
    # fit network
    history = model.fit(train_data, train_label, epochs=50, batch_size=32, callbacks=callbacks, validation_data=(val_data, val_label), verbose=0, shuffle=True)
    # plot history
    
    return model,history

def PostDrug_HoldOutModel(model,train_data,val_data, train_label, val_label,data_type, exp_time):
    
    train_label = np.asarray(train_label).astype('float32').reshape((-1,1))
    val_label = np.asarray(val_label).astype('float32').reshape((-1,1))

    # get the base model and don't train it

    model.trainable=False
    print('Before TF')
    print(model.summary())
    model = tf.keras.models.Sequential(model.layers[:-4])

    inputs = keras.Input(shape=(31,24))
    x=model(inputs, training=False)
    x=keras.layers.Flatten()(x)
    x=keras.layers.Dense(128,activation='sigmoid')(x)
    x=keras.layers.Dropout(0.5)(x)
    outputs=keras.layers.Dense(1,activation='sigmoid')(x)
    model = keras.Model(inputs,outputs)

    # model.add(Flatten())
    # model.add(Dense(256, activation='relu'))
    # model.add(Dropout(0.5))
    # model.add(Dense(1, activation='sigmoid'))
    print('After TF')
    print(model.summary())
            
    callbacks = [
    keras.callbacks.ModelCheckpoint(
        "best_post_model.h5", save_best_only=True, monitor="val_loss"
    ),
    keras.callbacks.ReduceLROnPlateau(
         monitor="val_loss", factor=0.1, patience=20, min_lr=0.00001
    ),
    #keras.callbacks.EarlyStopping(monitor="val_loss", patience=20, verbose=0),
    ]
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss="binary_crossentropy",
              metrics=["binary_accuracy"])
    
    history = model.fit(train_data, train_label, epochs=200, batch_size=32, callbacks=callbacks, validation_data=(val_data, val_label), verbose=0, shuffle=True)

    # plot history

    fname = 'Post_Drug_HoldOut_'+data_type+exp_time+'.h5'
    model.save(fname)
        
    return model,history

n=30
pre_hist_tr =[]
pre_hist_val =[]
pre_acc=[]
post30_morph_hist_tr=[]
post30_morph_hist_val=[]
post30_morph_acc =[]
post60_morph_hist_tr=[]
post60_morph_hist_val=[]
post60_morph_acc =[]
post90_morph_hist_tr=[]
post90_morph_hist_val=[]
post90_morph_acc =[]
post30_place_hist_tr=[]
post30_place_hist_val=[]
post30_place_acc =[]
post60_place_hist_tr=[]
post60_place_hist_val=[]
post60_place_acc =[]
post90_place_hist_tr=[]
post90_place_hist_val=[]
post90_place_acc =[]
list_shap_values_pre=[]
list_test_sets_pre=[]
list_shap_values_post30morphine=[]
list_test_sets_post30morphine=[]
list_shap_values_post30placebo=[]
list_test_sets_post30placebo=[]
list_shap_values_post60morphine=[]
list_test_sets_post60morphine=[]
list_shap_values_post60placebo=[]
list_test_sets_post60placebo=[]
list_shap_values_post90morphine=[]
list_test_sets_post90morphine=[]
list_shap_values_post90placebo=[]
list_test_sets_post90placebo=[]
explainer_pre =[]
explainer_post30morphine =[]
explainer_post30placebo =[]
explainer_post60morphine =[]
explainer_post60placebo =[]
explainer_post90morphine =[]
explainer_post90placebo=[]

shap_test_data_pre=[]
shap_test_data_post30morphine=[]
shap_test_data_post30placebo=[]
shap_test_data_post60morphine=[]
shap_test_data_post60placebo=[]
shap_test_data_post90morphine=[]
shap_test_data_post90placebo=[]






for i in range(0,n):
    
    print('##############------------ITERATION '+str(i) +' ---------------##############')
    
    pre_data=np.concatenate([Morphine_HbO_pre_vas3,Placebo_HbO_pre_vas3,Morphine_HbO_pre_vas7,Placebo_HbO_pre_vas7])
    ## vas3 labeled as 0, vas7 labeled as 1
    stim_type = np.concatenate([np.zeros(np.shape(Morphine_HbO_pre_vas3)[0]), np.zeros(np.shape(Placebo_HbO_pre_vas3)[0]),
                           np.ones(np.shape(Morphine_HbO_pre_vas7)[0]), np.ones(np.shape(Placebo_HbO_pre_vas7)[0])])

    aug_train_pre_data, test_pre_data, val_pre_data, aug_train_pre_label, test_pre_label, val_pre_label=data_split_aug(pre_data, 
                                                                                      stim_type,'holdout',1)

    ##-------- Model Development-------- ####


    Pre_Model,pre_model_history = Pre_Drug_Model(aug_train_pre_data, val_pre_data, aug_train_pre_label, val_pre_label)
    
    pre_hist_tr.append(pre_model_history.history['loss'])
    pre_hist_val.append(pre_model_history.history['val_loss'])

    print('\n')
    print('######----Pre Drug Model------########--')
    test_pre_label = np.asarray(test_pre_label).astype('float32').reshape((-1,1))
    pre_model_Acc = Pre_Model.evaluate(test_pre_data, test_pre_label)
    pre_acc.append(pre_model_Acc)
    
    ## SHAP ##
    explainer=[]
    explainer=shap.DeepExplainer(Pre_Model,aug_train_pre_data)
    shap_values = explainer.shap_values(test_pre_data)
    list_shap_values_pre.append(shap_values)
    list_test_sets_pre.append(test_pre_label)
    explainer_pre.append(explainer)
    shap_test_data_pre.append(test_pre_data)
    
    print('\n')
    ## Pre-trained model using Pre-Drug data
    ## Now we transfer the knowledge to a base_model
    ## But first create the data for post 30 min

    post30_data_morphine=np.concatenate([Morphine_HbO_post30_vas3,Morphine_HbO_post30_vas7])
    stim_type_morphine = np.concatenate([np.zeros(np.shape(Morphine_HbO_post30_vas3)[0]),np.ones(np.shape(Morphine_HbO_post30_vas7)[0])])

    post30_data_placebo = np.concatenate([Placebo_HbO_post30_vas3,Placebo_HbO_post30_vas7])
    stim_type_placebo = np.concatenate([np.zeros(np.shape(Placebo_HbO_post30_vas3)[0]),np.ones(np.shape(Placebo_HbO_post30_vas7)[0])])


    aug_train_morphine_post30_data, test_morphine_post30_data, val_morphine_post30_data, aug_train_morphine_post30_label, test_morphine_post30_label, val_morphine_post30_label=data_split_aug(post30_data_morphine, 
                                                                                      stim_type_morphine,'holdout',1)
    
    aug_train_placebo_post30_data, test_placebo_post30_data, val_placebo_post30_data, aug_train_placebo_post30_label, test_placebo_post30_label, val_placebo_post30_label=data_split_aug(post30_data_placebo, 
                                                                                      stim_type_placebo,'holdout',1)


    post30_Model_Serial_Morphine, post30_History_Serial_Morphine= PostDrug_HoldOutModel(Pre_Model,aug_train_morphine_post30_data, val_morphine_post30_data, aug_train_morphine_post30_label, val_morphine_post30_label,'Morphine','30')
    post30_Model_Serial_Placebo,post30_History_Serial_Placebo = PostDrug_HoldOutModel(Pre_Model,aug_train_placebo_post30_data, val_placebo_post30_data, aug_train_placebo_post30_label, val_placebo_post30_label,'Placebo','30')
    
    post30_morph_hist_tr.append(post30_History_Serial_Morphine.history['loss'])
    post30_morph_hist_val.append(post30_History_Serial_Morphine.history['val_loss'])
    post30_place_hist_tr.append(post30_History_Serial_Placebo.history['loss'])
    post30_place_hist_val.append(post30_History_Serial_Placebo.history['val_loss'])  
    
    print('\n')
    print('######----Morphine Post 30 Min TF Network------########')
    post30_Serial_Morph_Acc=post30_Model_Serial_Morphine.evaluate(test_morphine_post30_data, test_morphine_post30_label)[1]
    print('\n')
    print('######----Placebo Post 30 Min TF Network------########')
    test_placebo_post30_label = np.asarray(test_placebo_post30_label).astype('float32').reshape((-1,1))
    post30_Serial_Placebo_Acc=post30_Model_Serial_Placebo.evaluate(test_placebo_post30_data, test_placebo_post30_label)[1]
    print('\n')
    
    post30_morph_acc.append(post30_Serial_Morph_Acc)
    post30_place_acc.append(post30_Serial_Placebo_Acc)
    
    ## SHAP post 30 morphine ##
    explainer=[]
    explainer=shap.DeepExplainer(post30_Model_Serial_Morphine,aug_train_morphine_post30_data)
    shap_values = explainer.shap_values(test_morphine_post30_data)
    list_shap_values_post30morphine.append(shap_values)
    list_test_sets_post30morphine.append(test_morphine_post30_label)
    explainer_post30morphine.append(explainer)
    shap_test_data_post30morphine.append(test_morphine_post30_data)
    
    ## SHAP post 30 placebo ##
    explainer=[]
    explainer=shap.DeepExplainer(post30_Model_Serial_Placebo,aug_train_placebo_post30_data)
    shap_values = explainer.shap_values(test_placebo_post30_data)
    list_shap_values_post30placebo.append(shap_values)
    list_test_sets_post30placebo.append(test_placebo_post30_label)
    explainer_post30placebo.append(explainer)
    shap_test_data_post30placebo.append(test_placebo_post30_data)


    ## Post 60 Morphine and Placebo Using Post 30 Morphine and Placebo Models
    
    post60_data_morphine=np.concatenate([Morphine_HbO_post60_vas3,Morphine_HbO_post60_vas7])
    stim_type_morphine = np.concatenate([np.zeros(np.shape(Morphine_HbO_post60_vas3)[0]),np.ones(np.shape(Morphine_HbO_post60_vas7)[0])])
    
    post60_data_placebo = np.concatenate([Placebo_HbO_post60_vas3,Placebo_HbO_post60_vas7])
    stim_type_placebo = np.concatenate([np.zeros(np.shape(Placebo_HbO_post60_vas3)[0]),np.ones(np.shape(Placebo_HbO_post60_vas7)[0])])

    aug_train_morphine_post60_data, test_morphine_post60_data, val_morphine_post60_data, aug_train_morphine_post60_label, test_morphine_post60_label, val_morphine_post60_label=data_split_aug(post60_data_morphine, 
                                                                                      stim_type_morphine,'holdout',1)

    aug_train_placebo_post60_data, test_placebo_post60_data, val_placebo_post60_data, aug_train_placebo_post60_label, test_placebo_post60_label, val_placebo_post60_label=data_split_aug(post60_data_placebo, 
                                                                                      stim_type_placebo,'holdout',1)


    post60_Model_Serial_Morphine, post60_History_Serial_Morphine = PostDrug_HoldOutModel(Pre_Model,aug_train_morphine_post60_data, val_morphine_post60_data, aug_train_morphine_post60_label, val_morphine_post60_label,'Morphine','60')
    post60_Model_Serial_Placebo, post60_History_Serial_Placebo = PostDrug_HoldOutModel(Pre_Model,aug_train_placebo_post60_data, val_placebo_post60_data, aug_train_placebo_post60_label, val_placebo_post60_label,'Placebo','60')

    post60_morph_hist_tr.append(post60_History_Serial_Morphine.history['loss'])
    post60_morph_hist_val.append(post60_History_Serial_Morphine.history['val_loss'])
    post60_place_hist_tr.append(post60_History_Serial_Placebo.history['loss'])
    post60_place_hist_val.append(post60_History_Serial_Placebo.history['val_loss'])    

    print('\n')
    print('######----Morphine Post 60 Min TF Network------########')
    post60_Serial_Morph_Acc=post60_Model_Serial_Morphine.evaluate(test_morphine_post60_data, test_morphine_post60_label)[1]
    print('\n')
    print('######----Placebo Post 60 Min TF Network------########')
    post60_Serial_Placebo_Acc=post60_Model_Serial_Placebo.evaluate(test_placebo_post60_data, test_placebo_post60_label)[1]
    print('\n')
    ## Post 90 Morphine and Placebo Using Post 60 Morphine and Placebo Models
    post60_morph_acc.append(post60_Serial_Morph_Acc)
    post60_place_acc.append(post60_Serial_Placebo_Acc)
    
    ## SHAP post 60 morphine ##
    explainer=[]
    explainer=shap.DeepExplainer(post60_Model_Serial_Morphine,aug_train_morphine_post60_data)
    shap_values = explainer.shap_values(test_morphine_post60_data)
    list_shap_values_post60morphine.append(shap_values)
    list_test_sets_post60morphine.append(test_morphine_post60_label)
    explainer_post60morphine.append(explainer)
    shap_test_data_post60morphine.append(test_morphine_post60_data)
    
    ## SHAP post 60 placebo ##
    explainer=[]
    explainer=shap.DeepExplainer(post60_Model_Serial_Placebo,aug_train_placebo_post60_data)
    shap_values = explainer.shap_values(test_placebo_post60_data)
    list_shap_values_post60placebo.append(shap_values)
    list_test_sets_post60placebo.append(test_placebo_post60_label)
    explainer_post60placebo.append(explainer)
    shap_test_data_post60placebo.append(test_placebo_post60_data)
    

    post90_data_morphine=np.concatenate([Morphine_HbO_post90_vas3,Morphine_HbO_post90_vas7])
    stim_type_morphine = np.concatenate([np.zeros(np.shape(Morphine_HbO_post90_vas3)[0]),np.ones(np.shape(Morphine_HbO_post90_vas7)[0])])

    post90_data_placebo = np.concatenate([Placebo_HbO_post90_vas3,Placebo_HbO_post90_vas7])
    stim_type_placebo = np.concatenate([np.zeros(np.shape(Placebo_HbO_post90_vas3)[0]),np.ones(np.shape(Placebo_HbO_post90_vas7)[0])])

    aug_train_morphine_post90_data, test_morphine_post90_data, val_morphine_post90_data, aug_train_morphine_post90_label, test_morphine_post90_label, val_morphine_post90_label=data_split_aug(post90_data_morphine, 
                                                                                      stim_type_morphine,'holdout',1)

    aug_train_placebo_post90_data, test_placebo_post90_data, val_placebo_post90_data, aug_train_placebo_post90_label, test_placebo_post90_label, val_placebo_post90_label=data_split_aug(post90_data_placebo, 
                                                                                      stim_type_placebo,'holdout',1)


    post90_Model_Serial_Morphine,post90_History_Serial_Morphine  = PostDrug_HoldOutModel(Pre_Model,aug_train_morphine_post90_data, val_morphine_post90_data, aug_train_morphine_post90_label, val_morphine_post90_label,'Morphine','90')
    post90_Model_Serial_Placebo,post90_History_Serial_Placebo = PostDrug_HoldOutModel(Pre_Model,aug_train_placebo_post90_data, val_placebo_post90_data, aug_train_placebo_post90_label, val_placebo_post90_label,'Placebo','90')

    print('\n')
    print('######----Morphine Post 90 Min TF Network------########')
    post90_Serial_Morph_Acc=post90_Model_Serial_Morphine.evaluate(test_morphine_post90_data, test_morphine_post90_label)[1]
    print('\n')
    print('######----Placebo Post 90 Min TF Network------########')
    post90_Serial_Placebo_Acc=post90_Model_Serial_Placebo.evaluate(test_placebo_post90_data, test_placebo_post90_label)[1]
    print('\n')
    
    post90_morph_acc.append(post90_Serial_Morph_Acc)
    post90_place_acc.append(post90_Serial_Placebo_Acc)
    
    ## SHAP post 90 morphine ##
    explainer=[]
    explainer=shap.DeepExplainer(post90_Model_Serial_Morphine,aug_train_morphine_post90_data)
    shap_values = explainer.shap_values(test_morphine_post90_data)
    list_shap_values_post90morphine.append(shap_values)
    list_test_sets_post90morphine.append(test_morphine_post90_label)
    explainer_post90morphine.append(explainer)
    shap_test_data_post90morphine.append(test_morphine_post90_data)
    
    
    ## SHAP post 90 placebo ##
    explainer=[]
    explainer=shap.DeepExplainer(post90_Model_Serial_Placebo,aug_train_placebo_post90_data)
    shap_values = explainer.shap_values(test_placebo_post90_data)
    list_shap_values_post90placebo.append(shap_values)
    list_test_sets_post90placebo.append(test_placebo_post90_label)
    explainer_post90placebo.append(explainer)
    shap_test_data_post90placebo.append(test_placebo_post90_data)
    
    post90_morph_hist_tr.append(post30_History_Serial_Morphine.history['loss'])
    post90_morph_hist_val.append(post30_History_Serial_Morphine.history['val_loss'])
    post90_place_hist_tr.append(post30_History_Serial_Placebo.history['loss'])
    post90_place_hist_val.append(post30_History_Serial_Placebo.history['val_loss'])  

pre_mean_tr=np.mean(np.vstack(pre_hist_tr),axis=0)
pre_std_tr=np.std(np.vstack(pre_hist_tr),axis=0)##--##
pre_mean_val=np.mean(np.vstack(pre_hist_val),axis=0)
pre_std_val=np.std(np.vstack(pre_hist_val),axis=0)##--##

epc = np.arange(0,len(pre_mean_tr))

plt.title('Pre')
plt.plot(epc, pre_mean_tr, label='train')
plt.fill_between(epc,pre_mean_tr-pre_std_tr,pre_mean_tr+pre_std_tr,alpha=.2)
plt.plot(epc, pre_mean_val, label='validation')
plt.fill_between(epc, pre_mean_val-pre_std_val,pre_mean_val+pre_std_val,alpha=.2)
plt.xlabel('# of Epochs')
plt.ylabel('Binary Cross Entropy Loss')
plt.legend()
plt.show()

#### Post 30 Placebo ####

post30_placebo_mean_tr=np.mean(np.vstack(post30_place_hist_tr),axis=0)
post30_placebo_std_tr=np.std(np.vstack(post30_place_hist_tr),axis=0)##--##
post30_placebo_mean_val=np.mean(np.vstack(post30_place_hist_val),axis=0)
post30_placebo_std_val=np.std(np.vstack(post30_place_hist_val),axis=0)##--##

epc = np.arange(0,len(post30_placebo_mean_tr))

plt.title('Post 30 Placebo')
plt.plot(epc, post30_placebo_mean_tr, label='train')
plt.fill_between(epc,post30_placebo_mean_tr-post30_placebo_std_tr,post30_placebo_mean_tr+post30_placebo_std_tr,alpha=.2)
plt.plot(epc, post30_placebo_mean_val, label='validation') 
plt.fill_between(epc,post30_placebo_mean_val-post30_placebo_std_val,post30_placebo_mean_val+post30_placebo_std_val,alpha=.2)
plt.xlabel('# of Epochs')
plt.ylabel('Binary Cross Entropy Loss')
plt.legend()
plt.show()

#### Post 30 Morphine ####

post30_morphine_mean_tr=np.mean(np.vstack(post30_morph_hist_tr),axis=0)
post30_morphine_std_tr=np.std(np.vstack(post30_morph_hist_tr),axis=0)##--##
post30_morphine_mean_val=np.mean(np.vstack(post30_morph_hist_val),axis=0)
post30_morphine_std_val=np.std(np.vstack(post30_morph_hist_val),axis=0)##--##

epc = np.arange(0,len(post30_morphine_mean_tr))

plt.title('Post 30 Morphine')
plt.plot(epc, post30_morphine_mean_tr, label='train')
plt.fill_between(epc,post30_morphine_mean_tr-post30_morphine_std_tr,post30_morphine_mean_tr+post30_morphine_std_tr,alpha=.2)
plt.plot(epc, post30_morphine_mean_val, label='validation')
plt.fill_between(epc,post30_morphine_mean_val-post30_morphine_std_val,post30_morphine_mean_val+post30_morphine_std_val,alpha=.2)
plt.xlabel('# of Epochs')
plt.ylabel('Binary Cross Entropy Loss')
plt.legend()
plt.show()


#### Post 60 Placebo ####

post60_placebo_mean_tr=np.mean(np.vstack(post60_place_hist_tr),axis=0)
post60_placebo_std_tr=np.std(np.vstack(post60_place_hist_tr),axis=0)##--##
post60_placebo_mean_val=np.mean(np.vstack(post60_place_hist_val),axis=0)
post60_placebo_std_val=np.std(np.vstack(post60_place_hist_val),axis=0)##--##

epc = np.arange(0,len(post60_placebo_mean_tr))

plt.title('Post 60 Placebo')
plt.plot(epc,post60_placebo_mean_tr, label='train')
plt.fill_between(epc, post60_placebo_mean_tr-post60_placebo_std_tr,post60_placebo_mean_tr+post60_placebo_std_tr,alpha=.2)
plt.plot(epc,post60_placebo_mean_val, label='validation')
plt.fill_between(epc, post60_placebo_mean_val-post60_placebo_std_val,post60_placebo_mean_val+post60_placebo_std_val,alpha=.2)
plt.xlabel('# of Epochs')
plt.ylabel('Binary Cross Entropy Loss')
plt.legend()
plt.show()


#### Post 60 Morphine ####

post60_morphine_mean_tr=np.mean(np.vstack(post60_morph_hist_tr),axis=0)
post60_morphine_std_tr=np.std(np.vstack(post60_morph_hist_tr),axis=0)##--##
post60_morphine_mean_val=np.mean(np.vstack(post60_morph_hist_val),axis=0)
post60_morphine_std_val=np.std(np.vstack(post60_morph_hist_val),axis=0)##--##

epc = np.arange(0,len(post60_morphine_mean_tr))

plt.title('Post 60 Morphine')
plt.plot(epc,post60_morphine_mean_tr, label='train')
plt.fill_between(epc,post60_morphine_mean_tr-post60_morphine_std_tr,post60_morphine_mean_tr+post60_morphine_std_tr,alpha=.2)
plt.plot(epc,post60_morphine_mean_val, label='validation')
plt.fill_between(epc,post60_morphine_mean_val-post60_morphine_std_val,post60_morphine_mean_val+post60_morphine_std_val,alpha=.2)
plt.xlabel('# of Epochs')
plt.ylabel('Binary Cross Entropy Loss')
plt.legend()
plt.show()

#### Post 90 Placebo ####

post90_placebo_mean_tr=np.mean(np.vstack(post90_place_hist_tr),axis=0)
post90_placebo_std_tr=np.std(np.vstack(post90_place_hist_tr),axis=0)##--##
post90_placebo_mean_val=np.mean(np.vstack(post90_place_hist_val),axis=0)
post90_placebo_std_val=np.std(np.vstack(post90_place_hist_val),axis=0)##--##

epc = np.arange(0,len(post90_placebo_mean_tr))

plt.title('Post 90 Placebo')
plt.plot(epc, post90_placebo_mean_tr, label='train')
plt.fill_between(epc,post90_placebo_mean_tr-post90_placebo_std_tr,post90_placebo_mean_tr+post90_placebo_std_tr,alpha=.2)
plt.plot(epc, post90_placebo_mean_val,label='validation')
plt.fill_between(epc,post90_placebo_mean_val-post90_placebo_std_val,post90_placebo_mean_val+post90_placebo_std_val,alpha=.2)
plt.xlabel('# of Epochs')
plt.ylabel('Binary Cross Entropy Loss')
plt.legend()
plt.show()

#### Post 90 Morphine ####

post90_morphine_mean_tr=np.mean(np.vstack(post90_morph_hist_tr),axis=0)
post90_morphine_std_tr=np.std(np.vstack(post90_morph_hist_tr),axis=0)##--##
post90_morphine_mean_val=np.mean(np.vstack(post90_morph_hist_val),axis=0)
post90_morphine_std_val=np.std(np.vstack(post90_morph_hist_val),axis=0)##--##

epc = np.arange(0,len(post90_morphine_mean_tr))

plt.title('Post 90 Morphine')
plt.plot(epc,post90_morphine_mean_tr, label='train')
plt.fill_between(epc,post90_morphine_mean_tr-post90_morphine_std_tr,post90_morphine_mean_tr+post90_morphine_std_tr,alpha=.2)
plt.plot(epc,post90_morphine_mean_val, label='validation') 
plt.fill_between(epc,post90_morphine_mean_val-post90_morphine_std_val,post90_morphine_mean_val+post90_morphine_std_val,alpha=.2)
plt.xlabel('# of Epochs')
plt.ylabel('Binary Cross Entropy Loss')
plt.legend()
plt.show()

print('Mean Pre Model Acc:' + str(np.mean(np.array(pre_acc)[:,1])) + '+/-'+ str(np.std(np.array(pre_acc)[:,1])))
print('Mean Post 30 Model Morphine Acc:'+ str(np.mean(post30_morph_acc)) + '+/-'+ str(np.std(post30_morph_acc)))
print('Mean Post 30 Model Placebo Acc:' + str(np.mean(post30_place_acc)) + '+/-'+  str(np.std(post90_place_acc)))
print('Mean Post 60 Model Morphine Acc:' + str(np.mean(post60_morph_acc)) + '+/-'+  str(np.std(post60_morph_acc)))
print('Mean Post 60 Model Placebo Acc:' + str(np.mean(post60_place_acc)) + '+/-'+  str(np.std(post60_place_acc)))
print('Mean Post 90 Model Morphine Acc:' + str(np.mean(post90_morph_acc)) + '+/-'+  str(np.std(post90_morph_acc)))
print('Mean Post 90 Model Placebo Acc:' + str(np.mean(post90_place_acc)) + '+/-'+  str(np.std(post90_place_acc)))

def set_axis_style(ax, labels):
    ax.xaxis.set_tick_params(direction='out')
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xticks(np.arange(1, len(labels) + 1))
    ax.set_xticklabels(labels)
    ax.set_xlim(0.25, len(labels) + 0.75)
    ax.set_xlabel('')
    #ax.set_title(region)

def adjacent_values(vals, q1, q3):
    upper_adjacent_value = q3 + (q3 - q1) * 1.5
    upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])

    lower_adjacent_value = q1 - (q3 - q1) * 1.5
    lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
    return lower_adjacent_value, upper_adjacent_value
 
fig, axs = plt.subplots(1,1,figsize=(60, 20),sharex=False, sharey=True,constrained_layout=False)
names = [
    "Pre-Drug",
    "Post Drug \n Morphine \n 30 min",
    "Post Drug \n Morpihine \n 60 min",
    "Post Drug \n Morphine \n  90 min",
    "Post Drug \n Placebo \n  30 min",
    "Post Drug \n Placebo \n  60 min",
    "Post Drug \n Placebo \n  90 min",
    ]
k=0
fs=60
fs2=60
lw=5

data=[]
data.append(np.array(pre_acc)[:,1])
data.append(post30_morph_acc)
data.append(post60_morph_acc)
data.append(post90_morph_acc)
data.append(post30_place_acc)
data.append(post60_place_acc)
data.append(post90_place_acc)

pp=axs.violinplot(data,showmeans=True,showextrema=True,showmedians=True,vert=True, widths=0.5)
for pc in pp['bodies']:
    pc.set_facecolor('red')
    pc.set_edgecolor('red')
    pc.set_linewidth(lw)
pp['cmeans'].set_color('k')
pp['cmeans'].set_linewidth(lw)
pp['cmaxes'].set_color('k')
pp['cmaxes'].set_linewidth(lw)
pp['cmins'].set_color('k')
pp['cmins'].set_linewidth(lw)
pp['cbars'].set_color('k')
pp['cbars'].set_linewidth(lw)
plt.setp(axs.get_xticklabels(), fontsize=fs2, fontweight="bold")
plt.setp(axs.get_yticklabels(), fontsize=fs2, fontweight="bold")
    #axs[i,j].set_xlabel('Feature Set' '\n' '\n' + alph[k] + '\n' ,fontsize=fs,fontweight='bold')
axs.set_ylabel('Accuracy',fontsize=fs2,fontweight='bold')
axs.set_xlabel(names,fontsize=fs2,fontweight='bold')
axs.set_title('Transfer Learning Results',fontsize=60,fontweight='bold')
set_axis_style(axs, names)
plt.show()

features=[]

for i in range(1,25):
    
    features.append('Channel '+str(i))

# ## Averaging shap values for Pre-Model

shap_av_val_pre=[]
shap_av_test_data_pre=[]

for i in range(0,n):
        
    shap_av_val_pre.append(np.abs(list_shap_values_pre[i][0]).mean(axis=0))
    shap_av_test_data_pre.append(shap_test_data_pre[i].mean(axis=0))
        

av_shap_values=np.mean(shap_av_val_pre,axis=0)
av_shap_test_data=np.mean(shap_av_test_data_pre,axis=0)
#shap.image_plot(av_shap_values.reshape(1,av_shap_values.shape[0],av_shap_values.shape[1],1))
av_shap_test_data = pd.DataFrame(data=av_shap_test_data, columns = features)
shap.summary_plot(av_shap_values, features=av_shap_test_data,plot_type='bar',max_display=len(features),title='Pre Model')
plt.show()
        
        
# ## Averaging shap values for Post30-Morphine-Model

shap_av_val_pre=[]
shap_av_test_data_pre=[]

for i in range(0,n):
        
    shap_av_val_pre.append(np.abs(list_shap_values_post30morphine[i][0]).mean(axis=0))
    shap_av_test_data_pre.append(test_morphine_post30_data.mean(axis=0))
        

av_shap_values=np.mean(shap_av_val_pre,axis=0)
av_shap_test_data=np.mean(shap_av_test_data_pre,axis=0)

shap.initjs()

av_shap_test_data = pd.DataFrame(data=av_shap_test_data, columns = features)
shap.summary_plot(av_shap_values, features=av_shap_test_data,plot_type='bar',max_display=len(features),title='Morphine 30 min Model')
plt.show()


# ## Averaging shap values for Post60-Morphine-Model

shap_av_val_pre=[]
shap_av_test_data_pre=[]

for i in range(0,n):
        
    shap_av_val_pre.append(np.abs(list_shap_values_post60morphine[i][0]).mean(axis=0))
    shap_av_test_data_pre.append(test_morphine_post60_data.mean(axis=0))
        

av_shap_values=np.mean(shap_av_val_pre,axis=0)
av_shap_test_data=np.mean(shap_av_test_data_pre,axis=0)

shap.initjs()

av_shap_test_data = pd.DataFrame(data=av_shap_test_data, columns = features)
shap.summary_plot(av_shap_values, features=av_shap_test_data,plot_type='bar',max_display=len(features),title='Morphine 60 min Model')
plt.show()

# ## Averaging shap values for Post90-Morphine-Model

shap_av_val_pre=[]
shap_av_test_data_pre=[]

for i in range(0,n):
        
    shap_av_val_pre.append(np.abs(list_shap_values_post90morphine[i][0]).mean(axis=0))
    shap_av_test_data_pre.append(test_morphine_post90_data.mean(axis=0))
        

av_shap_values=np.mean(shap_av_val_pre,axis=0)
av_shap_test_data=np.mean(shap_av_test_data_pre,axis=0)

shap.initjs()

av_shap_test_data = pd.DataFrame(data=av_shap_test_data, columns = features)
shap.summary_plot(av_shap_values, features=av_shap_test_data,plot_type='bar',max_display=len(features),title='Morphine 90 min Model')
plt.show()

# ## Averaging shap values for Post30-Placebo-Model

shap_av_val_pre=[]
shap_av_test_data_pre=[]

for i in range(0,n):
        
    shap_av_val_pre.append(np.abs(list_shap_values_post30placebo[i][0]).mean(axis=0))
    shap_av_test_data_pre.append(test_placebo_post30_data.mean(axis=0))
        

av_shap_values=np.mean(shap_av_val_pre,axis=0)
av_shap_test_data=np.mean(shap_av_test_data_pre,axis=0)

shap.initjs()

av_shap_test_data = pd.DataFrame(data=av_shap_test_data, columns = features)
shap.summary_plot(av_shap_values, features=av_shap_test_data,plot_type='bar',max_display=len(features),title='Placebo 30 min Model')
plt.show()


# ## Averaging shap values for Post60-Placebo-Model

shap_av_val_pre=[]
shap_av_test_data_pre=[]

for i in range(0,n):
        
    shap_av_val_pre.append(np.abs(list_shap_values_post60placebo[i][0]).mean(axis=0))
    shap_av_test_data_pre.append(test_placebo_post60_data.mean(axis=0))
        

av_shap_values=np.mean(shap_av_val_pre,axis=0)
av_shap_test_data=np.mean(shap_av_test_data_pre,axis=0)

shap.initjs()

av_shap_test_data = pd.DataFrame(data=av_shap_test_data, columns = features)
shap.summary_plot(av_shap_values, features=av_shap_test_data,plot_type='bar',max_display=len(features),title='Placebo 30 min Model')
plt.show()

# ## Averaging shap values for Post90-Placebo-Model

shap_av_val_pre=[]
shap_av_test_data_pre=[]

for i in range(0,n):
        
    shap_av_val_pre.append(np.abs(list_shap_values_post90placebo[i][0]).mean(axis=0))
    shap_av_test_data_pre.append(test_placebo_post90_data.mean(axis=0))
        

av_shap_values=np.mean(shap_av_val_pre,axis=0)
av_shap_test_data=np.mean(shap_av_test_data_pre,axis=0)

shap.initjs()
av_shap_test_data = pd.DataFrame(data=av_shap_test_data, columns = features)
shap.summary_plot(av_shap_values, features=av_shap_test_data,plot_type='bar',max_display=len(features),title='Placebo 30 min Model')
plt.show()



##############------------ITERATION 0 ---------------##############


######----Pre Drug Model------########--


keras is no longer supported, please use tf.keras instead.
Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.




Before TF
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv1d (Conv1D)             (None, 29, 32)            2336      
                                                                 
 max_pooling1d (MaxPooling1D  (None, 14, 32)           0         
 )                                                               
                                                                 
 conv1d_1 (Conv1D)           (None, 12, 64)            6208      
                                                                 
 max_pooling1d_1 (MaxPooling  (None, 6, 64)            0         
 1D)                                                             
                                                                 
 conv1d_2 (Conv1D)           (None, 4, 128)            24704     
                                                                 
 max_pooling1d_2 (MaxPooling  (None, 2, 128)

Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
