In [46]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from IPython import display
from collections import Counter
import optuna
import wandb
from tqdm import tqdm
from gymnasium.envs.toy_text.frozen_lake import generate_random_map

In [47]:
# Environment setup
env = gym.make('FrozenLake-v1', is_slippery=False)
n_actions = env.action_space.n
n_states = env.observation_space.n

In [48]:
# Neural network model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(n_states, n_actions)

    def forward(self, x):
        return self.fc(x)
    
class ConvNet(nn.Module):
    # TODO: normalization layers, maybe layer norm because batch size of 1 ?
    def __init__(self, input_size, n_actions):
        super(ConvNet, self).__init__()
        # Assuming input_size is the flattened size of the 2D state
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        # Flatten the output for the fully connected layer
        self.fc = nn.Linear(32 * input_size, n_actions)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [49]:
device = torch.device("mps")

In [50]:
def preprocess_state(position, map_layout):
    nrows, ncols = 4, 4  # Assuming a 4x4 map
    state_matrix = np.zeros((nrows, ncols))
    
    # Decode map layout
    layout_to_val = {b'F': 0, b'H': 1, b'S': 0, b'G': 3}  # Start 'S' also considered safe '0'
    for i in range(nrows):
        for j in range(ncols):
            state_matrix[i][j] = layout_to_val[map_layout[i][j]]

    # Convert position to 2D coordinates and update in state matrix
    row, col = divmod(position, ncols)
    state_matrix[row][col] = 2  # Marking the current position

    # Normalize
    normalized_state = state_matrix / 3.0 - 0.5
    return torch.tensor(normalized_state, dtype=torch.float).unsqueeze(0)  # Adds batch dimension


In [51]:
print(env.desc)

[[b'S' b'F' b'F' b'F']
 [b'F' b'H' b'F' b'H']
 [b'F' b'F' b'F' b'H']
 [b'H' b'F' b'F' b'G']]


In [52]:
def find_goal_position(env_desc):
    """
    Find the position of the goal ('G') in the environment description.
    """
    for row in range(len(env_desc)):
        for col in range(len(env_desc[row])):
            if env_desc[row][col] == b'G':
                return row, col
    return None

def manhattan_distance(point_a, point_b):
    """
    Calculate the Manhattan distance between two points.
    """
    return abs(point_a[0] - point_b[0]) + abs(point_a[1] - point_b[1])

def calculate_intermediate_reward(current_state, next_state, env, visited_states, forward_step_reward=0, visited_step_reward=0):
    """
    Calculate intermediate reward based on movement towards the goal.
    """
    if next_state in visited_states:
        return visited_step_reward  # Negative reward for visiting a state that has already been visited
    
    #goal_position = find_goal_position(env.desc)
    # Convert state to 2D coordinates
    #ncols = len(env.desc[0])
    #current_row, current_col = divmod(current_state, ncols)
    #next_row, next_col = divmod(next_state, ncols)

    #current_distance = manhattan_distance((current_row, current_col), goal_position)
    #next_distance = manhattan_distance((next_row, next_col), goal_position)

    return forward_step_reward if next_state > current_state else 0
    #return forward_step_reward if next_distance < current_distance else 0


def create_vectorized_environments(env_name, n_envs, random_map, is_slippery):
    if random_map:
        envs = [gym.make(env_name, desc=generate_random_map(size=4), is_slippery=is_slippery) for _ in range(n_envs)]
    else:
        envs = [gym.make(env_name, is_slippery=is_slippery) for _ in range(n_envs)]
    return envs

