# Equations

$ \delta = [R_{t+1} + \gamma Q(s_{t+1}, a_{t+1})] - Q(s_t,a_t) $

$ Q(s_t,a_t) = Q(s_t,a_t) + \alpha\delta $

## Working Memory With SARSA

In [None]:
#from plotly.graph_objs import Scatter, Layout
import matplotlib.pyplot as plt
import plotly
import numpy as np
import random
import hrr
import math
from plotly.graph_objs import Scatter, Layout, Surface
plotly.offline.init_notebook_mode(connected=True)

In [None]:
## input:  float value
## returns:  log modulus tranformation of input value (error)
def log_transform(error):
    return math.copysign(1.0,error)*math.log(math.fabs(error)+1,2)

In [None]:
## inputs:  2D numpy array, list of integers (index of rows in array)
## returns:  index of the maximum value in the array 
##           within the restricted row list (restrict)
def argmax(arr_2d,restrict):
    max_row = 0
    max_col = 0
    max_value = arr_2d[0,0]
    for row in range(arr_2d.shape[0]):
        if row not in restrict:
            continue
        for col in range(arr_2d.shape[1]):
            if arr_2d[row,col] > max_value:
                max_value = arr_2d[row,col]
                max_row,max_col = row,col
    return list((max_row,max_col))

# Problem Description:
    This algorithm solves a 2 level hierarchy working memory problem with the outer context consisting of 3 choices (0,1,2) and the inner context consisting of 3 choices (0,1,2).  After every 10 episodes the outer context is switched and held in place.  Reward is given at outer context 0, inner context 0, at state 0.  A reward is also given at outer context 1, inner context 1, at state n/2.
    
    Cue: 0 denotes red, 1 denotes green, 2 denotes seeing nothing
    Outer Context: 0 denotes blue, 1 denotes purple, 2 denotes nothing in wm
    Inner Context: 0 denotes red, 1 denotes green, 2 denotes nothing in wm
    Actions: 0 denotes left, 1 denotes right

