In [1]:
import os

import numpy as np
import importlib
import gymnasium as gym
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import patches
import matplotlib.animation as manimation

import gym_env
import utils
from utils import create_mapping, get_transition_matrix, create_mapping_nb, get_full_maze_values

### Helper Functions

In [2]:
# Define the environment
def gen_nhb_exp():
    envstep=[]
    for s in range(6):
        # actions 0=left, 1=right
        envstep.append([[0,0], [0,0]])  # [s', done]
    envstep = np.array(envstep)
    # State 0 -> 1, 2
    envstep[0,0] = [1,0]
    envstep[0,1] = [2,0]
    # State 1 -> 3, 4
    envstep[1,0] = [3, 1]
    envstep[1,1] = [4, 1]
    # State 2 -> 4, 5
    envstep[2,0] = [4, 1]
    envstep[2,1] = [5, 1]
    
    return envstep

def policy_reval(agent):
    """
    The New environment is the same as the old one except we 
    
    Args:
    agent (LinearRL class) : The LinearRL agent

    Returns:
    V_new (array) : New value of each state
    """
    r_new = agent.r
    expr_new = np.exp(r_new[agent.terminals] / agent._lambda)
    Z_new = np.zeros(len(r_new))

    Z_new[~agent.terminals] = agent.DR[~agent.terminals][:,~agent.terminals] @ agent.P @ expr_new
    Z_new[agent.terminals] = expr_new
    V_new = np.round(np.log(Z_new), 2)

    return V_new, Z_new

def decision_policy(agent, Z):
    """
    Performs matrix version of equation 6 from the LinearRL paper

    Args:
    agent (LinearRL class) : The LinearRL agent

    Returns:
    pii (array) : The decision policy
    """
    G = np.dot(agent.T, Z)

    expv_tiled = np.tile(Z, (len(Z), 1))
    G = G.reshape(-1, 1)
    
    zg = expv_tiled / G
    pii = agent.T * zg

    return pii

def plot_decision_prob(probs_train, probs_test):
    """
    Plots the decision probability of going towards a terminal state

    Args:
    probs_train (array) : Probability of heading towards each terminal state before policy revaluation
    probs_test (array) : probability of heading towrads each terminal state after policy revaluation
    """
    color_palette = sns.color_palette("colorblind")
    colors = [color_palette[3], color_palette[2]]
    print(colors)
    bar_positions_training = np.arange(len(probs_train)) * 0.4
    bar_positions_test = np.arange(len(probs_train)) * 0.4 + 1
    # bar_positions_training = np.array([0, 0.4])  # Bar positions for training (s1 and s2)
    # bar_positions_test = np.array([1, 1.4])  # Bar positions for test (s1 and s2)

    plt.bar(bar_positions_training, probs_train, width=0.3, color=colors, edgecolor='black')
    plt.bar(bar_positions_test, probs_test, width=0.3, color=colors, edgecolor='black')

    handles = [plt.Rectangle((0,0),1,1, facecolor=colors[i], edgecolor='black') for i in range(len(probs_train))]
    plt.legend(handles, [f'State {i+1}' for i in range(len(probs_train))], title='States', loc='upper right')
    
    plt.ylabel('Probabilities')
    plt.xticks([0.2, 1.2], ['Training', 'Test'])
    plt.rcParams['font.family'] = 'serif'

    plt.show()

### LinearRL