# Training Function
def train_model(model, optimizer, loss_fn, gamma, epsilon_start, epsilon_decay, num_episodes, device, n_states, random_map=False, is_slippery=False, hole_reward=-1, forward_step_reward=0, visited_step_reward=0):
    n_envs = 8
    max_steps = 4 * 4 * 4 # 4x the number of states seems like a reasonable upper bound
    epsilon = epsilon_start

    # Training loop
    # Initialize lists to store metrics
    losses = []
    weight_norms = []
    bias_norms = []
    grad_norms = []
    total_rewards = []
    epsilon_values = []
    episode_lengths = []
    successful_episodes = 0
    success_rate = []

    plot_update_frequency = int(num_episodes * (5/100))  # update plots every 5% of episodes
    if plot_update_frequency == 0:
        plot_update_frequency = 1


    # Training loop
    for episode in tqdm(range(num_episodes)):
        model.train()
        # Create vectorized environments
        envs = create_vectorized_environments('FrozenLake-v1', n_envs, random_map, is_slippery)


        states = [env.reset()[0] for env in envs]
        dones = [False] * n_envs
        episode_rewards = [0] * n_envs
        visited_states = [set() for _ in envs]
        episode_lengths = [0] * n_envs

        step_count = 0
        ongoing_indices = list(range(n_envs))  # Initialize with all environment indices

        # TODO: fix loss bug when len is 1
        while len(ongoing_indices) > 1 and step_count < max_steps:
            step_count += 1
            # Filter out completed environments for action selection and state update
            ongoing_envs = [envs[i] for i in ongoing_indices]
            ongoing_states = [states[i] for i in ongoing_indices]


            # Preprocess all states and convert them into tensors
            state_tensors = [preprocess_state(state, env.desc).to(device) for state, env in zip(ongoing_states, ongoing_envs)] 
            # Combine all state tensors into a single batch
            state_tensor_batch = torch.stack(state_tensors)
            
            # Compute Q-values for the entire batch
            q_values_batch = model(state_tensor_batch)
            
            # Iterate over each environment to select actions
            actions = []
            for i, (q_values, env) in enumerate(zip(q_values_batch, ongoing_envs)):
                visited_states[ongoing_indices[i]] .add(states[ongoing_indices[i]])
            
                if random.random() < epsilon:
                    action = env.action_space.sample()
                else:
                    action = torch.argmax(q_values).item()
            
                actions.append(action)


            next_states, rewards, next_dones = [], [], []
            for i, (env, action) in enumerate(zip(ongoing_envs, actions)):
                next_state, reward, done, _, _ = env.step(action)
                # Custom reward logic
                if done and reward == 0:  # Agent fell into a hole
                    reward = hole_reward
                else:
                    # Additional logic to calculate reward for moving towards the goal
                    reward += calculate_intermediate_reward(states[ongoing_indices[i]], next_state, env, visited_states, forward_step_reward, visited_step_reward)

                next_states.append(next_state)
                rewards.append(reward)
                next_dones.append(done)
                episode_rewards[ongoing_indices[i]] += reward
                episode_lengths[ongoing_indices[i]]  += 1
            
            
            next_states_tensors = [preprocess_state(state, env.desc).to(device) for state, env in zip(next_states, ongoing_envs)]
            next_states_batch = torch.stack(next_states_tensors)
                        # Create batches for rewards, actions, and dones
            reward_batch = torch.tensor(rewards, device=device)
            action_batch = torch.tensor(actions, device=device)
            done_batch = torch.tensor(next_dones, device=device)
            
            # Compute Q-values for the entire batch
            next_q_values_batch = model(next_states_batch)
            max_next_state_q_values = torch.max(next_q_values_batch, dim=1)[0]
            # Compute the target Q-values
            target_q_values = reward_batch + gamma * max_next_state_q_values * (1 - done_batch.float())
            #print(q_values_batch.shape) # (8, 4)
            #print(action_batch.shape) # (8,)
            #print(target_q_values.shape) # (8,)
            # Correct way to get predicted Q values for the taken actions
            predicted_q_values = q_values_batch.gather(1, action_batch.unsqueeze(-1)).squeeze()
            # Compute the loss
            loss = loss_fn(predicted_q_values, target_q_values)
            
            # Rest of your optimization logic
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()

            # Update states of ongoing indices
            for i, (next_state, done) in enumerate(zip(next_states, next_dones)):
                if not done:
                    states[ongoing_indices[i]] = next_state

            # Calculate and store loss
            losses.append(loss.item())

            # Calculate and store weight and bias norms
            weight_norm = sum(torch.norm(param)**2 for param in model.parameters() if param.dim() > 1)
            bias_norm = sum(torch.norm(param)**2 for param in model.parameters() if param.dim() == 1)
            weight_norms.append(weight_norm.item())
            bias_norms.append(bias_norm.item())

            # Calculate and store gradient norms
            grad_norm = sum(torch.norm(param.grad)**2 for param in model.parameters() if param.grad is not None)
            grad_norms.append(grad_norm.item())

            new_ongoing_indices = []
            for i, (done, reward) in enumerate(zip(next_dones, rewards)):
                if done :
                    if reward > 0:
                        successful_episodes += 1
                else:
                    new_ongoing_indices.append(ongoing_indices[i])
            ongoing_indices = new_ongoing_indices
        # Close all environments at the end of the episode
        for env in envs:
            env.close()
        
        # Store total reward and epsilon value for the episode
        total_rewards.append(episode_rewards)
        epsilon_values.append(epsilon)
        success_rate.append(successful_episodes / (episode + 1))  # Calculate success rate

        # Decay epsilon
        if epsilon > 0.01:
            epsilon *= epsilon_decay

        if episode % plot_update_frequency == 0 or episode == num_episodes - 1:
            # log metrics to wandb  
            eval_success_rate = evaluate_model(model, 100, device, n_states, n_envs, is_slippery, random_map)
            wandb.log({"loss": loss.item(),
                       "weight_norm": weight_norm.item(),
                       "bias_norm": bias_norm.item(),
                       "grad_norm": grad_norm.item(),
                       "epsilon": epsilon,
                       "eval_success_rate": eval_success_rate,
                       "episode": episode,
                       "average_episode_length": sum(episode_lengths)/len(episode_lengths),
                       "learning_rate": learning_rate})
                


    return model

