# Latent cause model of associative learning
## Python implementation

First, simulate same data. X is a matrix in which rows are trials and columns are features. <br/>
The first column is the US, 2nd to last columns are CS (features).<br/>
All of them have to be binary (for the present implementation).

In [2]:
import numpy as np

n_trial = 30
n_features = 4
X = np.random.randint(0,2,(n_trial, n_features))
print(X)

[[0 1 0 0]
 [0 0 1 1]
 [1 1 0 1]
 [0 0 0 0]
 [0 0 1 0]
 [1 1 1 1]
 [0 1 0 1]
 [0 1 0 0]
 [0 0 0 0]
 [0 0 0 0]
 [0 1 1 0]
 [0 0 0 1]
 [0 0 0 0]
 [0 0 0 0]
 [1 0 0 1]
 [1 1 0 0]
 [0 0 1 1]
 [0 1 1 1]
 [1 0 0 1]
 [0 0 1 1]
 [0 0 1 0]
 [0 1 1 0]
 [0 0 1 0]
 [1 1 1 0]
 [1 0 1 0]
 [1 1 1 0]
 [0 0 0 0]
 [1 0 0 1]
 [0 0 1 1]
 [1 1 1 1]]


CR is a vector of conditioned responses to the trials. This could be ratings, SCR, pupil dilation, etc.

In [3]:
CR = np.random.randn( 30,1 )

There is a few options that can be given for the actual algorithm.<br/>
Among them the number of particles and max number of potential causes:

In [4]:
opts = { 
    'n_particles': 100,
    'max_cause': 100
}

Now we import the needed class from LCM.py

In [5]:
from LCM import LCM_gridsearch

Initialize instance using the features and options.<br/>
Additional options are used for the fitting procedure, not the actual algorithm. E.g. the resolution for the grid search.

In [8]:
p = LCM_gridsearch(CR,X,opts=opts)

Loop over potential values of alpha, compute the algorithm for each value.

In [9]:
p.loop_alpha()

Now for some inference. Compute posterior over alpha grid, expected alpha given posterior and log BF of full alpha range against alpha = 0.<br/>
Note that the procedure as implemented assumes an uniform prior over alpha.

In [10]:
out = p.inference()

In [11]:
out

{'P': array([[0.0230141 ],
        [0.02125046],
        [0.02140812],
        [0.02125143],
        [0.02083616],
        [0.02109846],
        [0.02111743],
        [0.02066899],
        [0.02092292],
        [0.02057353],
        [0.02060556],
        [0.02015785],
        [0.02046423],
        [0.02027908],
        [0.02013509],
        [0.01982113],
        [0.01988606],
        [0.01988796],
        [0.01993458],
        [0.01992929],
        [0.01998149],
        [0.02014752],
        [0.01978369],
        [0.01990613],
        [0.01987613],
        [0.01986429],
        [0.01991852],
        [0.01959481],
        [0.01988867],
        [0.01960729],
        [0.01944533],
        [0.01971451],
        [0.01957561],
        [0.01959807],
        [0.01951657],
        [0.01952057],
        [0.01953591],
        [0.01971155],
        [0.01939751],
        [0.0191906 ],
        [0.019345  ],
        [0.01926342],
        [0.01933403],
        [0.01944397],
        [0.01955943],
     

That's pretty much it.