# Redspots Environment

## Environment Description

A simple gridworld with white (safe) spots, green (rewarding) spots, and red (undesirable/dangerous) spots.

## Goal

Have an agent map the environment and infer the best way to avoid red spots and get to the green spot.

# Setup

## Imports

In [1]:
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.cm as cm

import seaborn as sns

import pymdp
from pymdp import utils

## Auxilary Functions

In [2]:
from auxilaryfunctions import plot_grid, add_noise, plot_likelihood, plot_beliefs

## Parameters

### Grid

In [3]:
# grid_dimensions
grid_dims = [200, 200]

# list of grid positiions
grid_locations = []
for i in range(grid_dims[0]):
    for j in range(grid_dims[1]):
        grid_locations.append((i,j))

In [4]:
redspots = [(73, 23), (67, 23), (73, 17), (67, 17), (88, 53), (82, 53), (82, 47), (88, 47), (73, 78), (67, 78), (73, 71), (67, 71), (23, 73), (17, 73), (23, 67), (17, 67), (23, 73), (17, 73), (23, 67), (17, 67)]

In [5]:
# start point
agent_pos = (0,0)

# end goal
goal_location = (6,4)

## Generative Model

### States & Observations

s1 = current location \
s2 = attribute of current location

o1 = observed current location \
o2 = color of current location

In [6]:
# s1 is already defined = grid_locations

# s2
current_attribute = ['SAFE', 'DANGER', 'REWARDING']

# s3
right_attribute = ['SAFE', 'DANGER', 'REWARDING', 'N/A']

# s4
left_attribute = ['SAFE', 'DANGER', 'REWARDING', 'N/A']

# s5
up_attribute = ['SAFE', 'DANGER', 'REWARDING', 'N/A']

# s6
down_attribute = ['SAFE', 'DANGER', 'REWARDING', 'N/A']

# list of # of possibillities for states
num_states = [len(grid_locations), len(current_attribute), len(right_attribute), len(left_attribute), len(up_attribute), len(down_attribute)] # location

In [7]:
# o1 is already defined = grid_locatioons

# o2
current_color = ['WHITE', 'RED', 'GREEN']

# o3
right_color = ['WHITE', 'RED', 'GREEN', 'N/A']

# o4
left_color = ['WHITE', 'RED', 'GREEN', 'N/A']

# o5
up_color = ['WHITE', 'RED', 'GREEN', 'N/A']

# o6
down_color = ['WHITE', 'RED', 'GREEN', 'N/A']


# list of # of possibilities for observations
num_obs = [len(grid_locations), len(current_color), len(right_color), len(left_color), len(up_color), len(down_color)]

# Generative Model

## Rule-based Matrix

In [8]:
rule_matrix = np.zeros((num_states[0], num_states[1]))

# Rule-based assignment
for loc in range(num_states[0]):
    # Example: Assume all locations have [SAFE: 0.7, DANGER: 0.2, REWARDING: 0.1]
    rule_matrix[loc] = np.array([0.33, 0.33, 0.33])
    rule_matrix[loc] /= rule_matrix[loc].sum()

    # Normalize each location's attribute distribution (ensure sum = 1)

rule_matrix

array([[0.33333333, 0.33333333, 0.33333333],
       [0.33333333, 0.33333333, 0.33333333],
       [0.33333333, 0.33333333, 0.33333333],
       ...,
       [0.33333333, 0.33333333, 0.33333333],
       [0.33333333, 0.33333333, 0.33333333],
       [0.33333333, 0.33333333, 0.33333333]])

In [9]:
# red, green, white, one-hot encoded
# rule_matrix -> white, red, green -> safe, dangerous, rewarding

def update_rule_matrix(rule_matrix, beliefs):

    # Take current location
    s1 = np.argmax(beliefs[0])

    # Take safety belief
    s2 = np.argmax(beliefs[1])

    # one hot encoding
    if s2 == 0: rule_matrix[s1] = [1, 0, 0]
    elif s2 == 1: rule_matrix[s1] = [0, 1, 0]
    else: rule_matrix[s1] = [0, 0, 1]

    return rule_matrix

## A Matrix

In [10]:
# Define A Matrix
A_shapes = []
for i in num_obs:
    A_shapes.append([i] + num_states)

A = utils.obj_array_zeros(A_shapes)
A.shape

(6,)

: 

### Location Observation Modality A[0]

