In [1]:

import pickle
import gzip
import time

import numpy as np
import scipy as sp
import scipy.io as sio

import pymc3 as pm
import theano
import theano.tensor as tt
import matplotlib.pyplot as plt
import pandas as pd
import random



In [2]:

def get_data(sub_id):
    subdata=mydata[mydata.Subject==sub_id]
    trialindex=list(subdata['trialindex'])
    rt=list(subdata['RT'])
    conds=list(subdata['Type'])
    print('The trial length of sub%02d is %d' %(sub_id,len(trialindex)))
    return rt,trialindex,conds

In [3]:

def do_sampling_random_switchpoint(sub_id,rt, trialindex, conds,fakesp):
    # use empirical mean (ignoring condition or time point) as center of prior
    mu_obs = np.mean(rt)
    sd_obs = np.std(rt)
    
    model = pm.Model()
    with model:#define the model
        #mu_new is the mean RT of random stimulus
        mu_new = pm.Normal('mu_new', mu=mu_obs, sd=sd_obs*2, testval=mu_obs)
        mu_before_old_benefit=pm.Normal('mu_before_old_benefit', mu=mu_obs, sd=sd_obs*2, testval=mu_obs)
        mu_before = pm.math.switch(conds, mu_new-mu_before_old_benefit, mu_new)
        mu_after_old_benefit=pm.Normal('mu_after_old_benefit', mu=mu_obs, sd=sd_obs*2, testval=mu_obs)
        #function from theano.tensor，switch(cond, ift, iff)，if cond then ift else iff
        mu_after = pm.math.switch(conds, mu_new-mu_after_old_benefit, mu_new)
        #pseudo switchpoint
        a=np.arange(fakesp,fakesp+1)
        switchpoint=a.repeat(len(trialindex)) 
        print('The fake switchpoint of sub%02d is %d' %(sub_id,fakesp))
        
        #if trial in or after the switchpoint session，mu=mu_after，else mu=mu_before
        mu = pm.math.switch(trialindex > (switchpoint-1)*60, mu_after, mu_before)
        
        sigma = pm.HalfNormal('sigma', sd=sd_obs*2, testval=sd_obs*2)
        #model construction
        rt_modelled = pm.Normal('rt_modelled', mu=mu, sd=sigma, observed=rt)
        
        step = pm.Metropolis()
        
        trace = pm.sample(40000, step=step, start=model.test_point, chains=4,cores=4)#MCMC
    
    return trace[20000::5], model#delete the first 20000 samples(burn-in)，take every fifth of the remaining samples(thining)

In [4]:
def do_sampling_given_switchpoint(sub_id,rt, trialindex, conds,sp):
    # use empirical mean (ignoring condition or time point) as center of prior
    mu_obs = np.mean(rt)
    sd_obs = np.std(rt)
    
    model = pm.Model()
    with model:
        
        mu_new = pm.Normal('mu_new', mu=mu_obs, sd=sd_obs*2, testval=mu_obs)
        mu_before_old_benefit=pm.Normal('mu_before_old_benefit', mu=mu_obs, sd=sd_obs*2, testval=mu_obs)
        mu_before = pm.math.switch(conds, mu_new-mu_before_old_benefit, mu_new)
       
        mu_after_old_benefit=pm.Normal('mu_after_old_benefit', mu=mu_obs, sd=sd_obs*2, testval=mu_obs)
        mu_after = pm.math.switch(conds, mu_new-mu_after_old_benefit, mu_new)
        #switchpoint by verbal measure
        a=np.arange(sp,sp+1)
        switchpoint=a.repeat(len(trialindex)) 
        print('The true switchpoint of sub%02d is %d' %(sub_id,sp))
    
        mu = pm.math.switch(trialindex > (switchpoint-1)*60, mu_after, mu_before)
        
        sigma = pm.HalfNormal('sigma', sd=sd_obs*2, testval=sd_obs*2)
     
        rt_modelled = pm.Normal('rt_modelled', mu=mu, sd=sigma, observed=rt)

        
        step = pm.Metropolis()
        
        trace = pm.sample(40000, step=step, start=model.test_point, chains=4,cores=4)#MCMC采样
    
    return trace[20000::5], model

In [5]:
def do_sampling_noswitchpoint(sub_id,rt, trialindex, conds):
    # use empirical mean (ignoring condition or time point) as center of prior
    mu_obs = np.mean(rt)
    sd_obs = np.std(rt)
    
    model = pm.Model()
    with model:
        mu_new =  pm.Normal('mu_new', mu=mu_obs, sd=sd_obs*2, testval=mu_obs)
        mu_old_benefit =  pm.Normal('mu_old_benefit', mu=mu_obs, sd=sd_obs*2, testval=mu_obs)
        sigma = pm.HalfNormal('sigma', sd=sd_obs*2, testval=sd_obs*2)
        
        mu = pm.math.switch(conds, mu_new-mu_old_benefit, mu_new)
        
        rt_modelled = pm.Normal('rt_modelled', mu=mu, sd=sigma, observed=rt)
        
        step = pm.Metropolis()
        
        trace = pm.sample(40000, step=step, start=model.test_point, chains=4,
            cores=4)
    return trace[20000::5], model

