In [107]:
import gymnasium as gym
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import optuna
import wandb
from tqdm import tqdm
from gymnasium.envs.toy_text.frozen_lake import generate_random_map
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from collections import deque
from concurrent.futures import ThreadPoolExecutor

FrozenLake environment

In [108]:
# Environment setup
env = gym.make('FrozenLake-v1', is_slippery=False)
n_actions = env.action_space.n
n_states = env.observation_space.n

In [109]:
def convert_layout_to_tensor(map_layouts):
    nrows, ncols = len(map_layouts[0]), len(map_layouts[0][0])
    num_statuses = 4
    batch_size = len(map_layouts)

    # Initialize a tensor for the batch of layouts
    layout_tensor = torch.zeros((batch_size, nrows, ncols, num_statuses), device='cpu', dtype=torch.float)

    layout_to_val = {b'F': 0, b'H': 1, b'S': 0, b'G': 3,
                     'F': 0, 'H': 1, 'S': 0, 'G': 3}

    # Precompute all indices for the batch
    all_indices = [
        layout_to_val[item]
        for map_layout in map_layouts
        for row in map_layout
        for item in row
    ]
    indices_tensor = torch.tensor(all_indices, device='cpu').view(batch_size, nrows, ncols)

    # Update the tensor using advanced indexing
    layout_tensor.scatter_(3, indices_tensor.unsqueeze(3), 1)

    return layout_tensor


def update_start_positions(tensor_layout:Tensor, positions):
    nrows, ncols, _ = tensor_layout.size()[1:4]

    # Convert positions to a PyTorch tensor if it's not already one
    if not isinstance(positions, torch.Tensor):
        positions = torch.tensor(positions, dtype=torch.long, device=tensor_layout.device)

    # Calculate rows and columns for all positions
    rows = positions // ncols
    cols = positions % ncols

    # Reset the cells to [0, 0, 0, 0]
    tensor_layout[torch.arange(positions.size(0)), rows, cols] = torch.tensor([0, 0, 0, 0], dtype=tensor_layout.dtype, device=tensor_layout.device)

    # Set the start cells to [0, 0, 1, 0]
    tensor_layout[torch.arange(positions.size(0)), rows, cols, 2] = 1

    return tensor_layout

In [110]:
tensor = convert_layout_to_tensor(generate_random_map(4))
print(tensor)
updated_tensor = update_start_positions(tensor, [[0]])
print(updated_tensor) # dimensions: batch_size x nrows x ncols x num_statuses


tensor([[[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[0., 1., 0., 0.]],

         [[0., 1., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 0., 0., 1.]]]])
tensor([[[[0., 0., 1., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[0., 1., 0., 0.]],

         [[0., 1., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 0., 0., 1.]]]])


Implementation of DeamerV3 model
![DeamerV3 model](model_names.png)


In [111]:
# Autoencoder that learns the latent space of the data
class CategoricalAutoencoder(nn.Module):
    def __init__(self, input_size, latent_dim, num_categories):
        super(CategoricalAutoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.num_categories = num_categories

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(True),
            nn.Linear(512, latent_dim * num_categories)  # Output logits for each category in each latent dimension
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim * num_categories, 512),
            nn.ReLU(True),
            nn.Linear(512, input_size),
            nn.ReLU(True)
        )
        
    def encode(self, x, temperature=1.0):
        # Encode input and apply Gumbel Softmax
        encoded = self.encoder(x)
        encoded = encoded.view(-1, self.num_categories)
        softmax = nn.functional.gumbel_softmax(encoded, tau=temperature, hard=True)
        return softmax

    def forward(self, x, temperature=1.0):
        batch_size = x.size(0)
        softmax = self.encode(x, temperature)
        softmax = softmax.view(batch_size, -1)

        x = self.decoder(softmax)
        return x, softmax
    

In [112]:
# Sequence model that learns the dynamics of the latent space
# takes as input h t-1, z t-1, a t-1 and outputs ht
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(GRUModel, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # GRU Layer
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)

        # Fully connected layer
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # Forward propagate GRU
        out, _ = self.gru(x, h0)  # out: tensor of shape (batch_size, seq_length, hidden_size)

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])

        return out