In [None]:
# Define matrix for location observations
A[0] = np.zeros(A_shapes[0])  # Initialize with zeros

# Create identity mapping for locations, regardless of safety levels
base_mapping = np.eye(num_states[0])  # 40x40 identity matrix for locations

# Fill in the observation mapping for each combination of hidden state factors
for s1 in range(num_states[1]):  # current attribute
    for s2 in range(num_states[2]):  # right attribute
        for s3 in range(num_states[3]):  # left attribute
            for s4 in range(num_states[4]):  # up attribute
                for s5 in range(num_states[5]):  # down attribute
                    A[0][:,:,s1,s2,s3,s4,s5] = base_mapping

# Verify the shape and normalization
print("A[0] shape:", A[0].shape)
print("Column sums:", np.allclose(A[0].sum(axis=0), 1.0))  # Should be True

In [None]:
# Extract a 2D slice by fixing all other dimensions to 0
plot_likelihood(A[0][:,:,0,0,0,0,0], "Location observation likelihood matrix")

### Color observation modality: A[1]

In [401]:
# Map safety levels to indices
safety_level_to_index = {state: i for i, state in enumerate(current_attribute)}  # {'SAFE': 0, 'DANGER': 1, 'REWARDING': 2}

# Probabilities for each color given the safety level (in correct heatmap order: RED, GREEN, WHITE)
probabilities = {
    "SAFE": [1, 0, 0],        # ['WHITE', 'RED', 'GREEN']
    "DANGER": [0, 1, 0.],      # ['WHITE', 'RED', 'GREEN']
    "REWARDING": [0, 0, 1]    # ['WHITE', 'RED', 'GREEN']
}

In [402]:

# Populate A[1] with the mapping from safety states to color observations

for location in range(num_states[0]):  # For each location

    for s1 in range(num_states[1]):  # current attribute

        for s2 in range(num_states[2]):  # right attribute

            for s3 in range(num_states[3]):  # left attribute

                for s4 in range(num_states[4]):  # up attribute

                    for s5 in range(num_states[5]):  # down attribute

                        # Get the probabilities based on the current attribute (s1)

                        if s1 == 0:  # SAFE

                            A[1][:, location, s1, s2, s3, s4, s5] = [1, 0, 0]  # WHITE

                        elif s1 == 1:  # DANGER

                            A[1][:, location, s1, s2, s3, s4, s5] = [0, 1, 0]  # RED

                        else:  # REWARDING

                            A[1][:, location, s1, s2, s3, s4, s5] = [0, 0, 1]  # GREEN


In [None]:
# Create a figure with 3 subplots (one for each safety state)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Extract and plot each safety state slice
for safety_idx in range(3):
    # Get a slice for the current safety state (fixing all other dimensions to 0)
    slice_matrix = A[1][:, :, safety_idx, 0, 0, 0, 0]
    
    # Plot the slice
    sns.heatmap(slice_matrix, ax=axes[safety_idx], cmap='gray', cbar=False, vmin=0.0, vmax=1.0)
    axes[safety_idx].set_title(f'Safety State {safety_idx}')
    axes[safety_idx].set_xlabel('Location States')
    if safety_idx == 0:
        axes[safety_idx].set_ylabel('Color Observations')

plt.tight_layout()
plt.show()

### Adjacent Locations Attributes Observation Modalities - A[2] through A[5]

In [404]:
for i in range(2, 6):

    for location in range(num_states[0]):  # For each location

        for s1 in range(num_states[1]):  # current attribute

            for s2 in range(num_states[2]):  # right attribute

                for s3 in range(num_states[3]):  # left attribute

                    for s4 in range(num_states[4]):  # up attribute

                        for s5 in range(num_states[5]):  # down attribute

                            # Get the probabilities based on the current attribute (s1)

                            if s1 == 0:  # SAFE

                                A[i][:, location, s1, s2, s3, s4, s5] = [1, 0, 0, 0]  # WHITE, RED, GREEN, N/A

                            elif s1 == 1:  # DANGER

                                A[i][:, location, s1, s2, s3, s4, s5] = [0, 1, 0, 0]  # WHITE, RED, GREEN, N/A

                            elif s1 == 2:  # REWARDING

                                A[i][:, location, s1, s2, s3, s4, s5] = [0, 0, 1, 0]  # WHITE, RED, GREEN, N/A

                            else:  # N/A

                                A[i][:, location, s1, s2, s3, s4, s5] = [0, 0, 0, 1]  # WHITE, RED, GREEN, N/A