In [383]:
class LinearRL_NHB:
    def __init__(self, alpha=0.25, beta=10, gamma=0.904, _lambda=10, epsilon=0.4, num_steps=25000, policy="softmax", imp_samp=True, exp_type="policy_reval"):
        # Hard code start and end locations as well as size
        self.start_loc = 0
        self.target_locs = [3,4,5]
        self.size = 6
        self.agent_loc = self.start_loc
        self.exp_type = exp_type

        # Construct the transition probability matrix and envstep functions
        self.T = self.construct_T()
        self.envstep = gen_nhb_exp()
        
        # Get terminal states
        self.terminals = np.diag(self.T) == 1
        # Calculate P = T_{NT}
        self.P = self.T[~self.terminals][:,self.terminals]

        # Set reward
        self.reward_nt = -1   # Non-terminal state reward (set to 0 for SR)
        self.r = np.full(len(self.T), self.reward_nt)
        # Reward of terminal states depends on if we are replicating reward revaluation or policy revaluation
        if self.exp_type == "policy_reval":
            self.r_term_1 = [0, 15, 30]
            self.r_term_2 = [45, 15, 30]
        elif self.exp_type == "reward_reval":
            self.r_term_1 = [15, 0, 30]
            self.r_term_2 = [45, 0, 30]
        else:
            print("Incorrect experiment type (exp_type)")
            return(0)
        self.r[self.terminals] = self.r_term_1

        # Precalculate exp(r) for use with LinearRL equations
        self.expr_t = np.exp(self.r[self.terminals] / _lambda)
        self.expr_nt = np.exp(self.reward_nt / _lambda)

        # Params
        self.alpha = alpha
        self.beta = beta
        self.gamma = self.expr_nt
        self._lambda = _lambda
        self.epsilon = epsilon
        self.num_steps = num_steps
        self.policy = policy
        self.imp_samp = imp_samp

        # Model
        self.DR = self.get_DR()
        self.Z = np.full(self.size, 0.01)

        self.V = np.zeros(self.size)
        self.one_hot = np.eye(self.size)

    def construct_T(self):
        """
        Manually construt the simple two-step task
        """
        T = np.zeros((6, 6))
        T[0,1] = 0.1
        T[0,2] = 0.9
        T[1,3] = 0.1
        T[1,4] = 0.9
        T[2,4] = 0.2
        T[2,5] = 0.8
        T[3:6, 3:6] = np.eye(3)

        P = np.zeros((6, 6))
        P[0, 1:3] = 0.5
        P[1, 3:5] = 0.5
        P[2, 4:6] = 0.5
        P[3:6, 3:6] = np.eye(3)

        return P
    
    def update_term(self):
        """
        Update the terminal state values (experiment dependent)
        """
        self.r[self.terminals] = self.r_term_2

    def get_DR(self):
        """
        Returns the DR initialization based on what decision policy we are using, values are filled with 0.01 if using softmax to avoid div by zero
        """
        if self.policy == "softmax":
            DR = np.full((self.size, self.size), 0.001)
            np.fill_diagonal(DR, 1)
            DR[np.where(self.terminals)[0], np.where(self.terminals)[0]] = (1/(1-self.gamma))
        else:
            DR = np.eye(self.size)

        return DR

    def update_Z(self):
        self.Z[~self.terminals] = self.DR[~self.terminals][:,~self.terminals] @ self.P @ self.expr_t
        self.Z[self.terminals] = self.expr_t

    def update_V(self):
        self.V = np.round(np.log(self.Z), 2)
    
    def get_successor_states(self, state):
        """
        Manually define the successor states based on which state we are in
        """
        return np.where(self.T[state, :] != 0)[0]

    def importance_sampling(self, state, s_prob):
        """
        Performs importance sampling P(x'|x)/u(x'|x). P(.) is the default policy, u(.) is the decision policy
        """
        successor_states = self.get_successor_states(state)
        p = 1/len(successor_states)
        w = p/s_prob
        # print(f"state: {state}, s_prob: {s_prob}")
                
        return w

    def select_action(self, state):
        """
        Action selection based on our policy
        Options are: [random, softmax, egreedy, test]
        """
        if self.policy == "random":
            successor_states = self.get_successor_states(state)
            return np.random.choice(successor_states)
        
        elif self.policy == "softmax":
            successor_states = self.get_successor_states(state)
            action_probs = np.full(2, 0.0)   # We can hardcode this because every state has 2 actions

            v_sum = sum(np.exp((np.log(self.Z[s] + 1e-20)) / self.beta) for s in successor_states)

            # if we don't have enough info, random action
            if v_sum == 0:
                return  np.random.choice([0,1])

            for action in [0,1]:
                new_state, _ = self.envstep[state, action]
                action_probs[action] = np.exp((np.log(self.Z[new_state] + 1e-20)) / self.beta ) / v_sum

            # print(f"state: {state} | action_probs: {action_probs}")
            action = np.random.choice([0,1], p=action_probs)
            s_prob = action_probs[action]

            return action, s_prob

    def get_D_inv(self):
        """
        Calculates the DR directly using matrix inversion, used for testing
        """
        I = np.eye(self.size)
        D_inv = np.linalg.inv(I-self.gamma*self.T)

        return D_inv

    def learn(self):
        """
        Agent explores the maze according to its decision policy and and updates its DR as it goes
        """
        # print(f"Decision Policy: {self.policy}, Number of Iterations: {self.num_steps}, lr={self.alpha}, temperature={self.beta}, importance sampling={self.imp_samp}")
        # Iterate through number of steps
        for i in range(self.num_steps):
            # Agent gets some knowledge of terminal state values
            if i == 2:
                self.Z[self.terminals] = self.expr_t
            # Current state
            state = self.agent_loc

            # Choose action
            if self.policy == "softmax":
                action, s_prob = self.select_action(state)
            else:
                action = self.select_action(state, self.policy)
        
            # Take action
            next_state, done = self.envstep[state, action]
            
            # Importance sampling
            if self.imp_samp:
                w = self.importance_sampling(state, s_prob)
                w = 1 if np.isnan(w) or w == 0 else w
            else:
                w = 1
            
            # Update default representation
            target = self.one_hot[state] + self.gamma * self.DR[next_state]
            self.DR[state] = (1 - self.alpha) * self.DR[state] + self.alpha * target * w

            # Update Z-Values
            self.Z[~self.terminals] = self.DR[~self.terminals][:,~self.terminals] @ self.P @ self.expr_t
            
            if done:
                self.agent_loc = self.start_loc
                continue
            
            # Update state
            state = next_state
            self.agent_loc = state

        # Update DR at terminal state
        self.update_Z()
        self.update_V()
        # self.Z[self.terminals] = np.exp(self.r[self.terminals] / self._lambda)
        # self.V = np.round(np.log(self.Z), 2)

