#### Brain Network Model fitting using RNN
     Zheng Wang
     
     Two type models: Linear Model and Wong-Wang-Deco Model
     
     

In [9]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import scipy.io

In [10]:
class DecoRNN():
    
    param = {
    
    # Parameters for the integration
    
    "ROI_num" : 96,       # number of neural nodes
    "sigma"      : 0.02,    # standard deviation of the Gaussian noise
    
    # Parameters for the ODEs
    # Excitatory population
    "WE" : 1.,              # scale of the external input
    "tauE" : 100.,          # decay time
    "gamma_E" : 0.641/1000.,       # other dynamic parameter (?)
    
    # Inhibitory population
    "WI" : 0.7,             # scale of the external input
    "tauI" : 10.,           # decay time
    "gamma_I" : 1./1000.,          # other dynamic parameter (?)
    
    # External input
    "I0" : 0.5, # 0.32,          # external input
    "Ie" : 0.,       # external stimulation
    
    # Coupling parameters
    "g" : 100.,               # global coupling (from all nodes E_j to single node E_i)
    "gEE" : .1,            # local self excitatory feedback (from E_i to E_i)
    "gIE" : .1,            # local inhibitory coupling (from I_i to E_i)
    "gEI" : 0.1,            # local excitatory coupling (from E_i to I_i)
    "lamb" : 0.,             # scale of global coupling on I_i (compared to E_i)
       
    "aE":310,
    "bE" :125,
    "dE":0.16,
    "aI":615,
    "bI" :177, 
    "dI" :0.087, 
       
    # Output (BOLD signal)
   
    "alpha" : 0.32,
    "rho" : 0.34,
    "k1" : 2.38,
    "k2" : 2.0,
    "k3" : 0.48, # adjust this number from 0.48 for BOLD fluctruate around zero
    "V" : .02,
    "E0" : 0.34, 
    "tau_s" : 0.65,
    "tau_f" : 0.41,
    "tau_0" : 0.98
   
    } ### initial values of all model prameters
    
    
    def __init__(self, num_nodes):
        
        self.param['ROI_num'] = num_nodes
        self.L= tf.zeros((num_nodes,num_nodes))
        self.num_states = 6
        self.num_states_noise = 2
        self.variables_name = ['E', 'I', 'x', 'f', 'v', 'q']
        
            
        
        
    def dfun(self, X):
        
        E=X[:,0:1]
        I=X[:,1:2]
        x=X[:,2:3]
        f=X[:,3:4]
        v=X[:,4:5]
        q=X[:,5:6]
        
        
       
        def fout(z, alpha):
            return tf.pow(z, 1./alpha)
        def Ef(rho,z):
        
            return 1.0-tf.pow(1.-rho, 1./z)
        def h_tf(a, b, d, z):
            return (0.00001+tf.abs(a*z-b))/(0.00001*d+tf.abs(1.0000 -tf.exp(-d*(a*z-b))))

       
        IE = tf.nn.relu(self.param['WE']*self.param['I0'] + self.param['gEE']*E + self.param['g']*tf.matmul(self.L, E)\
                        -self.param['gIE']*(I) + self.param['Ie'])
        II = tf.nn.relu(self.param['WI']*self.param['I0'] + self.param['gEI']*E - I)
        
        #IE = 125/310.0 -0.026 +0.006*tf.tanh(IE/100. - 125/310.0 + 0.026) 
        #(3.0631+0.4869*tf.tanh(h_tf(aE, bE, dE, IE)-3.0631)) 
        E_new = -E/self.param['tauE'] +(1.0 -E)*self.param['gamma_E']*h_tf(self.param['aE'], self.param['bE'], \
                                                                           self.param['dE'], IE)
        I_new = - I/self.param['tauI'] +self.param['gamma_I']*h_tf(self.param['aI'], self.param['bI'], self.param['dI'], II) 
    
    
        dx = E - I -1.0/self.param['tau_s']*x  -1.0/self.param['tau_f']*(f-1.0)
        df = x 
        dv = (f -v**(1./self.param['alpha']))/self.param['tau_0'] #f/self.param['tau_0'] -fout(v, self.param['alpha'])/self.param['tau_0']
        dq = (f*(1.-(1.-self.param['rho'])**(1./f))/self.param['rho'] -q*(v)**(1./self.param['alpha'])/v)/self.param['tau_0']
        #f*Ef(self.param['rho'],f)/self.param['rho']/self.param['tau_0']-q*fout(v, self.param['alpha'])/v/self.param['tau_0']
        
        return tf.concat([E_new, I_new, dx, df, dv, dq], axis=1)
    
    
    
    