In [None]:
for i in range(2, 6):
    # Create a figure with 3 subplots (one for each safety state)
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Extract and plot each safety state slice
    for safety_idx in range(3):
        # Get a slice for the current safety state (fixing all other dimensions to 0)
        slice_matrix = A[i][:, :, safety_idx, 0, 0, 0, 0]
        
        # Plot the slice
        sns.heatmap(slice_matrix, ax=axes[safety_idx], cmap='gray', cbar=False, vmin=0.0, vmax=1.0)
        axes[safety_idx].set_title(f'Safety State {safety_idx}')
        axes[safety_idx].set_xlabel('Location States')
        if safety_idx == 0:
            axes[safety_idx].set_ylabel('Color Observations')

    plt.tight_layout()
    plt.show()

### Add Noise

In [None]:
# Add noise to each modality separately
for modality in range(len(A)):
    A[modality] = add_noise(A[modality], noise_level=0)

### Plot Each Matrix

## B Matrix

### Define Shape

In [None]:
num_controls = [5, 1, 1, 1, 1, 1]
B_f_shapes = [ [ns, ns, num_controls[f]] for f, ns in enumerate(num_states)]
B = utils.obj_array_zeros(B_f_shapes)
B_f_shapes

### B[0] - Control Factor - Location Transitions

In [None]:
grid_dims

In [409]:
actions = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"]

In [410]:
for action_id, action_label in enumerate(actions):
  
    for curr_state, (x, y) in enumerate(grid_locations):
        
        # Calculate next position based on action
        if action_label == "UP":
            next_y = max(0, y - 1)        # Move up (decrease y)
            next_x = x
        elif action_label == "DOWN":
            next_y = min(grid_dims[1]-1, y + 1)  # Move down (increase y)
            next_x = x
        elif action_label == "LEFT":
            next_x = max(0, x - 1)        # Move left (decrease x)
            next_y = y
        elif action_label == "RIGHT":
            next_x = min(grid_dims[0]-1, x + 1)  # Move right (increase x)
            next_y = y
        else:  # STAY
            next_x = x
            next_y = y
        
        # Get the state index for the next position
        next_state = grid_locations.index((next_x, next_y))
        
        # Set transition probability to 1.0
        B[0][next_state, curr_state, action_id] = 1.0

### B[1] - Non-Control Factor - Identity Matrix

In [411]:
# Initialize safety level transition matrix (no changes for safety levels)
B[1][:,:,0] = np.eye(3)  # Identity matrix for safety level transitions

### B[2] and beyond

In [412]:
for i in range(2,6):
    B[i] = B[1][:]

### Normalization

In [413]:
# Normalize B matrix columns for each action
for action_id in range(len(actions)):
    # Get slice for current action
    B_action = B[0][..., action_id]
    
    # Replace zero columns with ones in appropriate positions
    zero_cols = (B_action.sum(axis=0) == 0)
    for col in range(B_action.shape[1]):
        if zero_cols[col]:
            # Stay in the same state if no transition is defined
            B_action[col, col] = 1.0
    
    # Normalize columns
    column_sums = B_action.sum(axis=0)
    B[0][..., action_id] = B_action / column_sums[None, :]

# Verify normalization
for action_id in range(len(actions)):
    assert np.allclose(B[0][..., action_id].sum(axis=0), 1.0), f"Action {actions[action_id]} not normalized"

In [None]:
print(B[1])

## C Vectors (prior preferences)

### Initialize

In [None]:
C = utils.obj_array_zeros(num_obs)  # Initialize C array with shape matching num_obs
print(C.shape)

### C[0] - Preference for location observations

In [None]:
goal_location

In [None]:
# Set preferences for state observations (location)
C[0] = np.ones(len(grid_locations))
C[0][grid_locations.index(goal_location)] += 1

print(C[0])

In [None]:
for i, loc in enumerate(grid_locations):
    x = ((goal_location[0] - loc[0])**2 + (goal_location[1] - loc[1])**2) ** 0.5
    print(loc, i, x)
    C[0][i] -= x * 0.1

In [None]:
print(C[0])

In [420]:
from pymdp.maths import softmax

In [421]:
C[0] = softmax(C[0])

In [None]:
plot_beliefs(C[0])

