In [1]:
import matplotlib.pyplot as plt
import plotly
import numpy as np
import random
import hrr
import math
from plotly.graph_objs import Scatter, Layout, Surface
plotly.offline.init_notebook_mode(connected=True)

In [2]:
def log_transform(error):
    return math.copysign(1.0,error)*math.log(math.fabs(error)+1,2)

In [3]:
def softmax(arr,t=1.0):
    w = np.array(arr)
    e = np.exp(w / t)
    dist = e / np.sum(e)
    return dist

In [4]:
def argmax(arr_2d,restrict):
    max_row = restrict[0]
    max_col = 0
    #max_value = arr_2d[0,0]
    max_value = arr_2d[restrict[0],0]
    for row in range(arr_2d.shape[0]):
        if row not in restrict:
            continue
        for col in range(arr_2d.shape[1]):
            if arr_2d[row,col] > max_value:
                max_value = arr_2d[row,col]
                max_row,max_col = row,col
    return list((max_row,max_col))

In [5]:
def argmax1(arr_3d,outer,inner):
    max_row = outer[0]
    max_col = inner[0]
    max_x = 0
    #max_value = arr_2d[0,0]
    max_value = arr_3d[outer[0],inner[0],0]
    for row in range(arr_3d.shape[0]):
        if row not in outer:
            continue
        for col in range(arr_3d.shape[1]):
            if col not in inner:
                continue
            for x in range(arr_3d.shape[2]):
                if arr_3d[row,col,x] > max_value:
                    max_value = arr_3d[row,col,x]
                    max_row,max_col,max_x = row,col,x
    return list((max_row,max_col,max_x))

In [6]:
def context_check(outer,inner):
    if outer == 0:
        if inner == 0:
            return 'AX'
        elif inner == 1:
            return 'AY'
    elif outer == 1:
        if inner == 0:
            return 'BX'
        elif inner == 1:
            return 'BY'

In [7]:
def performance(outer,inner,action,cont,arr_2d):
    # arr_2d[count,numcorrect,performance]
    if context_check(outer,inner)=='AX':
        #count1+=1
        arr_2d[0,0]+=1
        if action == 0 and cont == 0 or action == 1 and cont == 1:
            #nc1 += 1
            arr_2d[0,1]+=1
        #AX_perf = nc1/count1
        arr_2d[0,2] = arr_2d[0,1]/arr_2d[0,0]
    elif context_check(outer,inner)=='BX':
        #count2+=1
        arr_2d[1,0]+=1
        if action == 1:
            #nc2 += 1
            arr_2d[1,1]+=1
        #BX_perf = nc2/count2
        arr_2d[1,2] = arr_2d[1,1]/arr_2d[1,0]
    elif context_check(outer,inner)=='AY':
        #count3+=1
        arr_2d[2,0]+=1
        if action == 1:
            #nc3 += 1
            arr_2d[2,1]+=1
        #AY_perf = nc3/count3
        arr_2d[2,2] = arr_2d[2,1]/arr_2d[2,0]
    elif context_check(outer,inner)=='BY':
        #count4+=1
        arr_2d[3,0]+=1
        if action == 1 and cont == 0 or action == 0 and cont == 1:
            #nc4 += 1
            arr_2d[3,1]+=1
        #BY_perf = nc4/count4
        arr_2d[3,2] = arr_2d[3,1]/arr_2d[3,0]

