Approximate run duration: 5 hours

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [3]:
cd /content/drive/MyDrive/Work/Sciencing/Google_CoLab_projects/WWD_Model/WWD_pytorch/wwd-model-fitting/final/

/content/drive/MyDrive/Work/Sciencing/Google_CoLab_projects/WWD_Model/WWD_pytorch/wwd-model-fitting/final


In [4]:
import sys
sys.path.append('code/')

In [5]:
import torch
from torch.nn.parameter import Parameter

In [6]:

from rww_pytorch_model import RNNWWD
from rww_pytorch_model import Model_fitting
import matplotlib.pyplot as plt # for plotting
import numpy as np # for numerical operations
import pandas as pd # for data manipulation
import seaborn as sns # for plotting 
import time # for timer
import os

In [8]:
import os
out_dir = '../data/HCP/'
base_dir = '../data/HCP/'
subs =sorted([sc_file[-10:-4] for sc_file in os.listdir(base_dir) if sc_file[:8] == 'weights_'])


In [9]:
def plot_sim_states_outputs(ts, output):
    """
    Plot the simulation states of trained input parameters. 

    Parameters
    ----------
    ts_sim: tensor with node_size X datapoint
        simulated BOLD
    ts: tensor with node_size X datapoint
        empirical BOLD
    E_sim: tensor with node_size X datapoint 
        simulated E 
    I_sim: tensor with node_size X datapoint
        simulated I 
    x_sim: tensor with node_size X datapoint
        simulated x 
    f_sim: tensor with node_size X datapoint
        simulated f 
    v_sim: tensor with node_size X datapoint
        simulated v 
    q_sim: tensor with node_size X datapoint
        simulated q 
    """
    ts_sim = output['simBOLD']
    E_sim = output['E']
    I_sim = output['I'] 
    x_sim = output['x'] 
    f_sim = output['f']
    v_sim = output['v'] 
    q_sim = output['q']
    fig, ax = plt.subplots(5, 2, figsize=(12,8))
    im1 = ax[0,0].imshow(np.corrcoef(ts_sim), cmap = 'bwr')
    ax[0,0].set_title('simFC')
    fig.colorbar(im1, ax=ax[0,0])
    im2 = ax[0,1].imshow(np.corrcoef(ts.T), cmap = 'bwr')
    ax[0,1].set_title('expFC')
    fig.colorbar(im2, ax=ax[0,1])
    ax[1,0].plot(ts_sim.T)
    ax[1,0].set_title('simBOLD')
    ax[1,1].plot(ts)
    ax[1,1].set_title('expBOLD')
    ax[2,0].plot(E_sim.T)
    ax[2,0].set_title('sim E')
    ax[2,1].plot(I_sim.T)
    ax[2,1].set_title('sim I')
    ax[3,0].plot(x_sim.T)
    ax[3,0].set_title('sim x')
    ax[3,1].plot(f_sim.T)
    ax[3,1].set_title('sim f')
    ax[4,0].plot(v_sim.T)
    ax[4,0].set_title('sim v')
    ax[4,1].plot(q_sim.T)
    ax[4,1].set_title('sim q')
    plt.show()
    