In [6]:
def model_construct(sub_id,model_type,sp):
        filepath='E:/transition-upload/Python/Bayesian model/experiment2'
        rt,trialindex,conds=get_data(sub_id)
        #log transformation
        logrt=np.log10(rt)
        plt.scatter(trialindex,logrt)
        plt.savefig(filepath+'/scatter_sub{:02d}.png'.format(sub_id))
        print("Now is fitting %s model for sub%02d......"%(model_type,sub_id))
        if model_type=='nosp':
            trace,model=do_sampling_noswitchpoint(sub_id,logrt,trialindex,conds)
        elif model_type=='randomsp':
            trace,model=do_sampling_random_switchpoint(sub_id,logrt,trialindex,conds,sp)
        elif model_type=='givensp':
            trace,model=do_sampling_given_switchpoint(sub_id,logrt,trialindex,conds,sp)
        with model:
            pm.traceplot(trace)
            plt.savefig(filepath+'/{}_trace_sub{:02d}.png'.format(model_type, sub_id))
            plt.close('all')
            
            pm.plot_posterior(trace)
            plt.savefig(filepath+'/{}_posterior_sub{:02d}.png'.format(model_type, sub_id))
            plt.close('all')
            
            #export data
            with gzip.open(filepath + '/tracedata/{}_trace_sub{:02d}.pkl.gz'.format(model_type, sub_id), 'wb') as f:
                pickle.dump((trace, model), f)
            waic=pm.waic(trace,scale='deviance')
        print("The WAIC of %s model is %f"%(model_type,waic.waic))
        print("--------------------------------------------------------")
        return trace,model


In [7]:
def run(sub_id,sp,fakesp):
    filepath='E:/transition-upload/Python/Bayesian model/experiment2'
    tracenp,modelnp=model_construct(sub_id,'nosp',sp)
    tracegp,modelgp=model_construct(sub_id,'givensp',sp)
    tracerp,modelrp=model_construct(sub_id,'randomsp',fakesp)
    with pd.ExcelWriter(filepath+'/summary_sub'+str(sub_id)+'.xlsx') as writer: 
        with modelnp:
            pm.summary(tracenp).to_excel(writer, sheet_name='noswitchpoint')
           
        with modelgp:
            pm.summary(tracegp).to_excel(writer, sheet_name='givenswtichpoint')

        with modelrp:
            pm.summary(tracerp).to_excel(writer, sheet_name='randomswitchpoint')
    #waic comparision
    df_comp_WAIC = pm.compare({'randomswitchpoint': tracerp,'noswitchpoint': tracenp,'givenswitchpoint':tracegp},ic='waic',scale='deviance')
    df_comp_WAIC.to_csv(filepath+'/cmp_waic_sub'+str(sub_id)+'.csv')
    

In [9]:
#read real switchpoint 
tpdata=pd.read_csv('E:/transition-upload/Python/Bayesian model/experiment2/exp2_tp.csv')
tpdata.subject=(tpdata['subject']).astype(int)
tpdata.transition_Phase=(tpdata['true_transition']).astype(int)
tpdict =tpdata.set_index('subject')['true_transition'].to_dict()
print(tpdict)
#read pseudo switchpoint
ftpdict =tpdata.set_index('subject')['fake_transition'].to_dict()
print(ftpdict)

{2: 6, 5: 7, 7: 4, 8: 4, 9: 6, 11: 4, 14: 6, 15: 5, 18: 4, 20: 4, 21: 3, 25: 5, 27: 4, 29: 7, 30: 5, 32: 4, 33: 6, 34: 4}
{2: 3, 5: 6, 7: 3, 8: 6, 9: 5, 11: 6, 14: 5, 15: 7, 18: 5, 20: 3, 21: 5, 25: 3, 27: 5, 29: 3, 30: 3, 32: 6, 33: 4, 34: 7}


  after removing the cwd from sys.path.


In [10]:
#read RT data
mydata=pd.read_csv('E:/transition-upload/Python/Bayesian model/experiment2/exp2_expdata.csv')
#delete NULL data（trials that ACC=0）
mydata.dropna(axis=0,how='any',inplace=True)
print(mydata)


       Subject  trialindex  ACC     RT  Type
0            1         1.0    1  817.0  True
1            1         2.0    1  468.0  True
2            1         3.0    1  448.0  True
3            1         4.0    1  481.0  True
4            1         5.0    1  528.0  True
...        ...         ...  ...    ...   ...
15655       34       536.0    1  400.0  True
15656       34       537.0    1  336.0  True
15657       34       538.0    1  388.0  True
15658       34       539.0    1  425.0  True
15659       34       540.0    1  297.0  True

[15182 rows x 5 columns]


In [13]:
#run 
subid=1
if subid in tpdict.keys():
    run(subid,sp=tpdict[subid],fakesp=ftpdict[subid])
else:
    print("The subject is not in the list")

The subject is not in the list