In [8]:
def TD(ntrials,lrate,gamma,td_lambda,temp,decay,time):
    n = 1024
    nactions = 2 # number of actions
    nouter_wm = 2 # number of context wm
    ninner_wm = 2 # number of inner wm slots
    n_context = 2 # number of context signals
    n_cue = 2 # number of cue signals
    n_probe = 2 # number of probe signals
    n_cue_probe = 4
    
    ######## reward functions ##########
    # reward matix, reward given at 0,0,0
    # A-X target context
    # 0 action => right press, 1 action => left press
    reward1 = np.zeros((n_cue+1,n_probe+1,nactions))
    reward1[0,0,0] = 1
    reward1[0,1,1] = 1
    reward1[1,0,1] = 1
    reward1[1,1,1] = 1
    # B-Y target context
    reward2 = np.zeros((n_cue+1,n_probe+1,nactions))
    reward2[0,0,1] = 1
    reward2[0,1,1] = 1
    reward2[1,0,1] = 1
    reward2[1,1,0] = 1
    # press left reward
    reward3 = [0,0.2]
    ####################################
    
    
    ############## hrrs ###############
    # identity vector
    hrr_i = np.zeros(n)
    hrr_i[0] = 1
    
    # L R actions hrr
    actions = hrr.hrrs(n,nactions)
    
    # 1 2 context signal hrr
    context_signal = hrr.hrrs(n,n_context)
    context_signal = np.row_stack((context_signal,hrr_i))
    
    # cue signal hrr # not used
    cue_signal = hrr.hrrs(n,n_cue)
    cue_signal = np.row_stack((cue_signal,hrr_i))
    
    # probe signal hrr # not used
    probe_signal = hrr.hrrs(n,n_probe)
    probe_signal = np.row_stack((probe_signal,hrr_i))
    
    # cue and probe signals hrr
    #cue_probe_signal = hrr.hrrs(n,n_cue_probe)
    #cue_probe_signal = np.row_stack((cue_probe_signal,hrr_i))
    
    # outer working memory hrr
    outer_wm = hrr.hrrs(n,nouter_wm)
    outer_wm = np.row_stack((outer_wm,hrr_i))
    
    # inner working memory hrr
    inner_wm = hrr.hrrs(n,ninner_wm)
    inner_wm = np.row_stack((inner_wm,hrr_i))
    
    # precomputed hrr
    #context_action_outerwm = hrr.oconvolve(actions,hrr.oconvolve(context_signal,outer_wm))
    #context_action_outerwm = np.reshape(context_action_outerwm,(n_context+1,nouter_wm+1,nactions,n))
    
    #signal_action_innerwm = hrr.oconvolve(actions,hrr.oconvolve(cue_probe_signal,inner_wm))
    #signal_action_innerwm = np.reshape(signal_action_innerwm,(n_cue_probe+1,ninner_wm+1,nactions,n))
    ########################################
    context_action_outerwm = hrr.oconvolve(actions,hrr.oconvolve(context_signal,outer_wm))
    signal_innerwm = hrr.oconvolve(cue_signal,hrr.oconvolve(inner_wm,probe_signal))
    computed = hrr.oconvolve(context_action_outerwm,signal_innerwm)
    computed = np.reshape(computed,(n_context+1,n_cue+1,n_probe+1,nouter_wm+1,ninner_wm+1,nactions,n))
    
    # weight vector and bias
    W = hrr.hrr(n)
    bias = 1
    
    # epsilon soft
    epsilon = .01
    
    # temperatue for softmax
    t = temp
    
    nsteps = 100
    context = 2 # init context
    
    perf_arr = np.zeros((4,3))
    
    # lists used for displaying performance graph
    AX_data_pts = []
    BX_data_pts = []
    AY_data_pts = []
    BY_data_pts = []
    
    for trial in range(ntrials):
        
        context = np.random.choice([0,1],p=[.5,.5]) # choose context signal
        
        if context == 0:
            reward = reward1 # use reward function 1
        elif context == 1:
            reward = reward2 # use reward function 2
            
        cue,probe = 2,2 # init with no cue/probe signal
        current_inner_wm = 2 # init with nothing in wm
        current_outer_wm = 2 # init with nothing in wm
        
        values = np.dot(computed[context,cue,probe,:,current_inner_wm,:,:],W) + bias
        #print(values.shape)
        sm_prob = softmax(values,t)
        possible_wm = np.unique(np.array([context]))
        
        wm_action = argmax(sm_prob,possible_wm)
        current_outer_wm = wm_action[0]
        action = wm_action[1]
        #print('ContexWM:',current_outer_wm)
        #print(context,current_outer_wm,action)
        # epsilon goes here
        if random.random() < epsilon:
            current_outer_wm = random.randint(0,nouter_wm)
            action = random.randrange(nactions)
            
        value = values[current_outer_wm,action]
        eligibility = np.zeros(n)
        global_context = context # used for reward function
        
        
        
        for step in range(nsteps):
            r = reward3[action] # # reward function, may not be needed
            
            # absorb reward
            if trial%50 == 0:
                eligibility = computed[context,cue,probe,current_outer_wm,current_inner_wm,action,:] + td_lambda*eligibility
                error = r - value
                W += lrate*log_transform(error)*eligibility
                break
            
            pvalue = value
            paction = action
            pcontext = context
            pcue = cue
            pprobe = probe
            p_outer_wm = current_outer_wm
            p_inner_wm = current_inner_wm
            
            eligibility = computed[context,cue,probe,current_outer_wm,current_inner_wm,action,:] + td_lambda*eligibility
            
            cue = random.randint(0,1) # get cue signal
            global_cue = cue # used for reward function
            probe = 2
            context = 2
            
            values = np.dot(computed[context,cue,probe,current_outer_wm,:,:,:],W) + bias
            sm_prob = softmax(values,t)
            possible_wm = np.unique(np.array([2,cue]))
            wm_action = argmax(sm_prob,possible_wm)
            current_inner_wm = wm_action[0]
            action = wm_action[1]
            # epsilon goes here
            if random.random() < epsilon:
                current_inner_wm = random.randint(0,ninner_wm) 
                action = random.randrange(nactions)
            
            value = values[current_inner_wm,action]
            error = (r+gamma*value)-pvalue
            W += lrate*log_transform(error)*eligibility
            
            ###########################################
            pvalue = value
            paction = action
            pcontext = context
            pcue = cue
            pprobe = probe
            p_outer_wm = current_outer_wm
            p_inner_wm = current_inner_wm
            
            eligibility = computed[context,cue,probe,current_outer_wm,current_inner_wm,action,:] + td_lambda*eligibility
            
            cue = 2
            probe = random.randint(0,1) # get probe signal
            context = 2
            global_probe = probe # used for reward function
            
            values = np.dot(computed[context,cue,probe,current_outer_wm,current_inner_wm,:,:],W) + bias
            sm_prob = softmax(values,t)
            #possible_wm = np.unique(np.array([2,signal]))
            #wm_action = argmax(sm_prob,possible_wm)
            #current_inner_wm = wm_action[0]
            #action = wm_action[1]
            action = np.argmax(sm_prob)
            # epsilon goes here
            if random.random() < epsilon:
                action = random.randrange(nactions)
            
            #value = values[current_inner_wm,action]
            r = reward[global_cue,global_probe,action]
            value = values[action]
            error = (r+gamma*value)-pvalue
            W += lrate*log_transform(error)*eligibility
            #print(global_context,global_cue,global_probe,action,'reward:',r)
            
            
            performance(global_cue,global_probe,action,global_context,perf_arr)
            
            #print(AX_data_pts)
        ########################################
        #print(context,cue,probe,action)
        if trial%1000==0:
            print('Trial:',trial,end='\n\n')
            print(format('','>10s'),format('count','>12s'),format('performance','>20s'))
            print(format('AX |','<10s'),format(perf_arr[0,0],'>12.1f'),format(perf_arr[0,2],'>20.2%'))
            print(format('BX |','<10s'),format(perf_arr[1,0],'>12.1f'),format(perf_arr[1,2],'>20.2%'))
            print(format('AY |','<10s'),format(perf_arr[2,0],'>12.1f'),format(perf_arr[2,2],'>20.2%'))
            print(format('BY |','<10s'),format(perf_arr[3,0],'>12.1f'),format(perf_arr[3,2],'>20.2%'))
            
            AX_data_pts.append(perf_arr[0,2])
            BX_data_pts.append(perf_arr[1,2])
            AY_data_pts.append(perf_arr[2,2])
            BY_data_pts.append(perf_arr[3,2])
            
            perf_arr = np.zeros((4,3))
            print(end='\n\n')
        
    print('Trial:',trial,end='\n\n')
    print(format('','>10s'),format('count','>12s'),format('performance','>20s'))
    print(format('AX |','<10s'),format(perf_arr[0,0],'>12.1f'),format(perf_arr[0,2],'>20.2%'))
    print(format('BX |','<10s'),format(perf_arr[1,0],'>12.1f'),format(perf_arr[1,2],'>20.2%'))
    print(format('AY |','<10s'),format(perf_arr[2,0],'>12.1f'),format(perf_arr[2,2],'>20.2%'))
    print(format('BY |','<10s'),format(perf_arr[3,0],'>12.1f'),format(perf_arr[3,2],'>20.2%'))

    print(end='\n\n')
    
    V1,V2,V3,V4 = AX_data_pts,BX_data_pts,AY_data_pts,BY_data_pts
    
    plotly.offline.iplot([
            dict(x=[x for x in range(len(V1))] , y=V1, type='scatter',name='AX'),
            dict(x=[x for x in range(len(V1))] , y=V2, type='scatter',name='BX'),
            dict(x=[x for x in range(len(V1))] , y=V3, type='scatter',name='AY'),
            dict(x=[x for x in range(len(V1))] , y=V4, type='scatter',name='BY')
    ])

