In [1]:
import matplotlib.pyplot as plt
import plotly
import numpy as np
import random
import hrr
import math
from plotly.graph_objs import Scatter, Layout, Surface
plotly.offline.init_notebook_mode(connected=True)

In [2]:
def log_transform(error):
    return math.copysign(1.0,error)*math.log(math.fabs(error)+1,2)

In [3]:
def softmax(arr,t=1.0):
    w = np.array(arr)
    e = np.exp(w / t)
    dist = e / np.sum(e)
    return dist

In [4]:
def argmax(arr_3d,outer,inner):
    max_row = outer[0]
    max_col = inner[0]
    max_x = 0
    #max_value = arr_2d[0,0]
    max_value = arr_3d[outer[0],inner[0],0]
    for row in range(arr_3d.shape[0]):
        if row not in outer:
            continue
        for col in range(arr_3d.shape[1]):
            if col not in inner:
                continue
            for x in range(arr_3d.shape[2]):
                if arr_3d[row,col,x] > max_value:
                    max_value = arr_3d[row,col,x]
                    max_row,max_col,max_x = row,col,x
    return list((max_row,max_col,max_x))

In [5]:
def context_check(outer,inner):
    if outer == 0:
        if inner == 0:
            return 'AX'
        elif inner == 1:
            return 'AY'
    elif outer == 1:
        if inner == 0:
            return 'BX'
        elif inner == 1:
            return 'BY'

In [6]:
def performance(outer,inner,action,arr_2d):
    # arr_2d[count,numcorrect,performance]
    if context_check(outer,inner)=='AX':
        #count1+=1
        arr_2d[0,0]+=1
        if action == 0:
            #nc1 += 1
            arr_2d[0,1]+=1
        #AX_perf = nc1/count1
        arr_2d[0,2] = arr_2d[0,1]/arr_2d[0,0]
    elif context_check(outer,inner)=='BX':
        #count2+=1
        arr_2d[1,0]+=1
        if action == 1:
            #nc2 += 1
            arr_2d[1,1]+=1
        #BX_perf = nc2/count2
        arr_2d[1,2] = arr_2d[1,1]/arr_2d[1,0]
    elif context_check(outer,inner)=='AY':
        #count3+=1
        arr_2d[2,0]+=1
        if action == 1:
            #nc3 += 1
            arr_2d[2,1]+=1
        #AY_perf = nc3/count3
        arr_2d[2,2] = arr_2d[2,1]/arr_2d[2,0]
    elif context_check(outer,inner)=='BY':
        #count4+=1
        arr_2d[3,0]+=1
        if action == 1:
            #nc4 += 1
            arr_2d[3,1]+=1
        #BY_perf = nc4/count4
        arr_2d[3,2] = arr_2d[3,1]/arr_2d[3,0]