In [11]:
class Cost_fun():
    def __init__(self, logits_series, labels_series, batch_size):
        if len(logits_series) == len(labels_series):
            for i in range(len(logits_series)):
                if not logits_series[0].shape == labels_series[0].shape:
                    print('not matching')
                    break
            
            self.logits_series = logits_series
            self.labels_series = labels_series
            self.batch_size = batch_size
     

    def cost_r(self):
        
        labels_series_tf = tf.stack(self.labels_series,axis=1)
        logits_series_tf = tf.stack(self.logits_series,axis=1)

        labels_series_tf_n = labels_series_tf - tf.matmul(tf.reshape(tf.reduce_mean(labels_series_tf, 1), [self.batch_size,1]),\
                            tf.constant(np.ones((1,truncated_backprop_length)), dtype=tf.float32))
        logits_series_tf_n = logits_series_tf - tf.matmul(tf.reshape(tf.reduce_mean(logits_series_tf, 1), [self.batch_size,1]),\
                            tf.constant(np.ones((1,truncated_backprop_length)), dtype=tf.float32))


        cov_sim =tf.matmul(logits_series_tf_n, tf.transpose(logits_series_tf_n))
        cov_def= tf.matmul(labels_series_tf_n, tf.transpose(labels_series_tf_n))


        FC_sim_T = tf.matmul(tf.matmul(tf.diag(tf.reciprocal(tf.sqrt(tf.diag_part(cov_sim)))), cov_sim), \
                     tf.diag(tf.reciprocal(tf.sqrt(tf.diag_part(cov_sim)))))
        FC_T = tf.matmul(tf.matmul(tf.diag(tf.reciprocal(tf.sqrt(tf.diag_part(cov_def)))), cov_def), \
                 tf.diag(tf.reciprocal(tf.sqrt(tf.diag_part(cov_def))))) 
        ones_tri=tf.matrix_band_part(tf.ones_like(FC_T)-tf.diag(tf.ones((self.batch_size,))), 0, -1)
        zeros = tf.zeros_like(FC_T) # create a tensor all ones
        mask = tf.greater(ones_tri, zeros) # boolean tensor, mask[i] = True iff x[i] > 1
        FC_tri_v = tf.boolean_mask(FC_T, mask)

        FC_v = FC_tri_v - tf.reduce_mean(FC_tri_v)*tf.ones_like(FC_tri_v)


        FC_sim_tri_v = tf.boolean_mask(FC_sim_T, mask)
        FC_sim_v = FC_sim_tri_v - tf.reduce_mean(FC_sim_tri_v)*tf.ones_like(FC_sim_tri_v)

        corr_FC =tf.reduce_sum(tf.multiply(FC_v,FC_sim_v))\
                  /tf.sqrt(tf.reduce_sum(tf.multiply(FC_v,FC_v)))\
                /tf.sqrt(tf.reduce_sum(tf.multiply(FC_sim_v,FC_sim_v)))
        
        losses_corr = tf.square(1- corr_FC)
        losses = tf.sqrt(tf.reduce_mean(tf.multiply(FC_sim_tri_v-FC_tri_v, FC_sim_tri_v-FC_tri_v)))
        return losses
    def cost_dist(self):
        
        labels_series_tf = tf.stack(self.labels_series,axis=1)
        logits_series_tf = tf.stack(self.logits_series,axis=1)

        labels_series_tf_n = labels_series_tf - tf.matmul(tf.reshape(tf.reduce_mean(labels_series_tf, 1), [self.batch_size,1]),\
                            tf.constant(np.ones((1,truncated_backprop_length)), dtype=tf.float32))
        logits_series_tf_n = logits_series_tf - tf.matmul(tf.reshape(tf.reduce_mean(logits_series_tf, 1), [self.batch_size,1]),\
                            tf.constant(np.ones((1,truncated_backprop_length)), dtype=tf.float32))


        cov_sim =tf.matmul(logits_series_tf_n, tf.transpose(logits_series_tf_n))
        cov_def= tf.matmul(labels_series_tf_n, tf.transpose(labels_series_tf_n))


        FC_sim_T = tf.matmul(tf.matmul(tf.diag(tf.reciprocal(tf.sqrt(tf.diag_part(cov_sim)))), cov_sim), \
                     tf.diag(tf.reciprocal(tf.sqrt(tf.diag_part(cov_sim)))))
        FC_T = tf.matmul(tf.matmul(tf.diag(tf.reciprocal(tf.sqrt(tf.diag_part(cov_def)))), cov_def), \
                 tf.diag(tf.reciprocal(tf.sqrt(tf.diag_part(cov_def))))) 
        ones_tri=tf.matrix_band_part(tf.ones_like(FC_T)-tf.diag(tf.ones((self.batch_size,))), 0, -1)
        zeros = tf.zeros_like(FC_T) # create a tensor all ones
        mask = tf.greater(ones_tri, zeros) # boolean tensor, mask[i] = True iff x[i] > 1
        FC_tri_v = tf.boolean_mask(FC_T, mask)

        FC_v = FC_tri_v - tf.reduce_mean(FC_tri_v)*tf.ones_like(FC_tri_v)


        FC_sim_tri_v = tf.boolean_mask(FC_sim_T, mask)
        FC_sim_v = FC_sim_tri_v - tf.reduce_mean(FC_sim_tri_v)*tf.ones_like(FC_sim_tri_v)

        losses = tf.sqrt(tf.reduce_mean(tf.multiply(FC_sim_v-FC_v, FC_sim_v-FC_v)))
        
        losses_dist = tf.reduce_mean(losses)
        return losses_dist