In [113]:
# Predictor model that predicts the next zt, rt, ct, xt
class PredictorModel(nn.Module):
    def __init__(self, gru_hidden_size, zt_dim, num_categories, reward_size, continue_size):
        super(PredictorModel, self).__init__()
        
        # Shared Layers
        self.shared_layers = nn.Sequential(
            nn.Linear(gru_hidden_size, 512),
            nn.ReLU()
        )

        # Specific predictor for the latent space zt
        self.dynamics_predictor = nn.Sequential(
            nn.Linear(512, 256),  # Reduced dimensionality for specific prediction
            nn.ReLU(),
            nn.Linear(256, zt_dim * num_categories)
        )
        
        # Specific predictor for the reward rt
        self.reward_predictor = nn.Sequential(
            nn.Linear(512, 256),  # Same as above
            nn.ReLU(),
            nn.Linear(256, reward_size)
        )
        
        # Specific predictor for the continue signal ct
        self.continue_predictor = nn.Sequential(
            nn.Linear(512, 256),  # Same as above
            nn.ReLU(),
            nn.Linear(256, continue_size)
        )

    def forward(self, ht):
        # Shared feature extraction
        shared_features = self.shared_layers(ht)

        # Predict the latent space zt using shared features
        z_hat = self.dynamics_predictor(shared_features)
        # Predict the reward rt using shared features
        r_hat = self.reward_predictor(shared_features)
        # Predict the continue signal ct using shared features
        c_hat = self.continue_predictor(shared_features)

        return z_hat, r_hat, c_hat

class IntegratedModel(nn.Module):
    def __init__(self, input_size, gru_hidden_size, num_layers, zt_dim, num_categories, reward_size, continue_size):
        super(IntegratedModel, self).__init__()

        self.gru_hidden_size = gru_hidden_size
        self.num_layers = num_layers

        # GRU Layer
        self.gru = nn.GRU(input_size, gru_hidden_size, num_layers, batch_first=True)

        # Shared Layers
        self.shared_layers = nn.Sequential(
            nn.Linear(gru_hidden_size, 512),
            nn.ReLU()
        )

        # Specific predictor for the latent space zt
        self.dynamics_predictor = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, zt_dim * num_categories)
        )
        
        # Specific predictor for the reward rt
        self.reward_predictor = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, reward_size)
        )
        
        # Specific predictor for the continue signal ct
        self.continue_predictor = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, continue_size)
        )

    def forward(self, x, h0=None):
        if h0 is None:
            # Initialize hidden state with zeros if not provided
            h0 = torch.zeros(self.num_layers, x.size(0), self.gru_hidden_size).to(x.device)

        # Forward propagate GRU
        out, hidden = self.gru(x, h0)

        # Apply shared layers to the last hidden state
        shared_features = self.shared_layers(out[:, -1, :])
        
        zt_hat = self.dynamics_predictor(shared_features)
        rt_hat = self.reward_predictor(shared_features)
        ct_hat = self.continue_predictor(shared_features)
    
        return hidden, zt_hat, rt_hat, ct_hat

In [114]:
# Define the Actor Network
class Actor(nn.Module):
    def __init__(self, state_size, action_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc2 = nn.Linear(128, action_size)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        action_probs = torch.softmax(self.fc2(x), dim=-1)
        return action_probs

# Define the Critic Network
class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        value = self.fc2(x)
        return value

Training loop

In [115]:
device = "mps"

In [116]:
latent_size = 32 # aka zt size
num_categories = 4 # number of categories in latent space
autoencoder = CategoricalAutoencoder(4*4*4, latent_size, num_categories).to(device)
autoencoder_optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)

ht_size = 32
reward_size = 1
continue_size = 1
integratedModel = IntegratedModel(latent_size * num_categories + n_actions, ht_size, 1, latent_size, num_categories, reward_size, continue_size).to(device)
integratedModel_optimizer = optim.Adam(integratedModel.parameters(), lr=0.001)

actor = Actor(ht_size, n_actions).to(device)
actor_optimizer = optim.Adam(actor.parameters(), lr=0.001)
critic = Critic(ht_size).to(device)
critic_optimizer = optim.Adam(critic.parameters(), lr=0.001)