In [7]:
def TD(ntrials,lrate,gamma,td_lambda,temp,decay,time):
    n = 1000
    nactions = 2 # number of actions
    nwm_o = 2 # number of outer wm slots
    nwm_i = 2 # number of inner wm slots
    nsig_o = 2 # number of cue signals
    nsig_i = 2 # number of probe signals
    
    # reward matix, reward given at 0,0,0
    reward = np.zeros((nsig_o+1,nsig_i+1,nactions))
    reward[0,0,0] = 1
    reward[0,1,1] = 1
    reward[1,0,1] = 1
    reward[1,1,1] = 1
    
    #reward_outer = np.zeros((nsig_o+1,nactions))
    #reward_outer[0,0] = 1
    #reward_outer[1,1] = 1
    # hrr for actions
    actions = hrr.hrrs(n,nactions)
    
    # identity vector
    hrr_i = np.zeros(n)
    hrr_i[0] = 1
    
    # cue outer
    sig_outer = hrr.hrrs(n,nsig_o)
    sig_outer = np.row_stack((sig_outer,hrr_i))
    
    # probe inner
    sig_inner = hrr.hrrs(n,nsig_i)
    sig_inner = np.row_stack((sig_inner,hrr_i))
     
    # outer working memory
    wm_outer = hrr.hrrs(n,nwm_o)
    wm_outer = np.row_stack((wm_outer,hrr_i))
    
    # inner working memory
    wm_inner = hrr.hrrs(n,nwm_i)
    wm_inner = np.row_stack((wm_inner,hrr_i))
    
    # precomputed
    external = hrr.oconvolve(sig_inner,sig_outer) 
    s_a = hrr.oconvolve(external,actions)
    s_a = np.reshape(s_a,(nsig_o+1,nsig_i+1,nactions,n))
    # weight vector and bias
    W = hrr.hrr(n)
    bias = 1
    
    # epsilon for e-soft policy
    epsilon = .01
    
    # temperature for softmax
    t = temp
    
    # eligibility trace
    eligibility = np.zeros(n)
    # array that keeps track of AX-CPT performance
    perf_arr = np.zeros((4,3))
    for trial in range(ntrials):
        eligibility = np.zeros(n)
        # 70% AX trials, 30% AY,BX,BY trials #
        index = np.random.choice([0,1,2,3],p=[.7,.1,.1,.1])
        choices = [(0,0),(0,1),(1,0),(1,1)]
        cue,probe = choices[index]
        #################################
        
        # sets context for later use
        outer = cue
        
        inner = probe
        
        ###### is cue worth remembering #######
        #######################################
        cue_outerwm = hrr.convolve(sig_outer[cue],wm_outer)
        values = np.dot(cue_outerwm,W) + bias
        sm_prob = softmax(values,t)
        wm_o = np.unravel_index(np.argmax(sm_prob),sm_prob.shape)
        #wm1 = wm_outer[wm_o] # selected memory slot
        #print('convolve:',cue_outerwm.shape)
        #print('values:',values.shape)
        #print('smax:',sm_prob.shape)
        #print('wm1:',wm1.shape)
        #print(wm_o)
        #print(sm_prob)
        
        ########## epsilon soft ###################
        if random.random() < epsilon:
            wm_o = random.randrange(nwm_o+1)
            
        trace1 = hrr.convolve(sig_outer[cue],wm_outer[wm_o])
        wm1 = wm_outer[wm_o] # selected memory slot
        ###### decay chosen workingMemory ########
        wm_outer_decayed = np.array(wm1)
        wm_outer_decayed = hrr.pow(wm_outer_decayed,decay**time)
        #trace1 = hrr.convolve(sig_outer[cue],wm_outer_decayed) #changed
        #######################################
        #print(wm_outer_decayed)
        #r = 0 # reward for outer memory choice
        pvalue = values[wm_o] # stores previous Q value
        eligibility = cue_outerwm[wm_o] + td_lambda*eligibility
        #eligibility = trace1 + td_lambda*eligibility
        #######################################
        wm1_wm2 = hrr.convolve(wm_outer_decayed,wm_inner) # convolve chosen outer wm with matrix of inner wm choices
        #wm1_wm2 = hrr.convolve(trace1,wm_inner) # changed
        probe_outerinnerwm_a = hrr.convolve(sig_inner[probe],hrr.oconvolve(wm1_wm2,actions))
        probe_outerinnerwm_a = np.reshape(probe_outerinnerwm_a,(nwm_o+1,nactions,n))
        values = np.dot(probe_outerinnerwm_a,W) + bias
        sm_prob = softmax(values,t)
        wm_i = np.unravel_index(np.argmax(sm_prob),sm_prob.shape) 
        current_memory = wm_i[0]
        action = wm_i[1]
        #print('probe:',probe_outerinnerwm_a.shape)
        #print('value:',values.shape)
        #print('smax:',sm_prob.shape)
        #print(sm_prob)
        #print(eligibility.shape)
        #######################################
    
        ######### epsilon soft policy ##########
        if random.random() < epsilon:
            
            action = random.randrange(0,nactions)
            current_memory = random.randrange(nwm_i+1)
            #current_outer_wm = random.randrange(nwm_o+1)
            #current_inner_wm = random.randrange(nwm_i+1)
        ########################################  
        r = reward[outer,inner,action]
        value = values[current_memory,action]
        error = (r + gamma*value) - pvalue
        W += lrate*log_transform(error)*eligibility
        
        eligibility = probe_outerinnerwm_a[current_memory,action,:] + td_lambda*eligibility
        ########################################
        value = values[current_memory,action]
        #r = reward[outer,inner,action]
        error = r - value
        W += lrate*log_transform(error)*eligibility
        performance(outer,inner,action,perf_arr)
        ########################################
        
        if trial%1000==0:
            print('Trial:',trial,end='\n\n')
            print(format('','>10s'),format('count','>12s'),format('performance','>20s'))
            print(format('AX |','<10s'),format(perf_arr[0,0],'>12.1f'),format(perf_arr[0,2],'>20.2%'))
            print(format('BX |','<10s'),format(perf_arr[1,0],'>12.1f'),format(perf_arr[1,2],'>20.2%'))
            print(format('AY |','<10s'),format(perf_arr[2,0],'>12.1f'),format(perf_arr[2,2],'>20.2%'))
            print(format('BY |','<10s'),format(perf_arr[3,0],'>12.1f'),format(perf_arr[3,2],'>20.2%'))
            perf_arr = np.zeros((4,3))
            print(end='\n\n')
        
    print('Trial:',trial,end='\n\n')
    print(format('','>10s'),format('count','>12s'),format('performance','>20s'))
    print(format('AX |','<10s'),format(perf_arr[0,0],'>12.1f'),format(perf_arr[0,2],'>20.2%'))
    print(format('BX |','<10s'),format(perf_arr[1,0],'>12.1f'),format(perf_arr[1,2],'>20.2%'))
    print(format('AY |','<10s'),format(perf_arr[2,0],'>12.1f'),format(perf_arr[2,2],'>20.2%'))
    print(format('BY |','<10s'),format(perf_arr[3,0],'>12.1f'),format(perf_arr[3,2],'>20.2%'))
    perf_arr = np.zeros((4,3))
    print(end='\n\n')
         