### C[1] - Preference for color observations

In [423]:
# white, red, green <- order it's encoded in

In [None]:
# Set preferences for color observations
C[1] = np.zeros((num_obs[1],))
C[1][0] = -0.1
C[1][1] = -1
C[1][2] = 1.1

print(C[1])


### C[2] and beyond - preference for adjacent color observations

In [None]:
# Set preferences for color observations
for i in range(2,6):
    C[i] = np.zeros((num_obs[i],))
    C[i][0] = -0.01
    C[i][1] = -0.1
    C[i][2] = .11

    print(C[i])


## D Vectors: Prior beliefs about hidden states

### Initialize

In [None]:
# Shape
num_states

In [427]:
D = utils.obj_array_uniform(num_states)

### D[0] - Belief About Current Location

In [None]:
# Define prior belief about agent's location (same as before)
D[0] = np.zeros(num_states[0])  # Shape (35,)
D[0][grid_locations.index(agent_pos)] = 1.0  # One-hot encoding for location

print("D[0] shape (Location prior):", D[0].shape)  # (35,)
D[0]

In [429]:
# # Initialize uniform distribution over locations
# D[0] = np.ones(num_states[0]) / num_states[0]  # Create normalized uniform distribution over all locations

# D[0]

### D[1] - Belief About Attribute of Current Location

In [None]:
D[1] = np.ones(num_states[1]) / num_states[1]  # Create normalized uniform distribution over all locations
D[1].shape

In [None]:
plot_beliefs(D[0])

In [None]:
plot_beliefs(D[1])

### D[2] and beyond - Belief About Attributes of Adjacent Locations

In [None]:
for i in range(2,6):
    D[i] = D[1][:]
    plot_beliefs(D[i])

# Generative Process

In [434]:
def update_vision(current_location, grid_dims, distance):
    """
    Update the agent's field of vision based on the current location and distance
    Returns a list of all grid positions within the vision range
    
    Args:
        current_location (tuple): Current (x,y) position of the agent
        grid_dims (list): Dimensions of the grid [width, height]
        distance (int): Vision range/distance
        
    Returns:
        list: List of (x,y) tuples representing visible grid positions
    """
    x, y = current_location
    x_min = max(0, x - distance)
    x_max = min(grid_dims[0], x + distance + 1)
    y_min = max(0, y - distance)
    y_max = min(grid_dims[1], y + distance + 1)
    
    visible_locations = []
    for y_pos in range(y_min, y_max):
        for x_pos in range(x_min, x_max):
            visible_locations.append((x_pos, y_pos))
            
    return visible_locations

In [435]:
X, Y = 0, 0

In [436]:
class GridWorldEnv():
    def __init__(self, starting_loc=(0, 0), redspots=[(1, 2), (3, 2), (4, 4), (6, 1)], goal=(6,4)):
        # Initialize coordinates
        self.x, self.y = starting_loc
        self.init_loc = starting_loc
        self.current_location = (self.x, self.y)

        self.goal = goal

        self.redspots = redspots

        self.red_obs = ['Null']
        self.green_obs = 'Null'
        self.white_obs = ['Null']

        self.agent_reward = 0
        
        print(f"Starting location is {self.current_location} | Red spot locations are {self.redspots} | Goal is {self.goal}")
    
    def step(self, action_label):
        if action_label == "UP": 
            self.y = max(0, self.y - 1)  # Move up (decrease y)
            
        elif action_label == "DOWN": 
            self.y = min(grid_dims[1] - 1, self.y + 1)  # Move down (increase y)

        elif action_label == "LEFT": 
            self.x = max(0, self.x - 1)  # Move left (decrease x)

        elif action_label == "RIGHT": 
            self.x = min(grid_dims[0] - 1, self.x + 1)  # Move right (increase x)

        # Update current_location tuple after movement
        self.current_location = (self.x, self.y)
        print(f"self.current_location: {self.current_location}")
        
        # Update vision with current coordinates
        self.vision = update_vision(self.current_location, grid_dims, 6)

        self.loc_obs = self.current_location

        # Reset observations at each step
        self.red_obs = ['Null']
        self.white_obs = ['Null']
        self.green_obs = 'Null'

        # Update observations based on vision
        for spot in self.vision:
            if spot in self.redspots:
                if 'Null' in self.red_obs:
                    self.red_obs = [spot]
                else:
                    self.red_obs.append(spot)
            elif spot == self.goal:
                self.green_obs = spot
            else:
                if 'Null' in self.white_obs:
                    self.white_obs = [spot]
                else:
                    self.white_obs.append(spot)

        # Update rewards and observations based on current location
        if self.current_location in self.redspots:
            self.agent_reward -= 5
            if 'Null' in self.red_obs:
                self.red_obs = [self.current_location]
            else:
                self.red_obs.append(self.current_location)
        elif self.current_location == self.goal:
            self.agent_reward += 20
            self.green_obs = self.current_location
        else:
            if 'Null' in self.white_obs:
                self.white_obs = [self.current_location]
            else:
                self.white_obs.append(self.current_location)
        
        return self.loc_obs, self.green_obs, self.white_obs, self.red_obs, self.agent_reward
    
    def reset(self):
        self.x, self.y = self.init_loc
        self.current_location = (self.x, self.y)
        print(f'Re-initialized location to {self.current_location}')
        self.loc_obs = self.current_location
        self.green_obs, self.white_obs, self.red_obs, self.agent_reward = 'Null', ['Null'], ['Null'], 0

        return self.loc_obs, self.green_obs, self.white_obs, self.red_obs, self.agent_reward