In [130]:
# Training loop
num_episodes = 1000
gamma = 0.99
for episode in tqdm(range(num_episodes)):
    state = env.reset()[0]
    done = False
    episode_rewards = []
    layout_tensor = convert_layout_to_tensor([env.desc]).to(device)
    hidden_state = torch.zeros((1, 1, ht_size), device=device, dtype=torch.float)

    while not done:
        state_tensor = update_start_positions(layout_tensor, [state])
        state_tensor = state_tensor.view(1, -1).to(device)

        
        hidden_state = hidden_state.to(device)
        action_probs = actor(hidden_state)[0] # remove batch dimension
        value = critic(hidden_state)

        action = torch.multinomial(action_probs, 1).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state_tensor = update_start_positions(layout_tensor, [next_state]).view(1, -1).to(device)

        # Store rewards for later use
        episode_rewards.append(reward)
        with torch.no_grad():
            # Encode the state
            encoded_state = autoencoder.encode(state_tensor)
            encoded_next_state = autoencoder.encode(next_state_tensor)

        # Adjusting shapes
        encoded_state_expanded = encoded_state.view(1, -1) 
        action_probs = action_probs.view(1, -1)
        
        # Concatenating tensors
        # Note: All tensors now have a batch dimension of 1
        # Integrated model loss
        integratedModel_input = torch.cat((encoded_state_expanded, action_probs.detach()), dim=1)
        # add batch because input needs to have batch, seq, input_size
        integratedModel_input = integratedModel_input.view(1, 1, -1)
        
        # Now, use the flattened input for your integrated model
        next_hidden_state, zt_hat, rt_hat, ct_hat = integratedModel(integratedModel_input, hidden_state)
        zt = encoded_next_state.view(1, -1)
        rt = torch.tensor(reward, device=device, dtype=torch.float).view(1, -1)
        ct = torch.tensor(1 - int(done), device=device, dtype=torch.float).view(1, -1)
        integratedModel_loss = (F.mse_loss(zt_hat, zt) 
                                + F.mse_loss(rt_hat, rt) 
                                + F.mse_loss(ct_hat, ct))

        # Critic loss
        target_value = reward + (gamma * critic(next_hidden_state.detach()) * (1 - int(done)))
        critic_loss = F.mse_loss(value, target_value.detach())

        # Actor loss
        advantage = target_value - value
        actor_loss = -torch.log(action_probs.squeeze(0)[action]) * advantage.detach()

        # Autoencoder loss
        autoencoder_state, _ = autoencoder(state_tensor.detach())
        autoencoder_loss = F.mse_loss(autoencoder_state, state_tensor)
        
    
        # Backpropagation
        critic_optimizer.zero_grad()
        critic_loss.backward()
        critic_optimizer.step()

        actor_optimizer.zero_grad()
        actor_loss.backward()
        actor_optimizer.step()
        
        autoencoder_optimizer.zero_grad()
        autoencoder_loss.backward()
        autoencoder_optimizer.step()
        
        integratedModel_optimizer.zero_grad()
        integratedModel_loss.backward()
        integratedModel_optimizer.step()

        state = next_state
        hidden_state = next_hidden_state.detach()



# Close the environment
env.close()


100%|██████████| 1000/1000 [05:24<00:00,  3.08it/s]


In [141]:
# Test the model
num_episodes = 100
successes = 0
for _ in range(num_episodes):
    actor.eval()
    critic.eval()
    autoencoder.eval()
    integratedModel.eval()
    
    state = env.reset()[0]
    done = False
    episode_rewards = []
    layout_tensor = convert_layout_to_tensor([env.desc])
    hidden_state = torch.zeros((1, 1, ht_size), device=device, dtype=torch.float)
    
    autoencoder_losses = []
    
    while not done:
        state_tensor = update_start_positions(layout_tensor, [state])
        state_tensor = state_tensor.view(1, -1).to(device)
        with torch.no_grad():
            # Encode the state
            encoded_state = autoencoder.encode(state_tensor)
        encoded_state_expanded = encoded_state.view(1, -1)
        
        action_probs = actor(hidden_state)[0]
        value = critic(hidden_state)

        action = torch.multinomial(action_probs, 1).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state_tensor = update_start_positions(layout_tensor, [next_state])

        # Store rewards for later use
        episode_rewards.append(reward)
        
        integratedModel_input = torch.cat((encoded_state_expanded, action_probs.detach()), dim=1)
        # add batch because input needs to have batch, seq, input_size
        integratedModel_input = integratedModel_input.view(1, 1, -1)
        hidden_state, _, _, _ = integratedModel(integratedModel_input, hidden_state)
        
        state = next_state
        
        # Test autoencoder loss
        autoencoder_state, _ = autoencoder(state_tensor.detach())
        autoencoder_loss = F.mse_loss(autoencoder_state, state_tensor)
        autoencoder_losses.append(autoencoder_loss.item())
    
    if reward == 1:
        successes += 1
    print(f"Average autoencoder loss: {np.mean(autoencoder_losses)}")
    print(f"Exemple state: {state_tensor}")
    print(f"Exemple autoencoder state: {autoencoder_state}")
# Close the environment
env.close()
print(f"Success rate: {successes/num_episodes}")
    

Average autoencoder loss: 0.04757815150215345
Exemple state: tensor([[0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0.,
         0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.,
         1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 1.]], device='mps:0')
Exemple autoencoder state: tensor([[0.0000, 0.0000, 1.0000, 0.0000, 0.1983, 0.0000, 0.8016, 0.0000, 0.4790,
         0.0000, 0.0000, 0.0000, 0.6825, 0.0000, 0.3177, 0.0000, 0.0000, 0.0000,
         0.2699, 0.0000, 0.0000, 0.9436, 0.0000, 0.0000, 0.9012, 0.0000, 0.0000,
         0.0000, 0.0000, 0.9711, 0.0000, 0.0000, 0.9601, 0.0000, 0.0398, 0.0000,
         0.9835, 0.0000, 0.0000, 0.0000, 0.9934, 0.0000, 0.0000, 0.0000, 0.0000,
         1.0000, 0.0000, 0.0000, 0.0000, 0.9930, 0.0000, 0.0000, 0.9989, 0.0000,
         0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.0000]], device='mps:0', 

KeyboardInterrupt: 