In [22]:
from environments import reinf_deval
from agents import qlearn, dynaq, po_qlearn, po_dynaq
from pomdp import *
from mdp import *
from plotnine import *
from scipy.stats import ttest_ind
import numpy as np
import pandas as pd

n_rats = 20 # number of 'stats rats' per group
env = reinf_deval() # define environment (go/no-go reversal learning)

In [36]:
# POMDP that approximates the reinforcer devaluation task
iti_prob = 0.25
noise_factor = 0.001

A = np.zeros((6, 6, 2)) # transition array
A[0, 0, :] = iti_prob # ITI -> ITI
A[0, [1, 2], :] = (1 - iti_prob)/2 # ITI -> sound 1 or sound 2
A[1, 3, 1] = 1 # sound 1, go to well -> food 1
A[2, 4, 1] = 1 # sound 2, go to well -> food 2
A[[1, 2], 0, 0] = 1 # trial ends after sitting there
A[[3, 4], 0, 0] = 1 # trial ends after sitting there
A[[3, 4], 0, 1] = 1 # trial ends after eating
#A[5, 0, :] = 1 # trial ends after nausea
# define observation array
B = np.zeros((6, 6, 2)) # observation matrix
B[0, 0, :] = 1 # ITI
B[1, 1, :] = 1 # sound 1
B[2, 2, :] = 1 # sound 2
B[3, 3, :] = 1 # food 1
B[4, 4, :] = 1 # food 2
B[5, 5, :] = 1 # nausea
# define reward array
R = np.zeros((6, 6, 2))
R[3, 5, 1] = 1 # food 1, eat -> ITI
R[4, 5, 1] = 1 # food 2, eat -> ITI

# add noise/make any transition possible
for i in range(6):
    for j in range(2):
        numerator = A[i, :, j] + noise_factor
        A[i, :, j] = numerator/np.sum(numerator)
        print(np.round(A[i, :, j], 2))
        
# create POMDP using A, B and R defined above
deval_pomdp = pomdp(A = A, B = B, R = R)
# create MDP using A and R defined above
deval_mdp = mdp(A = A, R = R)

[0.05 0.47 0.47 0.   0.   0.  ]
[0.05 0.47 0.47 0.   0.   0.  ]
[1. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0.]
[1. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0.]
[1. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0.]
[0.17 0.17 0.17 0.17 0.17 0.17]
[0.17 0.17 0.17 0.17 0.17 0.17]


In [34]:
# Simulate regular q-learning.
model = qlearn(env = env, learning_rate = 0.1, epsilon = 0.05, gamma = 0.5)
model.learn(1)
print(np.round(model.q, 2))

[[0.12 0.07]
 [0.04 0.53]
 [0.03 0.53]
 [0.03 1.06]
 [0.04 1.06]
 [0.   0.  ]]


In [38]:
# Simulate Dyna-Q.
model = dynaq(env = env, mdp = deval_mdp, learning_rate = 0.1, epsilon = 0.05, gamma = 0.5, n_dyna = 1)
model.learn(1)
print(np.round(model.q, 2))

[[0.05 0.04]
 [0.02 0.12]
 [0.03 0.21]
 [0.02 0.29]
 [0.02 0.35]
 [0.12 0.06]]


In [8]:
# Simulate POMDP q-learning.
model = po_qlearn(env = env, pomdp = deval_pomdp, learning_rate = 0.1, epsilon = 0.05, gamma = 0.5)
model.learn(1)
print(np.round(model.q, 2))

[[0.13 0.07]
 [0.02 0.53]
 [0.04 0.53]
 [0.02 1.06]
 [0.03 1.06]
 [0.   0.  ]]


In [19]:
# Simulate POMDP Dyna-Q.
model = po_dynaq(env = env, pomdp = deval_pomdp, learning_rate = 0.1, epsilon = 0.05, gamma = 0.5, n_dyna = 3)
model.learn(1)
print(np.round(model.q, 2))

[[0.04 0.05]
 [0.02 0.12]
 [0.02 0.1 ]
 [0.02 0.24]
 [0.02 0.19]
 [0.   0.  ]]


In [10]:
# Compare regular and POMDP q-learning.
rwd_ctrl0 = []
rwd_lesion0 = []
rwd_ctrl1 = []
rwd_lesion1 = []
for i in range(n_rats):
    model = po_qlearn(env = env, pomdp = rev_pomdp, learning_rate = 0.1, epsilon = 0.05, gamma = 0.5)
    model.learn(1)
    rwd_keep = model.rwd_list.loc[model.obs_list.isin([1, 2])]
    rwd_ctrl0 += [np.mean(rwd_keep[0:250])]
    rwd_ctrl1 += [np.mean(rwd_keep[250:])]
    model = qlearn(env = env, learning_rate = 0.1, epsilon = 0.05, gamma = 0.5)
    model.learn(1)
    rwd_keep = model.rwd_list.loc[model.obs_list.isin([1, 2])]
    rwd_lesion0 += [np.mean(rwd_keep[0:250])]
    rwd_lesion1 += [np.mean(rwd_keep[250:])]
df = pd.DataFrame({'rwd_initial' : rwd_ctrl0 + rwd_lesion0, 'rwd_reversal' : rwd_ctrl1 + rwd_lesion1, 'group' : n_rats*['control'] + n_rats*['lesion']})

print('initial learning')
print(np.round(ttest_ind(rwd_ctrl0, rwd_lesion0).statistic, 2))
print(np.round(ttest_ind(rwd_ctrl0, rwd_lesion0).pvalue, 4))
p0 = (ggplot(df, aes('group', 'rwd_initial'))
 + geom_boxplot())
p0.draw()
print()
print('reversal')
print(np.round(ttest_ind(rwd_ctrl1, rwd_lesion1).statistic, 2))
print(np.round(ttest_ind(rwd_ctrl1, rwd_lesion1).pvalue, 4))
p1 = (ggplot(df, aes('group', 'rwd_reversal'))
 + geom_boxplot())
p1.draw()

NameError: name 'rev_pomdp' is not defined