In [16]:
class Model_graph():
    
    ### default fit model parameteres:
    
    fit_param = ['g', 'gEE', 'gIE', 'gEI', 'sigma']
    
        
    def __init__(self, step_size, Tr, truncated_backprop_length, f,  fit_conecctiongains):
        #self.args = kwargs
       
        def smooth_normalize_ct(x, center):
            x_n = center+ (center-0.0001)*tf.tanh((x-center)/(center - 0.0001))
            return x_n
        
        tf.reset_default_graph()
        
        self.fit_cg = fit_conecctiongains
        
        self.dt = step_size
        self.truncated_backprop_length = truncated_backprop_length
        self.Tr = Tr
        self.hidden_num = np.int(Tr/step_size)
        self.num_states = f.num_states
        self.num_states_noise = f.num_states_noise
        self.batch_size = f.param['ROI_num']
        self.num_inputs = 1+self.truncated_backprop_length*(1+self.num_states_noise*self.hidden_num)
        
        self.adjacency = tf.placeholder(tf.float32, [self.batch_size, self.batch_size])
        self.batchX_placeholder = tf.placeholder(tf.float32, [self.batch_size, self.num_inputs])
        self.batchY_placeholder = tf.placeholder(tf.float32, [self.batch_size, self.truncated_backprop_length])

        f.L = -tf.diag(tf.reduce_sum(self.adjacency, axis =1)) + self.adjacency
        #L = -tf.diag(tf.reduce_sum(self.adjacency, axis =1)) + self.adjacency
        self.f = f
        
        self.init_paras = [80, 0.5, 0.5, 1.0, 0.02]
        
       
        variables_ls =[]
        
        for para in range(len(self.fit_param)):
            
            self.f.param[self.fit_param[para]] = tf.Variable(self.init_paras[para], dtype=tf.float32)
            variables_ls.append(self.f.param[self.fit_param[para]])
            
        self.variables_ls = tuple(variables_ls)
        self.Ws = tf.Variable(0.05+np.zeros((self.f.param['ROI_num'],self.f.param['ROI_num'])), dtype=tf.float32)
        
        if self.fit_cg == True:
            
            W_n= (self.Ws+tf.transpose(self.Ws))/2.0
            W_n = tf.exp(W_n)*self.adjacency#W_n - tf.diag(tf.diag_part(W_n))#tf.exp(2.0*W_n) * L_new# (tf.ones((batch_size, batch_size))- tf.diag(tf.diag_part(tf.ones((batch_size, batch_size)))))#

            W_s =W_n/tf.norm(W_n)
            self.f.L= -tf.diag(tf.reduce_sum(W_s, axis=1)) + W_s
        
        
        
        
       
        
        for para in self.fit_param:
            if para != 'sigma':
                self.f.param[para] = 0.001+tf.nn.relu(self.f.param[para])
            else:
                self.f.param[para] = 0.01+tf.nn.relu(self.f.param[para])
            
            
            
        
                
        self.init_state = tf.placeholder(tf.float32, [self.batch_size, self.num_states])
        
        
        # Unpack columns

        inputs_series = tf.unstack(self.batchX_placeholder, axis=1)
        labels_series = tf.unstack(self.batchY_placeholder, axis=1)
        print(len(inputs_series))
        
        current_state =  self.init_state
        states_series = []
        for j in range(self.truncated_backprop_length):
            #print(j)
            for i in range(self.hidden_num):
                #print(i)
                noises=[]
                for s in range(self.num_states_noise):
                    noises.append(inputs_series[s*self.truncated_backprop_length*self.hidden_num\
                                          +j*self.hidden_num+i])
                
                current_state = current_state + self.dt*self.f.dfun(current_state)
                #print('next',next_state.shape)
                current_states_series = tf.unstack(current_state, axis = 1)
                current_states_series[0] = tf.tanh(current_states_series[0] \
                              +tf.sqrt(self.dt)*self.f.param['sigma']*noises[0])
                current_states_series[1] = tf.tanh(current_states_series[1] \
                              +tf.sqrt(self.dt)*self.f.param['sigma']*noises[1])
                current_states_series[2] = tf.tanh(current_states_series[2])
                
                current_states_series[3] = smooth_normalize_ct(current_states_series[3],1.0)# + tf.tanh(current_states_series[3]-1.0)
                current_states_series[4] = smooth_normalize_ct(current_states_series[4],1.0)#1.0 + tf.tanh(current_states_series[4]-1.0)
                current_states_series[5] = smooth_normalize_ct(current_states_series[5],1.0)#1.0 + tf.tanh(current_states_series[5]-1.0)
                
                
                current_state = tf.stack(current_states_series, axis = 1)
                #print(current_states.shape)
            #next_state = E_new
            states_series.append(current_state)# Broadcasted addition
    
            #current_state = next_state 
        
        noise2 = inputs_series[-truncated_backprop_length-1:-1]
        
        logits_series = [50/0.34*0.02*(self.f.param['k1']*(1.0 -state[:, 5])\
                 +self.f.param['k2']*(1.0-state[:, 5]/state[:, 4])\
                 +self.f.param['k3']*(1.0-state[:, 4]))\
                 +0.02*noise   for state, noise in zip(states_series, noise2)]

        print(logits_series[0].shape)
        print(labels_series[0].shape)
        self.logits_series=logits_series
        self.states_series=states_series
        self.current_state = current_state
        cost = Cost_fun(logits_series, labels_series, self.f.param['ROI_num'])
        self.cost = cost
   
    
        total_loss = self.cost.cost_dist() #+ 0.1*(self.f.param['gEE'] + self.f.param['gIE'] -0.3)**2 + 0.1*(self.f.param['gEI'] -1.)**2
        


        opt_func = tf.train.AdamOptimizer(.01)




        train_step=opt_func.minimize(total_loss)
        
        
        
        self.train_step = train_step
        self.total_loss = total_loss
    
       
   
        