In [9]:
TD(10000,.4,.9,.8,.1,0,0)
# (num trials, learning rate, discount factor, lambda, temperature, decay factor, decay time steps)

Trial: 0

                  count          performance
AX |                0.0                0.00%
BX |                0.0                0.00%
AY |                0.0                0.00%
BY |                0.0                0.00%


Trial: 1000

                  count          performance
AX |            24384.0               94.87%
BX |            24659.0               94.46%
AY |            24538.0               91.66%
BY |            24419.0               90.09%


Trial: 2000

                  count          performance
AX |            24259.0               93.30%
BX |            24361.0               95.30%
AY |            24578.0               97.67%
BY |            24802.0               92.68%


Trial: 3000

                  count          performance
AX |            24397.0               93.22%
BX |            24554.0               95.85%
AY |            24608.0               94.18%
BY |            24441.0               89.89%


Trial: 4000

                  count       

In [40]:
x = np.array([1,2,3])
print(x[None,1].shape)

(1,)


In [64]:
np.argmax(np.take(x,[0,2]))
#mask = [True,True,False]

1

In [40]:
np.where(x>2,2,1)

array([1, 1, 2])

In [41]:
x[mask]

array([1, 2])

In [146]:
count = 12
perf = .45
print(format('','>20s'),format('count','>12s'),format('performance','>20s'))
print(format('AX |','<20s'),format(count,'>12d'),format(perf,'>20%'))
print(format('AY |','<20s'),format(count,'>12d'),format(perf,'>20%'))
print(format('BX |','<20s'),format(count,'>12d'),format(perf,'>20%'))
print(format('BY |','<20s'),format(count,'>12d'),format(perf,'>20%'))

                            count          performance
AX |                           12           45.000000%
AY |                           12           45.000000%
BX |                           12           45.000000%
BY |                           12           45.000000%


In [91]:
reward(4,5)

False False False False False False