In [None]:
def TD(nstates,nepisodes,lrate,gamma,td_lambda): 
    #n = 2048
    #n = 4096
    #n = 8192
    #n = 16000
    n = 32000
    #n = 64000
    #nstates = 50
    nactions = 2
    nslots = 2
    ncolors = 2
    n_outer = 2
    #goal for red is at 0, green at middle
    goal = [0,nstates//2]
    reward = np.zeros((n_outer+1,nslots+1,nstates))
    
    # reward matrix for each inner and outer context
    ## outer context of 0 and cue 0 has reward at state 0
    ## outer context of 1 and cue 1 has reward at state n/2
    for x in range(nslots):
        reward[x,x,goal[x]] = 1
    
    # basic actions are left and right
    states = hrr.hrrs(n,nstates)
    actions = hrr.hrrs(n,nactions)
    
    # identity vector
    hrr_i = np.zeros(n)
    hrr_i[0] = 1
    
    # external color (seeing the color)
    external = hrr.hrrs(n,nslots)
    external = np.row_stack((external,hrr_i))
    
    # Outer WorkingMemory
    outer_wm_slots = hrr.hrrs(n,nslots)
    outer_wm_slots = np.row_stack((outer_wm_slots,hrr_i))
    
    # Inner WorkingMemory
    inner_wm_slots = hrr.hrrs(n,nslots)
    inner_wm_slots = np.row_stack((inner_wm_slots,hrr_i))
    
    # precomputed state/action/working_memory triplet
    stateactions = hrr.oconvolve(actions,states)
    s_a_wm = hrr.oconvolve(stateactions,outer_wm_slots)
    s_s_a_wm = hrr.oconvolve(s_a_wm,inner_wm_slots)
    s_s_s_a_wm = hrr.oconvolve(s_s_a_wm,external)
    s_s_s_a_wm = np.reshape(s_s_s_a_wm,(nslots+1,nslots+1,nslots+1,nstates,nactions,n))
    
    # weight vector
    W = hrr.hrr(n)
    bias = 1

    #lrate = 0.1
    eligibility = np.zeros(n)
    #gamma = 0.9
    #td_lambda = 0.5
    epsilon = 0.01
    #nepisodes = 10000
    nsteps = 100

    # set working memory to nothing initially(inner working memory).
    ## initialized outside of loop so that current_wm is not flushed
    ## between episodes.
    current_wm = 2
    
    # set context to nothing initially(outer working memory)
    outer_context = 2
    ## keeps track of context layer
    ## r
    context_matrix = np.array([[0,0],[1,1],[None,None]])
    color_signal = -1
    for episode in range(nepisodes):    
        state = random.randrange(0,nstates)
        
        # cue to signal context
        color_signal = (color_signal+1)%len(external)
        # set external cue
        cue = color_signal
        color = cue
        
        ## change reward and switch context
        if episode%10==0:
            outer_context = (outer_context+1)%len(outer_wm_slots)
        
        values = np.dot(s_s_s_a_wm[color,outer_context,current_wm,state,:,:],W) + bias
        
        action = values.argmax()
        # returns index (row,col) of max value
        #color_action = np.unravel_index(values.argmax(), values.shape)
        #current_wm = color_action[0]
        #action = color_action[1]
        if random.random() < epsilon:
            action = random.randrange(0,nactions)
        values = values[action]
        eligibility = np.zeros(n)
        
        # turn signal off
        
        for step in range(nsteps):
            r = reward[outer_context,color,state]
            ## use context matrix to determine if agent should get a reward
            if ((context_matrix[outer_context,0]==outer_context and context_matrix[outer_context,1]==color) 
            and (color < 2 and state == goal[color])):
                eligibility = s_s_s_a_wm[cue,outer_context,current_wm,state,action,:] + td_lambda*eligibility
                #error = r - values[action]
                error = r - values
                #W += lrate*error*eligibility
                W += lrate*log_transform(error)*eligibility
                #print(cue,current_wm)
                #print('episode:',episode)
                #print('outer:',outer_context,'color:',color,'state:',state)
                #print('reward:',r)
                break
                
            ## updates    
            pstate = state
            pvalues = values
            paction = action
            previous_wm = current_wm
            p_outer_wm = outer_context
            psignal = cue
            #####
            
            eligibility = s_s_s_a_wm[cue,outer_context,current_wm,state,action,:] + td_lambda*eligibility
            
            state = ((state+np.array([-1,1]))%nstates)[action]
            
            ## turn off cue
            cue = 2
            
            values = np.dot(s_s_s_a_wm[cue,outer_context,:,state,:,:],W) + bias 
            #action = values.argmax()
            possible_wm = np.unique(np.array([2,previous_wm,psignal]))
            
            #wm_action = np.unravel_index(values.argmax(), values.shape)
            wm_action = argmax(values,possible_wm)
            
            current_wm = wm_action[0]
            action = wm_action[1]
            #values = values[current_wm,action]
            
            #wm_values = np.dot(s_s_a_wm[cue,:,state,action,:],W) + bias
            #print(wm_values)
            #current_wm = np.unique(np.array([wm_values[2],wm_values[previous_wm]],wm_values[psignal])).argmax()
            
            if random.random() < epsilon:
                action = random.randrange(0,nactions)
                current_wm = random.randrange(0,nslots+1)
            
            values = values[current_wm,action]
            #error = (r+gamma*values[action])-pvalues[paction]
            ## calculate error (observation - expectation)
            error = (r+gamma*values)-pvalues
            
            ## update the weight vector
            W += lrate*log_transform(error)*eligibility
        
        if episode%1000==0:
            V1 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,0,0,:,0,:]))
            V2 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,0,0,:,1,:]))
            V3 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,1,1,:,0,:]))
            V4 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,1,1,:,1,:]))
            V1 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,0,0,:,0,:]))
            V2 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,0,0,:,1,:]))
            V3 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,1,1,:,0,:]))
            V4 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,1,1,:,1,:]))
            plotly.offline.iplot([
            dict(x=[x for x in range(len(V1))] , y=V1, type='scatter',name='left and cue_0 and wm_0 and outer_0'),
            dict(x=[x for x in range(len(V1))] , y=V2, type='scatter',name='right and cue_0 and wm_0 and outer_0' ),
            dict(x=[x for x in range(len(V1))] , y=V3, type='scatter',name='left and cue_1 and wm_1 and outer_1'),
            dict(x=[x for x in range(len(V1))] , y=V4, type='scatter',name='right and cue_1 and wm_1 and outer_1'),
            ])
    
       
    V1 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,0,0,:,0,:]))
    V2 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,0,0,:,1,:]))
    V3 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,1,0,:,0,:]))
    V4 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,1,0,:,1,:]))
    V5 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,2,0,:,0,:]))
    V6 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,2,0,:,1,:]))
    V7 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,2,0,:,0,:]))
    V8 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,2,0,:,1,:]))
    V9 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,1,0,:,0,:]))
    V10 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[0,1,0,:,1,:]))
    V11 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,0,0,:,0,:]))
    V12 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[1,0,0,:,1,:]))
    V13 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[2,2,0,:,0,:]))
    V14 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[2,2,0,:,1,:]))
    V15 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[2,1,0,:,0,:]))
    V16 = list(map(lambda x: np.dot(x,W)+bias, s_s_s_a_wm[2,1,0,:,1,:]))
    
    plotly.offline.iplot([
    dict(x=[x for x in range(len(V1))] , y=V1, type='scatter',name='left and cue_0 and wm_0 and outer_0'),
    dict(x=[x for x in range(len(V1))] , y=V2, type='scatter',name='right and cue_0 and wm_0 and outer_0'),
    dict(x=[x for x in range(len(V1))] , y=V3, type='scatter',name='left and cue_1 and wm_1 and outer_0'),
    dict(x=[x for x in range(len(V1))] , y=V4, type='scatter',name='right and cue_1 and wm_1 and outer_0'),
    
    dict(x=[x for x in range(len(V1))] , y=V5, type='scatter',name='left and cue_1 and wm_2'),
    dict(x=[x for x in range(len(V1))] , y=V6, type='scatter',name='right and cue_1 and wm_2'),
    dict(x=[x for x in range(len(V1))] , y=V7, type='scatter',name='left and cue_0 and wm_2'),
    dict(x=[x for x in range(len(V1))] , y=V8, type='scatter',name='right and cue_0 and wm_2'),
    dict(x=[x for x in range(len(V1))] , y=V9, type='scatter',name='left and cue_0 and wm_1'),
    dict(x=[x for x in range(len(V1))] , y=V10, type='scatter',name='right and cue_0 and wm_1'),
    dict(x=[x for x in range(len(V1))] , y=V11, type='scatter',name='left and cue_1 and wm_0'),
    dict(x=[x for x in range(len(V1))] , y=V12, type='scatter',name='right and cue_1 and wm_0'),
    dict(x=[x for x in range(len(V1))] , y=V13, type='scatter',name='left and cue_2 and wm_2'),
    dict(x=[x for x in range(len(V1))] , y=V14, type='scatter',name='right and cue_2 and wm_2'),
    dict(x=[x for x in range(len(V1))] , y=V15, type='scatter',name='left and cue_2 and wm_1'),
    dict(x=[x for x in range(len(V1))] , y=V16, type='scatter',name='right and cue_2 and wm_1')
    ])
    

In [None]:
TD(10,50000,.1,.9,.5)
#inputs: nstates,nepisodes,lrate,gamma,td_lambda

# Testing Stuff Below