In [17]:

import argparse
import time
import os
from six.moves import cPickle


"""parser = argparse.ArgumentParser(
                    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Data and model checkpoints directories
parser.add_argument('data_dir=', type=str, default='/brunhild/mcintosh_lab/jwang/results/PPMI/',
                    help='data directory containing all subjects')
parser.add_argument('save_dir=', type=str, default='save',
                    help='directory to store checkpointed models')
parser.add_argument('log_dir=', type=str, default='logs',
                    help='directory to store tensorboard logs')
parser.add_argument('save_every=', type=int, default=1000,
                    help='Save frequency. Number of passes between checkpoints of the model.')
parser.add_argument('init_from=', type=str, default=None,
                    help=continue training from saved model at this path (usually "save").
                        Path must contain files saved by previous training process:
                        'config.pkl'        : configuration;
                        
                        'checkpoint'        : paths to model file(s) (created by tf).
                                              Note: this file contains absolute paths, be careful when moving files around;
                        'model.ckpt-*'      : file(s) with model definition (created by tf)
                         Model params must be the same between multiple runs (model, rnn_size, num_layers and seq_length))

# Optimization
parser.add_argument('truncated_backprop_length=', type=int, default=15,
                    help='RNN sequence length. Number of timesteps to unroll for.')
parser.add_argument('batch_size=', type=int, default=96,
                    help=node number)

parser.add_argument('Tr=', type=int, default=.05,
                    help='Tr fmri')
parser.add_argument('step_size=', type=int, default=.001,
                    help='Tr fmri')

parser.add_argument('step_size=', type=int, default=.001,
                    help='Integration step')
args = parser.parse_args()

import tensorflow as tf
"""