def plot_fit_parameters(output):
    g_par = output['g'] 
    gEE_par = output['gEE'] 
    gIE_par = output['gIE'] 
    gEI_par = output['gEI'] 
    g_mean_par = output['gmean'] 
    g_var_par = output['gvar'] 
    cA_par = output['cA'] 
    cB_par = output['cB'] 
    cC_par = output['cC'] 
    sigma_par = output['sigma_state']
    sigma_out_par = output['sigma_bold']
    """
    Plot the simulation states of fitted input parameters. 

    Parameters
    ----------
    g_par: list of fitted parameter values
        for g
    gEE_par: list of fitted parameter values
        for gEE
    gIE_par: list of fitted parameter values
        for gIE
    gEI_par: list of fitted parameter values
        for gEI
    sc_par: list of fitted parameter values
        for structural connectivity
    sc_par: list of fitted parameter values
        for sigma
    """
    fig, ax = plt.subplots(6,2, figsize=(12,8))
    im1 = ax[0,0].plot(g_par)
    ax[0,0].set_title('g')
    
    ax[0,1].plot(gEE_par)
    ax[0,1].set_title('gEE')
    
    ax[1,0].plot(gIE_par)
    ax[1,0].set_title('gIE')
    ax[1,1].plot(gEI_par)
    ax[1,1].set_title('gEI')
    
    ax[2,0].plot(sigma_par)
    ax[2,0].set_title('sc')

    ax[2,1].plot(sigma_out_par)
    ax[2,1].set_title('$\sigma$')
    ax[3,0].plot(g_mean_par)
    ax[3,0].set_title('post mean: g')

    ax[3,1].plot(g_var_par)
    ax[3,1].set_title('post var: g')

    ax[4,0].plot(cA_par)
    ax[4,0].set_title('post poly:A')

    ax[4,1].plot(cB_par)
    ax[4,1].set_title('post poly:B')

    ax[5,0].plot(cC_par)
    ax[5,0].set_title('post poly:C')

In [None]:
start_time = time.time()





for i in range(0, 40):
    
    
    node_size = 83
    mask = np.tril_indices(node_size, -1)
    num_epoches = 20
    batch_size = 20
    step_size = 0.05
    input_size = 2
    tr = 0.75
    sub=subs[i]
    print(i, sub)
    sc_file = base_dir +'weights_'+sub+'.txt'
    ts_file = base_dir +sub+'_rfMRI_REST1_LR_hpc200_clean__l2k8_sc33_ts.pkl'#out_dir+'sub_'+sub+'simBOLD_idt.txt'#
   
    if os.path.isfile(sc_file) and os.path.isfile(ts_file):
        sc = np.loadtxt(sc_file)
        SC =(sc+sc.T)*0.5
       
        sc = np.log1p(SC)/np.linalg.norm(np.log1p(SC))

        

        sc_mod = np.zeros_like(sc)
        
        ts_pd = pd.read_pickle(ts_file)
        ts = ts_pd.values
        #ts = np.loadtxt(ts_file)
        ts =ts/np.max(ts)
        fc_emp = np.corrcoef(ts.T)
        # Get the WWD model module for forward in a batch. 
        

        model = RNNWWD(input_size, node_size, batch_size, step_size, tr, sc, False, g_mean_ini=100, g_std_ini = .1, gEE_mean_ini=2.5, gEE_std_ini = .1)


        
        # call model fit method
        F = Model_fitting(model, ts, num_epoches)

        # fit data(train)

        output_train = F.train()

        
        output_test = F.test(120)
        plot_fit_parameters(output_train)
        plot_sim_states_outputs(ts, output_test)
        
        
        np.savetxt(out_dir + 'bold_test_4p_'+ sub +'.txt', output_test['simBOLD'])
        np.savetxt(out_dir + 'bold_train_4p_'+ sub +'.txt', output_train['simBOLD'])
        #np.savetxt(out_dir + 'sc_mod_'+ sub +'.txt', sc_mod)
        #np.savetxt(out_dir + 'sc_'+ sub +'.txt', sc)
        g=  output_train['g'][-100:].mean()
        gEE = output_train['gEE'][-100:].mean()
        gEI = output_train['gEI'][-100:].mean()
        gIE = output_train['gIE'][-100:].mean()


        np.savetxt(out_dir + 'parameters_4p_'+ sub +'.txt', np.array([g,gEE, gIE, gEI]).T)
end_time =  time.time()
print('running time is  {0} \'s'.format(end_time - start_time ))