# Active Inference

In [437]:
from pymdp.agent import Agent

In [None]:
agent_pos, redspots, goal_location

In [None]:
my_agent = Agent(A=A, B=B, C=C, D=D, policy_len=6)

my_env = GridWorldEnv(starting_loc = agent_pos, redspots=redspots, goal = goal_location)

In [None]:
loc_obs, green_obs, white_obs, red_obs, agent_reward = my_env.reset()
loc_obs, green_obs, white_obs, red_obs, agent_reward

In [None]:
loc_obs, green_obs, white_obs, red_obs, agent_reward = my_env.step('STAY')

In [None]:
loc_obs, green_obs, white_obs, red_obs, agent_reward

## Create Observation

In [None]:
loc_obs

In [None]:
grid_locations

In [445]:
def create_color_observation(position, red_obs, green_obs, white_obs):

    if red_obs != ['Null']:
        if position in red_obs: return 1  # RED
    if green_obs == position: return 2 # GREEN
    elif white_obs != ['Null']:
        if position in white_obs: return 0 # WHITE
    return 3

In [446]:
def create_observation(position, red_obs, green_obs, white_obs):
    return [grid_locations.index(position), create_color_observation(position, red_obs, green_obs, white_obs), create_color_observation((position[0] + 1, position[1]), red_obs, green_obs, white_obs), create_color_observation((position[0] - 1, position[1]), red_obs, green_obs, white_obs), create_color_observation((position[0], position[1] - 1), red_obs, green_obs, white_obs), create_color_observation((position[0], position[1] + 1), red_obs, green_obs, white_obs)]

## Loop

In [None]:
my_agent.qs[0], my_agent.qs[1]

In [None]:
obs = create_observation(loc_obs, red_obs, green_obs, white_obs)
obs

In [None]:
history_of_locs = [loc_obs]

T = 15

for t in range(T):

    obs = create_observation(loc_obs, red_obs, green_obs, white_obs)

    # generate observations
    print(f"Observation: {obs}")

    # belief posterior
    qs = my_agent.infer_states(obs) #directly updates using bayesian inference 

    # plot belief posterior
    plot_beliefs(qs[0])
    plot_beliefs(qs[1])

    # use belief posterior to update rule matrix
    rule_matrix = update_rule_matrix(rule_matrix, qs)
    print(f"Rule Matrix at loc {loc_obs}: {rule_matrix[obs[0]]}")

    # ruled based update on A
    # A = rule_based_update_A(rule_matrix, A)

    # plot updated A
    # plot_A_1(A)


    # policy selection
    my_agent.infer_policies()
    
    chosen_action_id = my_agent.sample_action()

    movement_id = int(chosen_action_id[0])

    choice_action = actions[movement_id]

    print(f'Action at time {t}: {choice_action}')

    
    loc_obs, green_obs, white_obs, red_obs, agent_reward = my_env.step(choice_action)
    
    print(agent_reward, loc_obs, green_obs, white_obs, red_obs)


    history_of_locs.append(loc_obs)

    print(f'Grid location at time {t}: {loc_obs}')

    print(f'Reward at time {t}: {agent_reward}')