class Brainmodelfitting( ):
    
    
    def __init__(self, model_graph, data_set, echo_step, num_epochs, out_dir):
        
        self.model = model_graph
        self.echo_step = echo_step
        self.Tr = model_graph.Tr
        self.step_size = model_graph.dt
        self.truncated_backprop_length = model_graph.truncated_backprop_length
        self.batch_size = model_graph.batch_size
        self.data_dir = data_dir
        self.out_dir = out_dir
        self.num_epochs = num_epochs
        
        self.data_set = data_set
    
        
    
    
    
    def plot(self, batch_idx, y_array, E_array, I_array, x_array, v_array, f_array, q_array, params_list, loss_list, batchX, batchY):
       
        plt.subplot(4, 3, 1)
        plt.cla()
        plt.plot(loss_list)
    
        plt.subplot(4, 3, 2)
        plt.cla()
        plt.plot(np.array(params_list)[:,0])
    
        plt.subplot(4, 3, 4)
        plt.cla()
        plt.plot(np.array(params_list)[:,5:])
        
        plt.subplot(4, 3, 3)
        plt.cla()
        plt.plot(y_array.T)
        
        plt.subplot(4, 3, 5)
        plt.cla()
        plt.plot(E_array.T)
        
        plt.subplot(4, 3, 6)
        plt.cla()
        plt.plot(I_array.T)
        
        plt.subplot(4, 3, 7)
        plt.cla()
        plt.plot(x_array.T)
        plt.subplot(4, 3, 8)
        plt.cla()
        plt.plot(v_array.T)
        plt.subplot(4, 3, 9)
        plt.cla()
        plt.plot(f_array.T)
        plt.subplot(4, 3, 10)
        plt.cla()
        plt.plot(q_array.T)
        
        plt.subplot(4, 3, 11)
        plt.cla()
        plt.plot(np.array(params_list)[:,1:2])
        
        plt.subplot(4, 3, 12)
        plt.cla()
        plt.plot(np.array(params_list)[:,3])
        

        plt.draw()
        plt.pause(0.0001)
    
    
    def train(self, subID, SC, TS):
        
        TS_len = TS.shape[0] - self.echo_step
        num_batches = TS_len // self.truncated_backprop_length
        
        with tf.Session() as sess:
            
            

            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(tf.global_variables())
            
            plt.ion()
            plt.figure()
            plt.show()
        
            loss_list = []
            params_list=[]
    
   
             
    
            for epoch_idx in range(self.num_epochs):
        
                y_in= TS.T[:,echo_step:TS_len]/np.max(TS)
                
                y_array =np.zeros((self.batch_size,num_batches*self.truncated_backprop_length))
                E_array =np.zeros((self.batch_size,num_batches*self.truncated_backprop_length))
                I_array =np.zeros((self.batch_size,num_batches*self.truncated_backprop_length))
                x_array =np.zeros((self.batch_size,num_batches*self.truncated_backprop_length))
                v_array =np.zeros((self.batch_size,num_batches*self.truncated_backprop_length))
                f_array =np.zeros((self.batch_size,num_batches*self.truncated_backprop_length))
                q_array =np.zeros((self.batch_size,num_batches*self.truncated_backprop_length))
                initi_con = .45*(np.random.uniform(0,1,[self.batch_size, self.model.num_states])+np.array([[0, 0, 0,.5,.5,.5]]))
                _current_state = initi_con
                
                for batch_idx in range(num_batches):
                
                    start_idx = batch_idx * self.truncated_backprop_length
                    end_idx = start_idx + self.truncated_backprop_length

                    batchX = np.random.randn(self.batch_size, 1+self.truncated_backprop_length +\
                                    self.model.num_states_noise*self.truncated_backprop_length* \
                                    self.model.hidden_num)
                    batchY = y_in[:,start_idx:end_idx]
                    #print(self.model.variables_scalar)

                    _variables_s,  _Wg, _current_state, _total_loss, _train_step,\
                               _logits_series, _states_series,  = sess.run(
                             [self.model.variables_ls, self.model.Ws, self.model.current_state, \
                              self.model.total_loss, self.model.train_step, self.model.logits_series, self.model.states_series],
                        feed_dict={
                            self.model.batchX_placeholder:batchX,
                            self.model.batchY_placeholder:batchY,
                            self.model.init_state:_current_state,
                            self.model.adjacency:SC
                            })

                    params_new = []
                    loss_list.append(_total_loss)
                    for par_s in _variables_s:
                        params_new.append(par_s)
                            
                    if self.model.fit_cg == True:
                        params_new.extend(list(_Wg.ravel()))
                    params_list.append(np.array(params_new))
                
                    for i in range(self.truncated_backprop_length):
                        start_idx = batch_idx * self.truncated_backprop_length
                
                        y_array[:, start_idx+i] = _logits_series[i]
                        E_array[:, start_idx+i] = _states_series[i][:,0]
                        I_array[:, start_idx+i] = _states_series[i][:,1]
                        x_array[:, start_idx+i] = _states_series[i][:,2]
                        f_array[:, start_idx+i] = _states_series[i][:,3]
                        v_array[:, start_idx+i] = _states_series[i][:,4]
                        q_array[:, start_idx+i] = _states_series[i][:,5]
                    
                            
                    
                
            saver.save(sess, self.out_dir+subID+ 'model.checkpoint')         
        
            #print( np.array(params_list))
            
            self.plot(batch_idx, y_array, E_array, I_array, x_array, v_array, f_array, q_array, params_list, loss_list,  batchX, batchY)        
            plt.ioff()
            plt.show()
            np.savetxt(self.out_dir+subID+ 'paramsList.txt', np.array(params_list))
            np.savetxt(self.out_dir+subID+ 'sim_fitting_bold.txt', y_array.T)
            #y_dmean =(y_array.T -y_array.T.mean(axis= 0)).T
            FC_sim = np.corrcoef(y_array[:,10:])
    
            FC = np.corrcoef(TS.T)
            corr_simfit= np.corrcoef(FC[np.tril_indices(self.batch_size,-1)], FC_sim[np.tril_indices(self.batch_size,-1)])[0,1]
            print(corr_simfit)
        
            fig, ax = plt.subplots(1,3, figsize=(20,4))
            ax[0].plot(TS)
            img1 =ax[1].imshow(FC_sim -np.diag(np.diag(FC_sim)), cmap='bwr')
            plt.colorbar(img1, ax=ax[1], fraction=0.046, pad=0.04)
            img2 =ax[2].imshow(FC, cmap='bwr')
            plt.colorbar(img2, ax=ax[2], fraction=0.046, pad=0.04)
            plt.show()
                       
            return corr_simfit
        
        
                       
    
    
    
   
                
   
    

