# Equations

$ \delta = [R_{t+1} + \gamma Q(s_{t+1}, a_{t+1})] - Q(s_t,a_t) $

$ Q(s_t,a_t) = Q(s_t,a_t) + \alpha\delta $

## Working Memory With SARSA

In [1]:
#from plotly.graph_objs import Scatter, Layout
import matplotlib.pyplot as plt
import plotly
import numpy as np
import random
import hrr
from plotly.graph_objs import Scatter, Layout, Surface
plotly.offline.init_notebook_mode(connected=True)

In [8]:
def TD(nstates,nepisodes,lrate,gamma,td_lambda): 
    n = 2048
    #nstates = 50
    nactions = 2
    nslots = 2
    
    #goal for red is at 0, green at middle
    goal = [0,nstates//2]
    reward = np.zeros((nslots,nstates))
    
    # reward matrix for each context
    for x in range(nslots):
        reward[x,goal[x]] = 1
    
    # basic actions are left and right
    states = hrr.hrrs(n,nstates)
    actions = hrr.hrrs(n,nactions)
    
    # identity vector
    hrr_i = np.zeros(n)
    hrr_i[0] = 1
    
    # external color
    external = hrr.hrrs(n,nslots)
    
    # WorkingMemory
    wm_slots = hrr.hrrs(n,nslots)
    
    # precomputed state/action/working_memory triplet
    stateactions = hrr.oconvolve(actions,states)
    s_a_wm = hrr.oconvolve(stateactions,wm_slots)
    s_s_a_wm = hrr.oconvolve(s_a_wm,external)
    s_s_a_wm = np.reshape(s_s_a_wm,(nslots,nslots,nstates,nactions,n))
    
    # External representation of color
    ncolors = 2
    colors = hrr.hrrs(n,ncolors)
    
    # weight vector
    W = hrr.hrr(n)
    bias = 1

    #lrate = 0.1
    eligibility = np.zeros(n)
    #gamma = 0.9
    #td_lambda = 0.5
    epsilon = 0.05
    #nepisodes = 10000
    nsteps = 100
    
    for episode in range(nepisodes):
        state = random.randrange(0,nstates)
        
        # cue to signal context
        color_signal = random.randrange(0,ncolors)
        # set internal to exteral
        color = color_signal
        values = np.dot(s_s_a_wm[color,color,state,:,:],W) + bias
        
        action = values.argmax()
        # returns index (row,col) of max value
        #color_action = np.unravel_index(values.argmax(), values.shape)
        #color = color_action[0]
        #action = color_action[1]
        if random.random() < epsilon:
            action = random.randrange(0,nactions)
            
        eligibility = np.zeros(n)
        for step in range(nsteps):
            r = reward[color,state]
            if state == goal[color]:
                eligibility = s_s_a_wm[color,color,state,action,:] + td_lambda*gamma*eligibility
                error = r - values[action]
                W += lrate*error*eligibility
                break
                
            pstate = state
            pvalues = values
            paction = action
            
            eligibility = s_s_a_wm[color,color,state,action,:] + td_lambda*gamma*eligibility
            
            state = ((state+np.array([-1,1]))%nstates)[action]
            
            values = np.dot(s_s_a_wm[color,color,state,:,:],W) + bias 
            action = values.argmax()
            #color_action = np.unravel_index(values.argmax(), values.shape)
            #color = color_action[0]
            #action = color_action[1]
            if random.random() < epsilon:
                action = random.randrange(0,nactions)
                
            error = (r+gamma*values[action])-pvalues[paction]
            W += lrate*error*eligibility
            
        
    V1 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,0,:,0,:]))
    V2 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,0,:,1,:]))
    V3 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,1,:,0,:]))
    V4 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,1,:,1,:]))
    
    plotly.offline.iplot([
    dict(x=[x for x in range(len(V1))] , y=V1, type='scatter',name='left and red'),
    dict(x=[x for x in range(len(V1))] , y=V2, type='scatter',name='right and red'),
    dict(x=[x for x in range(len(V1))] , y=V3, type='scatter',name='left and green'),
    dict(x=[x for x in range(len(V1))] , y=V4, type='scatter',name='right and green')
    ])

In [11]:
TD(50,50000,.1,.9,.5)
#inputs: nstates,nepisodes,lrate,gamma,td_lambda

# Testing Stuff Below