### Successor Representation

In [7]:
class SR_NHB:
    def __init__(self, alpha=0.1, beta=1, gamma=0.904, num_steps=25000, policy="random", exp_type="policy_reval"):
        # Hard code start and end locations as well as size
        self.start_loc = 0
        self.target_locs = [3,4,5]
        self.size = 6
        self.agent_loc = self.start_loc
        self.exp_type = exp_type

        # Construct the transition probability matrix and envstep functions
        self.T = self.construct_T()
        self.envstep = gen_nhb_exp()
        
        # Get terminal states
        self.terminals = np.diag(self.T) == 1

        # Set reward
        self.reward_nt = -1   # Non-terminal state reward (set to 0 for SR)
        self.r = np.full(len(self.T), self.reward_nt)
        # Reward of terminal states depends on if we are replicating reward revaluation or policy revaluation
        if self.exp_type == "policy_reval":
            self.r_term_1 = [0, 15, 30]
            self.r_term_2 = [45, 15, 30]
        elif self.exp_type == "reward_reval":
            self.r_term_1 = [15, 0, 30]
            self.r_term_2 = [45, 0, 30]
        else:
            print("Incorrect experiment type (exp_type)")
            return(0)
        self.r[self.terminals] = self.r_term_1

        # Params
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.num_steps = num_steps
        self.policy = policy

        # Model
        self.SR = np.eye(self.size)
        self.V = np.zeros(self.size)
        self.one_hot = np.eye(self.size)

    def construct_T(self):
        """
        Manually construt the simple two-step task
        """
        T = np.zeros((6, 6))
        T[0,1] = 0.3
        T[0,2] = 0.7
        T[1,3] = 0.1
        T[1,4] = 0.9
        T[2,4] = 0.2
        T[2,5] = 0.8
        T[3:6, 3:6] = np.eye(3)

        P = np.zeros((6, 6))
        P[0, 1:3] = 0.5
        P[1, 3:5] = 0.5
        P[2, 4:6] = 0.5
        P[3:6, 3:6] = np.eye(3)

        return T
    
    def update_term(self):
        """
        Update the terminal state values (experiment dependent)
        """
        self.r[self.terminals] = self.r_term_2

    def update_V(self):
        self.V = self.SR @ self.r
    
    def get_successor_states(self, state):
        """
        Manually define the successor states based on which state we are in
        """
        return np.where(self.T[state, :] != 0)[0]

    def select_action(self, state):
        """
        Action selection based on our policy
        Options are: [random, softmax, egreedy, test]
        """
        if self.policy == "random":
            successor_states = self.get_successor_states(state)
            return np.random.choice(successor_states)
        
        elif self.policy == "softmax":
            successor_states = self.get_successor_states(state)
            action_probs = np.full(2, 0.0)   # We can hardcode this because every state has 2 actions

            V = self.V[successor_states]
            exp_V = np.exp(V / self.beta)
            action_probs = exp_V / exp_V.sum()
    
            action = np.random.choice([0,1], p=action_probs)

            return action

    def learn(self):
        """
        Agent explores the maze according to its decision policy and and updates its DR as it goes
        """
        # Iterate through number of steps
        for i in range(self.num_steps):
            # Current state
            state = self.agent_loc

            # Choose action
            if self.policy == "softmax":
                action = self.select_action(state)
        
            # Take action
            next_state, done = self.envstep[state, action]
            
            # Update default representation
            target = self.one_hot[state] + self.gamma * self.SR[next_state]
            self.SR[state] = (1 - self.alpha) * self.SR[state] + self.alpha * target

            # Update Values
            self.update_V()

            if done:
                self.agent_loc = self.start_loc
                continue
            
            # Update state
            state = next_state
            self.agent_loc = state

        # Update DR at terminal state
        self.update_V()