In [18]:
data_dir='/brunhild/mcintosh_lab/jwang/ModelFitting/TestData/'
out_dir='/brunhild/mcintosh_lab/jwang/ModelFitting/test/'
data_set='TestData'
#model_name='Linear'


In [19]:
groups =[grp for grp in os.listdir(data_dir)]

f = DecoRNN(96)
batch_size = 96
step_size = 0.05
Tr= 2.5
truncated_backprop_length = 15

num_epochs = 300
echo_step = 0

#graph_lin = Model_graph(step_size, Tr, truncated_backprop_length, f, paras_lin, cost_name)
graph_deco = Model_graph(step_size, Tr, truncated_backprop_length, f, True)
#F = Brainmodelfitting(graph_lin, data_set,echo_step, num_epochs, out_dir)
F = Brainmodelfitting(graph_deco, data_set,echo_step, num_epochs, out_dir)


1516
(96,)
(96,)


In [20]:
groups =['CON_MCI']

In [None]:
para_corr={}
for grp in groups:
    print(grp)
    grp_dir = data_dir +grp + '/'
    subs = [sub for sub in os.listdir(grp_dir) if str.isdigit(sub)]
    for i in range(len(subs)):
        subID =subs[i]
            
        file_path= '/brunhild/mcintosh_lab/jwang/ModelFitting/' + data_set +'/'
        if not os.path.exists(file_path):
            os.mkdir(file_path)
        print(i, subID)
        SC_file = grp_dir + subID + '/preprocess/connectivity/SC/SC.txt'
        TS_file = grp_dir + subID + '/preprocess/connectivity/FC/TS.txt'
        if os.path.isfile(SC_file) and os.path.isfile(TS_file):
            para_corr[data_set+subID]= []
            SC= np.loadtxt(SC_file)
            TS= np.loadtxt(TS_file)
            TS_dmean =(TS.T -TS.T.mean(axis= 0)).T
            SC = (SC+SC.T)*0.5
            """SC1=SC[:batch_size//2,:batch_size//2].copy()
            SC2=SC[batch_size//2:batch_size,batch_size//2:batch_size].copy()
            SC3=SC[:batch_size//2,batch_size//2:batch_size].copy()
            mask1 = (SC1-SC1.mean(axis=1)< 2.*SC1.std(axis=1)) 
            SC1[mask1]=0
            SC2[(SC2-SC2.mean(axis=1) <  2*SC2.std(axis=1)) ]=0
            SC3[(SC3-SC3.mean(axis=1)<  2*SC3.std(axis=1)) ]=0
            SC[:batch_size//2,:batch_size//2] = SC1
            SC[batch_size//2:batch_size,batch_size//2:batch_size] = SC2
            SC[:batch_size//2,batch_size//2:batch_size] = 1*SC3
            SC[batch_size//2:batch_size,:batch_size//2] = 1*SC3.T"""
            Wo = np.log1p(SC)/np.linalg.norm(np.log1p(SC))
                #L_s = (-np.diag(np.sum(W0, axis= 1)) + W0).astype(np.float32)
                
            corr_fit = F.train(subID, Wo, TS_dmean)
            """para_corr[data_set+subID].append(corr_fit)
            corr_sim =F.test(subID, Wo, TS_dmean)
            para_corr[data_set+subID].append(corr_sim)
            params = np.loadtxt(out_dir+subID+'paramsList.txt')
            Theta = list(params[-10:,:].mean(axis = 0))
            para_corr[data_set+subID].extend(Theta)"""
                

CON_MCI
0 4032


<Figure size 432x288 with 0 Axes>

In [12]:
batch_size//2

48

In [10]:
f


<__main__.DecoRNN at 0x7f128b5de8d0>

In [13]:
f.param['sigma']

<tf.Tensor 'add_5:0' shape=() dtype=float32>

In [12]:
f.L

<tf.Tensor 'add:0' shape=(96, 96) dtype=float32>

In [29]:
graph_deco.variables_ls

[<tf.Variable 'Variable:0' shape=() dtype=float32_ref>,
 <tf.Variable 'Variable_1:0' shape=() dtype=float32_ref>,
 <tf.Variable 'Variable_2:0' shape=() dtype=float32_ref>,
 <tf.Variable 'Variable_3:0' shape=() dtype=float32_ref>,
 <tf.Variable 'Variable_4:0' shape=() dtype=float32_ref>]