In [53]:
def evaluate_model(model, num_eval_episodes, device, n_states, batch_size=10, is_slippery=False, random_map=False):
    #TODO: give other metrics such as average episode length, average reward, etc.
    model.eval()
    with torch.no_grad():
        successful_episodes = 0
        total_evaluated = 0
    
        while total_evaluated < num_eval_episodes:
            envs = []
            states = []
            dones = []
            for _ in range(min(batch_size, num_eval_episodes - total_evaluated)):
                if random_map:
                    env = gym.make('FrozenLake-v1', desc=generate_random_map(size=4), is_slippery=is_slippery)
                else:
                    env = gym.make('FrozenLake-v1', is_slippery=is_slippery)
                envs.append(env)
                states.append(env.reset()[0])
                dones.append(False)
    
            max_steps = len(envs[0].desc) * len(envs[0].desc[0]) * 4
            step_count = 0
    
            while step_count < max_steps and not all(dones):
                step_count += 1
                state_tensors = [preprocess_state(state, env.desc).to(device) for state, env in zip(states, envs)]
                state_batch = torch.stack(state_tensors)
                q_values = model(state_batch)
                actions = torch.argmax(q_values, dim=1).cpu().numpy()
    
                for i, env in enumerate(envs):
                    if not dones[i]:
                        next_state, reward, done, _, _ = env.step(actions[i])
                        states[i] = next_state
                        dones[i] = done
    
                        if done and reward > 0:
                            successful_episodes += 1
    
            total_evaluated += len(envs)
    
        success_rate = successful_episodes / num_eval_episodes
    return success_rate


In [54]:
# Optuna Objective Function
def objective(trial):
    wandb.init(project="frozenlake_slipperry_optuna_reward_search_convnet_random_map",
               name=f"trial_{trial.number}",
               config=trial.params,
               reinit=True)
    # find rewards
    hole_reward = trial.suggest_float("hole_reward", -1, 0)
    forward_step_reward = trial.suggest_float("forward_step_reward", 0, 1)
    visited_step_reward = trial.suggest_float("visited_step_reward", -1, 0)
    
    # Hyperparameters
    learning_rate = 0.0001
    gamma = 0.99
    epsilon = 0.8
    epsilon_decay = 0.999
    num_episodes = 2500

    n_actions = env.action_space.n
    n_states = env.observation_space.n
    model = ConvNet(n_states, n_actions).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss().to(device)
    random_map = True
    is_slippery = False

    trained_model = train_model(model, optimizer, loss_fn, gamma, epsilon, epsilon_decay, num_episodes, device, n_states, random_map=random_map, is_slippery=is_slippery,
                                hole_reward=hole_reward, forward_step_reward=forward_step_reward, visited_step_reward=visited_step_reward)
    
    num_eval_episodes = 50 
    success_rate = evaluate_model(trained_model, num_eval_episodes, device, n_states, is_slippery, random_map)


    wandb.finish()
    return success_rate


study

In [55]:
"""
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)

print("Best trial:")
trial = study.best_trial
print(f" Value: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")
"""

'\nstudy = optuna.create_study(direction=\'maximize\')\nstudy.optimize(objective, n_trials=50)\n\nprint("Best trial:")\ntrial = study.best_trial\nprint(f" Value: {trial.value}")\nprint(" Params: ")\nfor key, value in trial.params.items():\n    print(f"    {key}: {value}")\n'

In [56]:
# Hyperparameters
learning_rate = 0.0001
gamma = 0.99
epsilon = 1
epsilon_decay = 0.995

In [57]:
random_map = True
is_slippery = False

In [58]:
# Train the model
num_episodes = 1000
n_actions = env.action_space.n
n_states = env.observation_space.n
model = ConvNet(n_states, n_actions).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()
wandb.init(project="frozenlake_slipperry_batches_convnet_random_map",
           name=f"First batch attemps",
           reinit=True)