## Core LRL

In [311]:
def core_lrl(T, c, M=None, D=None, lambd=1):
    # reward vector across all states
    r = -c

    # terminal states
    terminals = np.diag(T) == 1

    # computing M (if not given)
    if M is None:
        ## M-Inv
        L = np.diag(np.exp(c / lambd)) - T
        L = L[~terminals][:, ~terminals]
        M = np.linalg.inv(L)

    # P = T_{NT}
    P = T[~terminals][:, terminals]
    expr = np.exp(r[terminals] / lambd)

    expv_N = M @ P @ expr
    expv = np.zeros(len(r))
    expv[~terminals] = expv_N
    expv[terminals] = np.exp(r[terminals] / lambd)

    # A matrix formulation of equation 6 of manuscript
    G = T @ expv
    # Transpose expv and perform element-wise division
    expv_tiled = np.tile(expv, (len(expv), 1))
    G = G.reshape(-1, 1)
    zg = expv_tiled / G
    pii = T * zg

    return pii, expv, M


In [344]:
T = np.zeros((6, 6))
T[0, 1:3] = 0.5
T[1, 3:5] = 0.5
T[2, 4:6] = 0.5
T[3:6, 3:6] = np.eye(3)

# T = np.zeros((6, 6))
# T[0,1] = 0.1
# T[0,2] = 0.9
# T[1,3] = 0.1
# T[1,4] = 0.9
# T[2,4] = 0.1
# T[2,5] = 0.9
# T[3:6, 3:6] = np.eye(3)

lambd = 100

c1 = -np.array([0, 0, 0, 0, 15, 30]) / lambd
U1, expv1, MNN = core_lrl(T, c1)

c2 = np.copy(c1)
c2[3] = -45 / lambd
U2, expv2, _ = core_lrl(T,c2,MNN)

In [345]:
print(MNN)

[[1.  0.5 0.5]
 [0.  1.  0. ]
 [0.  0.  1. ]]


In [346]:
U1[0]

array([0.        , 0.46257015, 0.53742985, 0.        , 0.        ,
       0.        ])

In [347]:
U2[0]

array([0.        , 0.52083747, 0.47916253, 0.        , 0.        ,
       0.        ])

## DR-Inv Policy Revaluation

In [227]:
# Load agent and solve for the DR using inverse equations
agent = LinearRL_NHB()
D_inv = agent.get_D_inv()
# # Use inverse DR to get value function
agent.DR = D_inv
agent.update_Z()
agent.update_V()
print("Value before terminal change: ")
print(agent.V)
# # Update the terminal states and resolve for value function
agent.update_term()
V_new, _ = policy_reval(agent)
print("\nValue after terminal change: ")
print(V_new)


Value before terminal change: 
[1.92 1.01 2.51 0.   1.5  3.  ]

Value after terminal change: 
[3.29 3.86 2.51 4.5  1.5  3.  ]


## DR-TD Policy Revaluation

