In [1]:
using CSV
using PyPlot
using DataFrames
using SpecialFunctions:erf
using StatsFuns:logsumexp,logistic

In [None]:
# load the EM library

full = false    # Maintain full covariance matrix (vs a diagional one) at the group level
emtol = 1e-2    # stopping condition (relative change) for EM

em_directory = "" #directory where EM functions are stored, from Daw lab.

push!(LOAD_PATH,em_directory)
using EM

#load the likelihood functions
lik_directory = "" #directory where likelihood functions are stored, from modules folder.

push!(LOAD_PATH,lik_directory)
using likFuncs

In [3]:
# load the data

#insert df path into function
df = CSV.read("triframe.csv", DataFrame);

#df.sub = df.session
#subs = unique(df.session) # to cluster at the session level
#NS = length(subs)
df.sub = df.rat
subs = unique(df.rat)
NS = length(subs)

df.port = df.port .+ 1
df.port = convert.(Int8,df[!,:port])
df.currentport = Int64.([0; df.port[1:end-1]]) 
df.lrchoice = df.lrchoice .+ 1
df

Unnamed: 0_level_0,Column1,port,rwd,nom_rwd_a,nom_rwd_b,tri,block,nom_rwd_c,lenAC
Unnamed: 0_level_1,Int64,Int8,Float64,Float64,Float64,Float64,Float64,Float64,Float64
1,0,2,1.0,50.0,10.0,0.0,1.0,90.0,19.0
2,1,3,1.0,50.0,10.0,1.0,1.0,90.0,19.0
3,2,2,0.0,50.0,10.0,2.0,1.0,90.0,19.0
4,3,1,0.0,50.0,10.0,3.0,1.0,90.0,19.0
5,4,3,1.0,50.0,10.0,4.0,1.0,90.0,19.0
6,5,2,0.0,50.0,10.0,5.0,1.0,90.0,19.0
7,6,1,0.0,50.0,10.0,6.0,1.0,90.0,19.0
8,7,2,0.0,50.0,10.0,7.0,1.0,90.0,19.0
9,8,3,0.0,50.0,10.0,8.0,1.0,90.0,19.0
10,9,2,1.0,50.0,10.0,9.0,1.0,90.0,19.0


# Port Q learner optimization

In [7]:
df.sub = df.rat
subs = unique(df.rat) # clustering on sessiosn right now
NS = length(subs)
X = ones(NS)

betas = [0. 0 0 0];
sigma = [1.,1,1.,1];
@time (betasa,sigmaa,xa,la,ha) = em(df,subs,X,betas,sigma,qlik3port; maxiter=2000,emtol=1e-3, full=true);

println("model qlik 3 port, aggregate log marginal likelihood: ",lml(xa,la,ha))
q3portparams = DataFrame(xa,:auto)
q3portparams = rename(q3portparams,:x1=>:beta,:x2=>:alpha,:x3=>:ccwBias,:x4=>:dist_bias)
q3portparams[!,:alpha] = 0.5 .+ 0.5 .* erf.(q3portparams[!,:alpha]/sqrt(2))
CSV.write(loadpath*"tri_q3port_params.csv",q3portparams)
q3portparams


iter: 80
betas: [1.53 -1.44 -0.18 -0.22]
sigma: [0.51 -0.3 -0.01 -0.02; -0.3 0.22 -0.03 0.02; -0.01 -0.03 0.04 -0.01; -0.02 0.02 -0.01 0.0]
free energy: -9221.486207
change: [0.000172, -4.7e-5, -2.0e-5, -2.0e-6, 0.000978, -0.000651, -0.000696, -0.000458, 0.00025, -7.0e-5, 0.000148, 1.6e-5, -3.8e-5, 1.5e-5]
max: 0.000978
749.574049 seconds (10.06 G allocations: 746.791 GiB, 14.03% gc time, 0.80% compilation time)
model qlik 3 port, aggregate log marginal likelihood: 9186.834345169731