model = train_model(model, optimizer, loss_fn, gamma, epsilon, epsilon_decay, num_episodes, device, n_states, 
                    random_map=random_map, is_slippery=is_slippery,
                    forward_step_reward=0.05,
                    visited_step_reward=0,
                    hole_reward=0)# maybe try 0 and not -0.05 again and run for longer so that it learns to avoid holes
wandb.finish()

100%|██████████| 1000/1000 [06:46<00:00,  2.46it/s]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
average_episode_length,▁▂▁▁▂▁▂▁▂▂▁▂▃▂▁▄▁▁█▁▁
bias_norm,▇███▇▇▆▅▄▄▄▃▂▂▂▂▂▂▂▁▁
episode,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
epsilon,█▆▅▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁
eval_success_rate,▁▄▃▂▃▂▃▃▄▅▇▅▅▆█▇▅▆▁▇▄
grad_norm,▁▂██▆█▃▂▄▃▄▃█▅▃▁▃▃▁█▃
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,▁▁▅▄▃▃▂▃▂▁▂▂▄▂▂▁▂▁▁█▁
weight_norm,▁▁▂▃▄▄▅▅▆▆▆▆▆▇▇▇▇▇███

0,1
average_episode_length,5.75
bias_norm,0.64172
episode,999.0
epsilon,0.00999
eval_success_rate,0.39
grad_norm,1.39275
learning_rate,0.0001
loss,0.02263
weight_norm,22.57725


In [59]:
#TODO: maybe try to give bigger rewards the closer we get to the goal
#TODO: try batch normalization
#TODO: try 1 hot encoding the state instead of the current encoding
#TODO: have datapoints for loss every step, but only evaluate model every few steps

In [63]:
# Evaluate the model
num_eval_episodes = 100
success_rate = evaluate_model(model, num_eval_episodes, device, n_states, batch_size=10, is_slippery=is_slippery, random_map=random_map)
print("Success Rate: ", success_rate)

Success Rate:  0.48


In [61]:
def visualize_policy_on_random_maps(model, device, n_maps=5, n_states=16, is_slippery=False):
    # Create a color map for the grid with your specified colors


    for map_idx in range(n_maps):
        # Create environment with a random map
        env = gym.make('FrozenLake-v1', desc=generate_random_map(size=4), is_slippery=is_slippery)

        # Convert env.desc to a string format
        desc = env.desc.astype(str)

        # Create a color map for the grid
        numerical_color_map = {'S': 2, 'F': 1, 'H': 0, 'G': 3}
        numerical_grid_colors = np.vectorize(numerical_color_map.get)(desc)


        # Extracting the policy from the model for the current map
        policy = np.array([torch.argmax(model(preprocess_state(s, env.desc).to(device))).item() for s in range(env.observation_space.n)])

        # Mapping actions to symbols for visualization
        action_symbols = {0: '←', 1: '↓', 2: '→', 3: '↑'}
        policy_symbols = np.vectorize(action_symbols.get)(policy)

        # Reshape for grid visualization
        policy_grid = policy_symbols.reshape(env.desc.shape)

        # Plotting the policy grid
        plt.figure(figsize=(5, 5))
        # Create a ListedColormap for custom colors
        cmap = ListedColormap(['black', 'lightblue', 'lightblue', 'yellow'])  # Order: H, F, S, G
        plt.imshow(numerical_grid_colors, cmap=cmap, interpolation='nearest')

        for i in range(desc.shape[0]):
            for j in range(desc.shape[1]):
                arrow = policy_grid[i, j]
                arrow_color = 'white' if desc[i, j] in ['H', 'G'] else 'black'
                
                        # Checking if the arrow points to a hole
                if desc[i, j] != 'H':
                    if (arrow == '←' and j > 0 and desc[i, j-1] == 'H') or \
                       (arrow == '→' and j < desc.shape[1] - 1 and desc[i, j+1] == 'H') or \
                       (arrow == '↑' and i > 0 and desc[i-1, j] == 'H') or \
                       (arrow == '↓' and i < desc.shape[0] - 1 and desc[i+1, j] == 'H'):
                        arrow_color = 'red'
                
                plt.text(j, i, policy_grid[i, j], ha='center', va='center', fontsize=20, color=arrow_color)
        plt.title(f'Policy Visualization for Map {map_idx+1}')
        plt.show()

visualize_policy_on_random_maps(model, device, n_maps=5, is_slippery=False)

RuntimeError: linear(): input and weight.T shapes cannot be multiplied (32x16 and 512x4)