### With importance sampling

In [125]:
agent_with_imp = LinearRL_NHB(_lambda=4, alpha=0.25, beta=4, num_steps=500, policy="softmax", imp_samp=True)
D_inv = agent_with_imp.get_D_inv()
agent_with_imp.learn()

In [126]:
# print("Value before terminal change: ")
# agent_with_imp.update_V()
# print(agent_with_imp.V)
# # Update the terminal states and resolve for value function
# agent_with_imp.update_term()
# agent_with_imp.learn_term(n=5)
# # V_new = policy_reval(agent_with_imp)
# print("\nValue after terminal change: ")
# # print(V_new)
# print(agent_with_imp.V)

### Without importance sampling

In [127]:
agent_no_imp = LinearRL_NHB(_lambda=4, alpha=0.25, beta=4, num_steps=50, policy="softmax", imp_samp=False)
agent_no_imp.learn()

In [128]:
# print("Value before terminal change: ")
# print(agent_no_imp.V)
# # Update the terminal states and resolve for value function
# agent_no_imp.update_term()
# agent_no_imp.learn_term(n=5)
# # V_new = policy_reval(agent_no_imp)
# print("\nValue after terminal change: ")
# # print(V_new)
# print(agent_no_imp.V)

## Average DR

In [384]:
num_iterations = 500

DR_avg_with_imp = np.zeros((6, 6))
DR_avg_no_imp = np.zeros((6,6))
SR_avg = np.zeros((6,6))
# Construct a T for SR that prefers going to the right

for i in range(num_iterations):
    # Define agents
    agent_with_imp = LinearRL_NHB(_lambda=10, alpha=0.25, beta=1, num_steps=250, policy="softmax", imp_samp=True, exp_type="policy_reval")
    agent_no_imp = LinearRL_NHB(_lambda=10, alpha=0.25, beta=1, num_steps=250, policy="softmax", imp_samp=False, exp_type="policy_reval")
    agent_SR = SR_NHB(alpha=0.25, beta=10, num_steps=250, policy="softmax")

    # Have the agents learn the environment
    agent_with_imp.learn()
    agent_no_imp.learn()
    agent_SR.learn()

    # Add to the averages
    DR_avg_with_imp += agent_with_imp.DR
    DR_avg_no_imp += agent_no_imp.DR
    SR_avg += agent_SR.SR

# Take average
DR_avg_with_imp /= num_iterations
DR_avg_no_imp /= num_iterations
SR_avg /= num_iterations

In [385]:
np.set_printoptions(suppress=True)
print("DR-Inv")
print(np.round(agent_with_imp.get_D_inv(), 3))
print("\n DR with Importance Sampling")
print(np.round(DR_avg_with_imp, 3))
print("\n DR without Importance Sampling")
print(np.round(DR_avg_no_imp, 3))
print("\n SR")
print(np.round(SR_avg, 4))

DR-Inv
[[ 1.     0.452  0.452  2.151  4.302  2.151]
 [ 0.     1.     0.     4.754  4.754  0.   ]
 [ 0.     0.     1.     0.     4.754  4.754]
 [ 0.     0.     0.    10.508  0.     0.   ]
 [ 0.     0.     0.     0.    10.508  0.   ]
 [ 0.     0.     0.     0.     0.    10.508]]

 DR with Importance Sampling
[[ 0.989  0.422  0.447  1.873  4.164  2.198]
 [ 0.001  0.986  0.001  4.561  4.78   0.001]
 [ 0.001  0.001  0.971  0.001  4.39   4.836]
 [ 0.001  0.001  0.001 10.508  0.001  0.001]
 [ 0.001  0.001  0.001  0.001 10.508  0.001]
 [ 0.001  0.001  0.001  0.001  0.001 10.508]]

 DR without Importance Sampling
[[ 1.001  0.169  0.737  0.292  2.605  5.703]
 [ 0.001  1.001  0.001  1.662  7.824  0.001]
 [ 0.001  0.001  1.001  0.001  1.769  7.74 ]
 [ 0.001  0.001  0.001 10.508  0.001  0.001]
 [ 0.001  0.001  0.001  0.001 10.508  0.001]
 [ 0.001  0.001  0.001  0.001  0.001 10.508]]

 SR