In [None]:
TD(11000,.01,.9,.8,.1,.5,0)
# (num trials, learning rate, discount factor, lambda, temperature, decay factor, decay time steps)

Trial: 0

                  count          performance
AX |                0.0                0.00%
BX |                0.0                0.00%
AY |                1.0                0.00%
BY |                0.0                0.00%


Trial: 1000

                  count          performance
AX |              680.0               97.06%
BX |              115.0               67.83%
AY |              103.0               98.06%
BY |              102.0               89.22%


Trial: 2000

                  count          performance
AX |              686.0               98.83%
BX |               92.0               95.65%
AY |              114.0               97.37%
BY |              108.0               98.15%


Trial: 3000

                  count          performance
AX |              716.0               98.60%
BX |               99.0               96.97%
AY |               93.0              100.00%
BY |               92.0               96.74%


Trial: 4000

                  count       

In [None]:
TD(11000,.01,.9,.5,.1,.5,5000)

Trial: 0

                  count          performance
AX |                0.0                0.00%
BX |                1.0              100.00%
AY |                0.0                0.00%
BY |                0.0                0.00%


Trial: 1000

                  count          performance
AX |              695.0               95.83%
BX |               91.0                2.20%
AY |              105.0               95.24%
BY |              109.0               97.25%


Trial: 2000

                  count          performance
AX |              688.0               99.42%
BX |               88.0                2.27%
AY |              120.0               97.50%
BY |              104.0              100.00%


Trial: 3000

                  count          performance
AX |              692.0               98.99%
BX |              109.0                0.92%
AY |               96.0               98.96%
BY |              103.0              100.00%


Trial: 4000

                  count       

In [15]:
x

array([ 0.55259584,  0.03874157,  0.380548  ,  0.24951881,  0.01340751,
       -0.35237558, -0.26257127,  0.34096257,  0.31601992, -0.27684736])

In [127]:
x = np.zeros((4,3))

In [128]:
x

array([[ 0.,  0.,  0.],
       [ 0.,  0.,  0.],
       [ 0.,  0.,  0.],
       [ 0.,  0.,  0.]])

In [146]:
count = 12
perf = .45
print(format('','>20s'),format('count','>12s'),format('performance','>20s'))
print(format('AX |','<20s'),format(count,'>12d'),format(perf,'>20%'))
print(format('AY |','<20s'),format(count,'>12d'),format(perf,'>20%'))
print(format('BX |','<20s'),format(count,'>12d'),format(perf,'>20%'))
print(format('BY |','<20s'),format(count,'>12d'),format(perf,'>20%'))

                            count          performance
AX |                           12           45.000000%
AY |                           12           45.000000%
BX |                           12           45.000000%
BY |                           12           45.000000%