Unnamed: 0_level_0,beta,alpha,ccwBias,dist_bias
Unnamed: 0_level_1,Float64,Float64,Float64,Float64
1,2.75739,0.0134133,-0.184902,-0.272958
2,2.54216,0.0324303,-0.422433,-0.217442
3,1.27591,0.0865239,-0.104254,-0.226532
4,0.589599,0.26513,-0.371423,-0.166685
5,1.39343,0.0850419,-0.224862,-0.176956
6,1.65382,0.0421338,0.111805,-0.321699
7,1.72843,0.0302564,0.0924233,-0.265701
8,1.62607,0.112733,-0.421039,-0.205079
9,1.03866,0.130566,-0.263752,-0.155836
10,0.645986,0.147959,-0.0412668,-0.20733


# Hybrid model optimization

In [8]:
df.sub = df.rat
subs = unique(df.rat) # clustering on sessiosn right now
NS = length(subs)
X = ones(NS)

betas = [0. 0. 0. 0. 0.];
sigma = [1.,1.,1.,1.,1.];
@time (betas,sigma,x,l,h) = em(df,subs,X,betas,sigma,qlikh; maxiter=2000,emtol=1e-3, full=true);

println("model qlik hybrid, aggregate log marginal likelihood: ",lml(x,l,h))
hyb_params = DataFrame(x,:auto)
hyb_params = rename(hyb_params,:x1=>:beta,:x2=>:lrmf,:x3=>:lrmb,:x4=>:ccwBias,:x5=>:bdist)#,:x6=>:retBias)
hyb_params[!,:lrmf] = 0.5 .+ 0.5 .* erf.(hyb_params[!,:lrmf]/sqrt(2))
hyb_params[!,:lrmb] = 0.5 .+ 0.5 .* erf.(hyb_params[!,:lrmb]/sqrt(2))
CSV.write(loadpath*"tri_hybrid_params.csv",hyb_params)
hyb_params


iter: 46
betas: [1.4 -1.45 -1.34 0.17 -0.21]
sigma: [0.36 -0.13 -0.42 0.01 -0.01; -0.13 0.15 0.15 0.05 0.02; -0.42 0.15 0.52 -0.02 0.01; 0.01 0.05 -0.02 0.03 0.01; -0.01 0.02 0.01 0.01 0.0]
free energy: -9317.374832
change: [1.2e-5, -1.1e-5, -9.1e-5, 2.0e-5, -8.0e-6, 0.000359, -0.000125, -0.000168, 0.000724, -0.000288, 7.3e-5, 0.000725, 0.000691, 0.000808, 0.000161, -0.001, 0.000427, 1.0e-5, 3.9e-5, 5.6e-5]
max: 0.001
367.894470 seconds (6.50 G allocations: 635.620 GiB, 27.89% gc time, 0.25% compilation time)
model qlik hybrid, aggregate log marginal likelihood: 9282.330766018516


Unnamed: 0_level_0,beta,lrmf,lrmb,ccwBias,bdist
Unnamed: 0_level_1,Float64,Float64,Float64,Float64,Float64
1,2.73954,0.0259815,0.00158738,0.202778,-0.256058
2,1.72726,0.0908449,0.0417556,0.326662,-0.198691
3,1.34247,0.0619876,0.0976946,0.10487,-0.216905
4,0.577919,0.221628,0.370405,0.360945,-0.154487
5,1.35113,0.0902732,0.0810844,0.208459,-0.165415
6,1.24824,0.0309009,0.166294,-0.108722,-0.305598
7,1.5195,0.0238876,0.0693425,-0.0939253,-0.266176
8,1.65516,0.132151,0.0572829,0.416976,-0.193925
9,1.08695,0.122339,0.124939,0.262102,-0.135358
10,0.727534,0.0746126,0.286481,0.0259464,-0.205288