In [None]:
for i in range(0, 40):
    sub=subs[i]
    print(i, sub)
    

    node_size = 83
    num_epoches = 20
    batch_size =20
    step_size = 0.05
    input_size = 2
    tr = 0.75
    sc_file = base_dir +'weights_'+sub+'.txt'
    ts_file = out_dir + 'bold_test_4p_'+ sub +'.txt'
   
    if os.path.isfile(sc_file) and os.path.isfile(ts_file):
        
        sc = np.loadtxt(sc_file)
        SC =(sc+sc.T)*0.5
       
        sc = np.log1p(SC)/np.linalg.norm(np.log1p(SC))

        

        
        
        
        ts = np.loadtxt(ts_file)
        ts = (ts/np.max(ts))[:,:1200]
        fc_emp = np.corrcoef(ts[:,:1200])
        print(fc_emp.shape)
        # Get the WWD model module for forward in a batch. 
        

        model = RNNWWD(input_size, node_size, batch_size, step_size, tr, sc, False,  g_mean_ini=100, g_std_ini = .1, gEE_mean_ini=2.5, gEE_std_ini = .1)


        
        # call model fit method
        F = Model_fitting(model, ts.T, num_epoches)

        # fit data(train)

        output_train = F.train()

        
        output_test = F.test(120)
        plot_fit_parameters(output_train)
        plot_sim_states_outputs(ts.T, output_test)
        
        np.savetxt(out_dir + 'bold_test_idt_4p_'+ sub +'.txt', output_test['simBOLD'])
        np.savetxt(out_dir + 'bold_train_idt_4p_'+ sub +'.txt', output_train['simBOLD'])
        #np.savetxt(out_dir + 'sc_mod_idt_'+ sub +'.txt', sc_mod)
        #np.savetxt(out_dir + 'sc_'+ sub +'.txt', sc)
        g=  output_train['g'][-100:].mean()
        gEE = output_train['gEE'][-100:].mean()
        gEI = output_train['gEI'][-100:].mean()
        gIE = output_train['gIE'][-100:].mean()


        np.savetxt(out_dir + 'parameters_idt_4p_'+ sub +'.txt', np.array([g,gEE, gIE, gEI]).T)
end_time =  time.time()
print('running time is  {0} \'s'.format(end_time - start_time ))

In [None]:
HCP_ts_sim ={}
HCP_ts_test ={}
HCP_paras = {}
HCP_ts = {}
HCP_sc = {}
for i in range(40):
    sub = subs[i]
    
    """ts_file = base_dir +sub+'_rfMRI_REST1_LR_hpc200_clean__l2k8_sc33_ts.pkl'#out_dir+'sub_'+sub+'simBOLD_idt.txt'#
   
    ts_pd = pd.read_pickle(ts_file)
    ts = ts_pd.values
    HCP_ts[sub] = ts"""
    HCP_ts_test[sub] = np.loadtxt(out_dir + 'bold_test_4p_'+ sub +'.txt')
    HCP_ts_sim[sub] = np.loadtxt(out_dir + 'bold_train_4p_'+ sub +'.txt')
    
    #HCP_sc[sub] = np.loadtxt(out_dir + 'sc_'+ sub +'.txt')
    
    HCP_paras[sub] = np.loadtxt(out_dir + 'parameters_4p_'+ sub +'.txt')
    
np.save(out_dir +'HCP_ts_sim_4p.npy', HCP_ts_sim)
np.save(out_dir + 'HCP_ts_test_4p.npy', HCP_ts_test)
#np.save(out_dir + 'HCP_ts.npy', HCP_ts)
#np.save(out_dir + 'HCP_sc.npy', HCP_sc)
np.save(out_dir + 'HCP_fitparas_4p.npy', HCP_paras)

In [None]:
HCP_ts_sim ={}
HCP_ts_test ={}
HCP_paras = {}

HCP_sc = {}
for i in range(40):
    sub = subs[i]
    

    HCP_ts_test[sub] = np.loadtxt(out_dir + 'bold_test_idt_4p_'+ sub +'.txt')
    HCP_ts_sim[sub] = np.loadtxt(out_dir + 'bold_train_idt_4p_'+ sub +'.txt')
    
   
    
    HCP_paras[sub] = np.loadtxt(out_dir + 'parameters_idt_4p_'+ sub +'.txt')
    
np.save(out_dir +'HCP_ts_sim_4p_idt.npy', HCP_ts_sim)
np.save(out_dir + 'HCP_ts_test_4p_idt.npy', HCP_ts_test)


np.save(out_dir + 'HCP_fitparas_4p_idt.npy', HCP_paras)