# Equations

$ \delta = [R_{t+1} + \gamma Q(s_{t+1}, a_{t+1})] - Q(s_t,a_t) $

$ Q(s_t,a_t) = Q(s_t,a_t) + \alpha\delta $

## Working Memory With SARSA

In [1]:
#from plotly.graph_objs import Scatter, Layout
import matplotlib.pyplot as plt
import plotly
import numpy as np
import random
import hrr
import math
from plotly.graph_objs import Scatter, Layout, Surface
plotly.offline.init_notebook_mode(connected=True)

In [2]:
def log_transform(error):
    return math.copysign(1.0,error)*math.log(math.fabs(error)+1,2)

In [3]:
def TD(nstates,nepisodes,lrate,gamma,td_lambda): 
    n = 2048
    #n = 4096
    #n = 8192
    n = 16000
    #nstates = 50
    nactions = 2
    nslots = 2
    ncolors = 2
    
    #goal for red is at 0, green at middle
    goal = [0,nstates//2]
    reward = np.zeros((nslots,nstates))
    
    # reward matrix for each context
    for x in range(nslots):
        reward[x,goal[x]] = 1
    
    # basic actions are left and right
    states = hrr.hrrs(n,nstates)
    actions = hrr.hrrs(n,nactions)
    
    # identity vector
    hrr_i = np.zeros(n)
    hrr_i[0] = 1
    
    # external color
    external = hrr.hrrs(n,nslots)
    external = np.row_stack((external,hrr_i))
    
    # WorkingMemory
    wm_slots = hrr.hrrs(n,nslots)
    wm_slots = np.row_stack((wm_slots,hrr_i))
    
    # precomputed state/action/working_memory triplet
    stateactions = hrr.oconvolve(actions,states)
    s_a_wm = hrr.oconvolve(stateactions,wm_slots)
    s_s_a_wm = hrr.oconvolve(s_a_wm,external)
    s_s_a_wm = np.reshape(s_s_a_wm,(nslots+1,nslots+1,nstates,nactions,n))
    
    # External representation of color
    #ncolors = 2
    #colors = hrr.hrrs(n,ncolors)
    
    # weight vector
    W = hrr.hrr(n)
    bias = 1

    #lrate = 0.1
    eligibility = np.zeros(n)
    #gamma = 0.9
    #td_lambda = 0.5
    epsilon = 0.01
    #nepisodes = 10000
    nsteps = 100
    
    for episode in range(nepisodes):
        state = random.randrange(0,nstates)
        
        # cue to signal context
        color_signal = random.randrange(0,ncolors)
        # set external cue
        cue = color_signal
        color = cue
        # set context 
        current_wm = color
        
        values = np.dot(s_s_a_wm[color,current_wm,state,:,:],W) + bias
        
        action = values.argmax()
        # returns index (row,col) of max value
        #color_action = np.unravel_index(values.argmax(), values.shape)
        #current_wm = color_action[0]
        #action = color_action[1]
        if random.random() < epsilon:
            action = random.randrange(0,nactions)
            
        eligibility = np.zeros(n)
        
        # turn signal off
        
        for step in range(nsteps):
            r = reward[color,state]
            if state == goal[color]:
                eligibility = s_s_a_wm[cue,current_wm,state,action,:] + td_lambda*eligibility
                error = r - values[action]
                #W += lrate*error*eligibility
                W += lrate*log_transform(error)*eligibility
                #print(cue,current_wm)
                break
                
            pstate = state
            pvalues = values
            paction = action
            previous_wm = current_wm
            psignal = cue
            
            eligibility = s_s_a_wm[cue,current_wm,state,action,:] + td_lambda*eligibility
            
            state = ((state+np.array([-1,1]))%nstates)[action]
            
            values = np.dot(s_s_a_wm[cue,current_wm,state,:,:],W) + bias 
            action = values.argmax()
            #color_action = np.unravel_index(values.argmax(), values.shape)
            #print('values:',values)
            #print('current:',current_wm)
            #print('cue:',cue)
            #print(color_action)
            wm_values = np.dot(s_s_a_wm[cue,:,state,action,:],W) + bias
            #print(wm_values)
            current_wm = np.array([wm_values[2],wm_values[previous_wm]],wm_values[psignal]).argmax()
            #print('current',current_wm)
            #print('current_wm:',current_wm)
            #action = color_action[1]
            #print('action:',color_action[1])
            if random.random() < epsilon:
                action = random.randrange(0,nactions)
    
            error = (r+gamma*values[action])-pvalues[paction]
            #print('error:',error)
            #W += lrate*error*eligibility
            W += lrate*log_transform(error)*eligibility
            # turn signal off

            cue = 2
        
        if episode%1000==0:
            V1 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,0,:,0,:]))
            V2 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,0,:,1,:]))
            V3 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,1,:,0,:]))
            V4 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,1,:,1,:]))
            plotly.offline.iplot([
            dict(x=[x for x in range(len(V1))] , y=V1, type='scatter',name='left and cue_red and wm_red'),
            dict(x=[x for x in range(len(V1))] , y=V2, type='scatter',name='right and cue_red and wm_red'),
            dict(x=[x for x in range(len(V1))] , y=V3, type='scatter',name='left and cue_green and wm_green'),
            dict(x=[x for x in range(len(V1))] , y=V4, type='scatter',name='right and cue_green and wm_green'),
            ])
            
    V1 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,0,:,0,:]))
    V2 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,0,:,1,:]))
    V3 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,1,:,0,:]))
    V4 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,1,:,1,:]))
    V5 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,2,:,0,:]))
    V6 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,2,:,1,:]))
    V7 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,2,:,0,:]))
    V8 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,2,:,1,:]))
    V9 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,1,:,0,:]))
    V10 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[0,1,:,1,:]))
    V11 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,0,:,0,:]))
    V12 = list(map(lambda x: np.dot(x,W)+bias, s_s_a_wm[1,0,:,1,:]))
    
    plotly.offline.iplot([
    dict(x=[x for x in range(len(V1))] , y=V1, type='scatter',name='left and cue_red and wm_red'),
    dict(x=[x for x in range(len(V1))] , y=V2, type='scatter',name='right and cue_red and wm_red'),
    dict(x=[x for x in range(len(V1))] , y=V3, type='scatter',name='left and cue_green and wm_green'),
    dict(x=[x for x in range(len(V1))] , y=V4, type='scatter',name='right and cue_green and wm_green'),
    #dict(x=[x for x in range(len(V1))] , y=V5, type='scatter',name='left and cue_green and wm_nothing'),
    #dict(x=[x for x in range(len(V1))] , y=V6, type='scatter',name='right and cue_green and wm_nothing'),
    #dict(x=[x for x in range(len(V1))] , y=V7, type='scatter',name='left and cue_red and wm_nothing'),
    #dict(x=[x for x in range(len(V1))] , y=V8, type='scatter',name='right and cue_red and wm_nothing'),
    #dict(x=[x for x in range(len(V1))] , y=V9, type='scatter',name='left and cue_green and wm_nothing'),
    #dict(x=[x for x in range(len(V1))] , y=V10, type='scatter',name='right and cue_green and wm_nothing'),
    #dict(x=[x for x in range(len(V1))] , y=V11, type='scatter',name='left and cue_red and wm_nothing'),
    #dict(x=[x for x in range(len(V1))] , y=V12, type='scatter',name='right and cue_red and wm_nothing')
    ])

In [5]:
TD(50,10000,.1,.9,.5)
#inputs: nstates,nepisodes,lrate,gamma,td_lambda

# Testing Stuff Below