[[1.     0.1765 0.7275 0.0286 0.2481 0.5398]
 [0.     1.     0.     0.1947 0.7048 0.    ]
 [0.     0.     1.    

In [394]:
agent_no_imp.r

array([-1, -1, -1, 45, 15, 30])

In [395]:
print(SR_avg @ agent_no_imp.r)
z = SR_avg @ agent_no_imp.r
pii = decision_policy(agent_SR, z)
print(pii[0])

[19.30070163 18.33294824 23.72929769 45.         15.         30.        ]
[0.         0.24874657 0.75125343 0.         0.         0.        ]


In [391]:
z = DR_avg_with_imp @ agent_with_imp.T @ np.exp(agent_no_imp.r/15)
print(z)
pii = decision_policy(agent_with_imp, z)
print(pii[0])

[ 73.17917906 115.85442012  52.60259556 211.0929877   28.60947345
  77.68684915]
[0.         0.68773877 0.31226123 0.         0.         0.        ]


In [392]:
z = DR_avg_no_imp @ agent_no_imp.T @ np.exp(agent_no_imp.r/15)
print(z)
pii = decision_policy(agent_no_imp, z)
print(pii[0])

[ 61.68462934  66.06602669  67.08709493 211.0929877   28.60947345
  77.68684915]
[0.         0.49616581 0.50383419 0.         0.         0.        ]


In [390]:
agent_no_imp.update_term()
print(agent_no_imp.r)

[-1 -1 -1 45 15 30]


In [279]:
agent_with_imp.DR = DR_avg_with_imp
agent_with_imp.update_Z()
agent_with_imp.update_V()
pii_old = decision_policy(agent_with_imp, agent_with_imp.Z)
print(agent_with_imp.Z)
agent_with_imp.update_term()
V_new, Z_new = policy_reval(agent_with_imp)
pii_new = decision_policy(agent_with_imp, Z_new)
print(Z_new)

[ 6.73394269  2.66796797 12.09168045  1.          4.48168907 20.08553692]
[25.91311927 44.31865082 12.48377843 90.0171313   4.48168907 20.08553692]


In [134]:
print(pii_old[0])
print(pii_new[0])

[0.         0.17019523 0.82980477 0.         0.         0.        ]
[0.         0.76692429 0.23307571 0.         0.         0.        ]


In [135]:
# plot_decision_prob(probs_train=[0.37420031, 0.62579969], probs_test=[0.87453662, 0.12546338])

In [136]:
agent_no_imp.DR = DR_avg_no_imp
agent_with_imp.update_Z()
agent_no_imp.update_V()
pii_old = decision_policy(agent_no_imp, agent_no_imp.Z)
print(agent_no_imp.Z)
agent_no_imp.update_term()
V_new, Z_new = policy_reval(agent_no_imp)
pii_new = decision_policy(agent_no_imp, Z_new)
print(Z_new)

[11.23768252  2.86368067 12.41955991  1.          4.48168907 20.08553692]
[11.62140615 47.43227685 12.82229007 90.0171313   4.48168907 20.08553692]


In [137]:
print(pii_old[0])
print(pii_new[0])

[0.         0.18737392 0.81262608 0.         0.         0.        ]
[0.         0.78719804 0.21280196 0.         0.         0.        ]


In [105]:
# plot_decision_prob(probs_train=[0.23538717, 0.76461283], probs_test=[0.78205788, 0.21794212])

## SR-TD

In [8]:
agent_SR = SR_NHB(alpha=0.25, beta=10, num_steps=250, policy="softmax")
agent_SR.learn()

In [9]:
agent_SR.V

array([21.01222449, 12.32834614, 25.92754206,  0.        , 15.        ,
       30.        ])

In [10]:
agent_SR.r

array([-1, -1, -1,  0, 15, 30])

In [11]:
print(agent_SR.SR)

[[1.         0.08067442 0.82332558 0.00183083 0.10285329 0.7124475 ]
 [0.         1.         0.         0.01476329 0.88855641 0.        ]
 [0.         0.         1.         0.         0.01283053 0.89116947]
 [0.         0.         0.         1.         0.         0.        ]
 [0.         0.         0.         0.         1.         0.        ]
 [0.         0.         0.         0.         0.         1.        ]]
