In [34]:
import gymnasium as gym
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from pytorchrl.data import PrioritizedReplayBuffer
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import optuna
import wandb
from tqdm import tqdm
from gymnasium.envs.toy_text.frozen_lake import generate_random_map
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from collections import deque
from concurrent.futures import ThreadPoolExecutor

FrozenLake environment

In [35]:
# Environment setup
env = gym.make('FrozenLake-v1', is_slippery=False)
n_actions = env.action_space.n
n_states = env.observation_space.n

In [36]:
def convert_layout_to_tensor(map_layouts):
    nrows, ncols = len(map_layouts[0]), len(map_layouts[0][0])
    num_statuses = 4
    batch_size = len(map_layouts)

    # Initialize a tensor for the batch of layouts
    layout_tensor = torch.zeros((batch_size, nrows, ncols, num_statuses), device='cpu', dtype=torch.float)

    layout_to_val = {b'F': 0, b'H': 1, b'S': 0, b'G': 3,
                     'F': 0, 'H': 1, 'S': 0, 'G': 3}

    # Precompute all indices for the batch
    all_indices = [
        layout_to_val[item]
        for map_layout in map_layouts
        for row in map_layout
        for item in row
    ]
    indices_tensor = torch.tensor(all_indices, device='cpu').view(batch_size, nrows, ncols)

    # Update the tensor using advanced indexing
    layout_tensor.scatter_(3, indices_tensor.unsqueeze(3), 1)

    return layout_tensor


def update_start_positions(tensor_layout:Tensor, positions):
    nrows, ncols, _ = tensor_layout.size()[1:4]

    # Convert positions to a PyTorch tensor if it's not already one
    if not isinstance(positions, torch.Tensor):
        positions = torch.tensor(positions, dtype=torch.long, device=tensor_layout.device)

    # Calculate rows and columns for all positions
    rows = positions // ncols
    cols = positions % ncols

    # Reset the cells to [0, 0, 0, 0]
    tensor_layout[torch.arange(positions.size(0)), rows, cols] = torch.tensor([0, 0, 0, 0], dtype=tensor_layout.dtype, device=tensor_layout.device)

    # Set the start cells to [0, 0, 1, 0]
    tensor_layout[torch.arange(positions.size(0)), rows, cols, 2] = 1

    return tensor_layout

In [37]:
tensor = convert_layout_to_tensor(generate_random_map(4))
print(tensor)
updated_tensor = update_start_positions(tensor, [[0]])
print(updated_tensor) # dimensions: batch_size x nrows x ncols x num_statuses


tensor([[[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 0., 0., 1.]]]])
tensor([[[[0., 0., 1., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 0., 0., 1.]]]])


Implementation of DeamerV3 model
![DeamerV3 model](model_names.png)


In [38]:
# Sequence model that learns the dynamics of the latent space
# takes as input h t-1, z t-1, a t-1 and outputs ht
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(GRUModel, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # GRU Layer
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        # Fully connected layer
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        # Forward propagate GRU
        out, _ = self.gru(x, h0)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])

        return out


In [39]:
class WorldModel(nn.Module):
    def __init__(self, input_size, xt_size, gru_hidden_size, num_layers, zt_dim, num_categories, reward_size, continue_size):
        super(WorldModel, self).__init__()

        self.gru_hidden_size = gru_hidden_size
        self.num_layers = num_layers
        self.num_categories = num_categories
        
        # GRU Layer
        self.gru = nn.GRU(input_size, gru_hidden_size, num_layers, batch_first=True)
        

        # Dynamics predictor
        self.dynamics_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, zt_dim * num_categories)
        )
        
        # Reward predictor
        self.reward_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size + zt_dim * num_categories, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, reward_size)
        )
        
        # Continue signal predictor
        self.continue_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size + zt_dim * num_categories, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, continue_size)
        )
        
        # Decoder predictor
        self.decoder_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size + zt_dim * num_categories, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, xt_size)
        )
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(gru_hidden_size + xt_size, 512),
            nn.LayerNorm(512),
            nn.SiLU(),
            nn.Linear(512, zt_dim * num_categories)  # Output logits for each category in each latent dimension
        )


    def forward(self, ztat, xt, ht_1):
        if ht_1 is None:
            # Initialize hidden state with zeros if not provided
            ht_1 = torch.zeros(self.num_layers, ztat.size(0), self.gru_hidden_size).to(xt.device)

        # Forward propagate GRU
        _ , ht = self.gru(ztat, ht_1)
       
        zt = self.encode(xt, ht)
        
        zt_hat = self.dynamics_predictor(ht)

        htzt = torch.cat((ht.view(1, -1), zt.view(1, -1)), dim=1)
        rt_hat = self.reward_predictor(htzt)
        ct_hat = self.continue_predictor(htzt)
        xt_hat = self.decoder_predictor(htzt)
    
        return ht, zt_hat, rt_hat, ct_hat, xt_hat, zt
    
    def encode(self, xt, ht, temperature=1.0):
        # Encode input and apply Gumbel Softmax
        # ht has is 1,1,32 need to be 1,32
        ht = ht.view(-1, self.gru_hidden_size)
        encoded = self.encoder(torch.cat((xt, ht), dim=1))
        encoded = encoded.view(-1, self.num_categories)
        softmax = nn.functional.gumbel_softmax(encoded, tau=temperature, hard=True)
        return softmax
    
    

In [40]:
def symlog(x):
    """Symmetric log transformation"""
    return torch.sign(x) * torch.log(torch.abs(x) + 1)

def symexp(x):
    """Inverse of symmetric log transformation"""
    return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)

class SymlogLoss(nn.Module):
    def __init__(self):
        super(SymlogLoss, self).__init__()

    def forward(self, input, target):
        # Apply symlog transformation to the target
        transformed_target = symlog(target)
        # Compute MSE Loss between input and transformed target
        return F.mse_loss(input, transformed_target)

In [41]:
# Define the Actor Network
class Actor(nn.Module):
    def __init__(self, state_size, action_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.ln1 = nn.LayerNorm(128)
        self.silu = nn.SiLU()
        self.fc2 = nn.Linear(128, action_size)

    def forward(self, state):
        x = self.ln1(self.fc1(state))
        x = self.silu(x)
        action_probs = torch.softmax(self.fc2(x), dim=-1)
        return action_probs

# Define the Critic Network
class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.ln1 = nn.LayerNorm(128)
        self.silu = nn.SiLU()
        self.fc2 = nn.Linear(128, 1)

    def forward(self, state):
        x = self.ln1(self.fc1(state))
        x = self.silu(x)
        value = self.fc2(x)
        return value

Training loop

In [42]:
device = "mps"

In [43]:
latent_size = 32 # aka zt size
num_categories = 4 # number of categories in latent space

ht_size = 32
xt_size = 4*4*4
reward_size = 1
continue_size = 1
worldModel = WorldModel(latent_size * num_categories + n_actions, xt_size, ht_size, 1, latent_size, num_categories, reward_size, continue_size).to(device)
worldModel_optimizer = optim.Adam(worldModel.parameters(), lr=1e-4, eps=1e-8)

actor = Actor(ht_size + latent_size*num_categories, n_actions).to(device)
actor_optimizer = optim.Adam(actor.parameters(), lr=3e-5, eps=1e-5)
critic = Critic(ht_size).to(device)
critic_optimizer = optim.Adam(critic.parameters(), lr=3e-5, eps=1e-5)

In [44]:
# Training loop
num_episodes = 1000
gamma = 0.997
Bpred = 1
Bdyn = 0.5
Brep = 0.1
loss_fn = SymlogLoss()
kl_loss = nn.KLDivLoss(reduction='batchmean')

logging = False

if logging:
    wandb.init(project="Dreamerv3 Reproduction",
               name="Zt in actor state",
               reinit=True,
               config={"num_episodes": num_episodes,
                       "gamma": gamma,
                       "latent_size": latent_size,
                       "num_categories": num_categories,
                       "ht_size": ht_size,
                       "reward_size": reward_size,
                       "continue_size": continue_size,
                       "worldModel_optimizer": worldModel_optimizer,
                       "actor_optimizer": actor_optimizer,
                       "critic_optimizer": critic_optimizer})


previous_episode_reward = 0
for episode in tqdm(range(num_episodes)):
    state = env.reset()[0]
    done = False
    episode_rewards = []
    layout_tensor = convert_layout_to_tensor([env.desc]).to(device)
    hidden_state = torch.zeros((1, 1, ht_size), device=device, dtype=torch.float)
    max_steps = 100

    while not done and max_steps > 0:
        state_tensor = update_start_positions(layout_tensor, [state])
        state_tensor = state_tensor.view(1, -1).to(device)
        with torch.no_grad():
            # Encode the state
            encoded_state = worldModel.encode(state_tensor, hidden_state)# could keep it from last output.
        encoded_state_flat = encoded_state.view(1, -1) 

        
        hidden_state = hidden_state.to(device)
        actor_state = torch.cat((hidden_state.view(1, -1), encoded_state_flat), dim=1)

        action_probs = actor(actor_state)[0] # remove batch dimension

        action = torch.multinomial(action_probs, 1).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state_tensor = update_start_positions(layout_tensor, [next_state]).view(1, -1).to(device)

        # Store rewards for later use
        episode_rewards.append(reward)


        action_probs = action_probs.view(1, -1)
        
        # Integrated model loss
        action_one_hot = F.one_hot(torch.tensor(action), n_actions).view(1, -1).to(device)

        integratedModel_input = torch.cat((encoded_state_flat, action_one_hot.detach()), dim=1)
        # add batch because input needs to have batch, seq, input_size
        integratedModel_input = integratedModel_input.view(1, 1, -1)
        
        # Now, use the flattened input for your integrated model
        next_hidden_state, zt_hat, rt_hat, ct_hat, xt_hat, zt = worldModel(integratedModel_input, state_tensor, hidden_state)
        xt = next_state_tensor.view(1, -1)
        rt = torch.tensor(reward, device=device, dtype=torch.float).view(1, -1)
        ct = torch.tensor(1 - int(done), device=device, dtype=torch.float).view(1, -1)
        prediction_loss = (loss_fn(xt_hat, xt) 
                                + loss_fn(rt_hat, rt) 
                                + loss_fn(ct_hat, ct))
        
        zt = zt.view(1, 1, -1)
        
        threshold = 1
        dynamic_loss = torch.max(torch.tensor(threshold), kl_loss(zt_hat, zt.detach()))
        representation_loss = torch.max(torch.tensor(threshold), kl_loss(zt, zt_hat.detach()))
        
        worldModel_loss = Bpred * prediction_loss + Bdyn * dynamic_loss + Brep * representation_loss
        
        worldModel_optimizer.zero_grad()
        worldModel_loss.backward()
        torch.nn.utils.clip_grad_norm_(worldModel.parameters(), 1000)
        worldModel_optimizer.step()
        
        #Trying actor critic training after world model has 100 episodes of training
        critic_grads = 0
        actor_grads = 0
        critic_loss = 0
        actor_loss = 0
        if episode > 100:
            # Critic loss
            imagination_horizon = 3 # TODO: 15 in paper
            return_lambda = 0.95
            lookahead_reward = reward + (gamma * (1 - return_lambda) * critic(next_hidden_state.detach()) * (1 - int(done)))
            with torch.no_grad():
                imagined_done = done
                imagined_ht = next_hidden_state
                imagined_zt = worldModel.encode(next_state_tensor, imagined_ht)
                imagined_xt = next_state_tensor
                for t in range(1, imagination_horizon+1):
                    if imagined_done:
                        break
                    
                    im_actor_state = torch.cat((imagined_ht.view(1, -1), imagined_zt.view(1,-1)), dim=1)
                    imagined_action_probs = actor(im_actor_state)[0]
                    imagined_action = torch.multinomial(action_probs, 1).item()
                    
                    im_action_one_hot = F.one_hot(torch.tensor(imagined_action), n_actions).view(1, -1).to(device)
    
                    imagined_integratedModel_input = torch.cat((imagined_zt.view(1, -1), im_action_one_hot.detach()), dim=1)
                    # add batch because input needs to have batch, seq, input_size
                    imagined_integratedModel_input = imagined_integratedModel_input.view(1, 1, -1)
                    
                    # Now, use the flattened input for your integrated model
                    im_next_hidden_state, im_zt_hat, im_rt_hat, im_ct_hat, im_xt_hat , _ = worldModel(imagined_integratedModel_input, imagined_xt, imagined_ht)
                    imagined_reward = (im_rt_hat + im_ct_hat*(critic(im_next_hidden_state.detach())))
                    lookahead_reward +=  (gamma ** t) * (return_lambda ** (t-1)) * imagined_reward
                    
                    imagined_done = imagined_done or im_ct_hat > 0.5
                    imagined_ht = im_next_hidden_state
                    imagined_zt = im_zt_hat
                    imagined_xt = im_xt_hat
                # one last critic estimation
                if not imagined_done:
                    lookahead_reward += (gamma ** imagination_horizon+1) * (return_lambda ** imagination_horizon) * critic(im_next_hidden_state.detach())
            
            #target_value = reward + (gamma * critic(next_hidden_state.detach()) * (1 - int(done)))
            target_value = lookahead_reward
            value = critic(hidden_state)
            critic_loss = F.mse_loss(value, target_value.detach())
    
            # Actor loss
            advantage = target_value - value
            actor_loss = -torch.log(action_probs.squeeze(0)[action] + 1e-8) * advantage.detach()
        
            # Backpropagation
            critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), 100)
            critic_optimizer.step()
    
            actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), 100)
            actor_optimizer.step()
            
            # get gradients norms
            critic_grads = sum([torch.norm(param.grad) for param in critic.parameters()])
            actor_grads = sum([torch.norm(param.grad) for param in actor.parameters()])
    

        
        if logging and episode%10 == 0:
            worldModel_grads = sum([torch.norm(param.grad) for param in worldModel.parameters()])
        
            # get weights norms
            critic_weights_norm = sum([torch.norm(param) for param in critic.parameters()])
            actor_weights_norm = sum([torch.norm(param) for param in actor.parameters()])
            worldModel_weights_norm = sum([torch.norm(param) for param in worldModel.parameters()])

            wandb.log({"critic_loss": critic_loss.item(),
                        "critic_grads": critic_grads,
                        "critic_weights_norm": critic_weights_norm,
                        "actor_loss": actor_loss.item(),
                        "actor_grads": actor_grads,
                        "actor_weights_norm": actor_weights_norm,
                        "worldModel_loss": worldModel_loss.item(),
                        "worldModel_grads": worldModel_grads,
                        "representation_loss": representation_loss.item(),
                        "dynamic_loss": dynamic_loss.item(),
                        "prediction_loss": prediction_loss.item(),
                        "worldModel_weights_norm": worldModel_weights_norm,
                       "previous_episode_reward": previous_episode_reward})
        
        state = next_state
        hidden_state = next_hidden_state.detach()
        max_steps -= 1
    
    previous_episode_reward = sum(episode_rewards)


# Close the environment
env.close()

if logging:
    wandb.finish()


  2%|▏         | 21/1000 [00:05<04:39,  3.50it/s]


KeyboardInterrupt: 

In [None]:
# Test the model
num_episodes = 100
successes = 0
for _ in range(num_episodes):
    actor.eval()
    critic.eval()
    worldModel.eval()
    
    state = env.reset()[0]
    done = False
    episode_rewards = []
    layout_tensor = convert_layout_to_tensor([env.desc])
    hidden_state = torch.zeros((1, 1, ht_size), device=device, dtype=torch.float)
    
    autoencoder_losses = []
    
    max_steps = 100
    while not done and max_steps > 0:
        state_tensor = update_start_positions(layout_tensor, [state])
        state_tensor = state_tensor.view(1, -1).to(device)
        with torch.no_grad():
            # Encode the state
            encoded_state = worldModel.encode(state_tensor, hidden_state)
        encoded_state_flat = encoded_state.view(1, -1)
        
        actor_state = torch.cat((hidden_state, encoded_state_flat), dim=1)
        action_probs = actor(actor_state)[0]
        value = critic(hidden_state)

        action = torch.multinomial(action_probs, 1).item()
        next_state, reward, done, _, _ = env.step(action)
        next_state_tensor = update_start_positions(layout_tensor, [next_state])

        # Store rewards for later use
        episode_rewards.append(reward)
        
        integratedModel_input = torch.cat((encoded_state_flat, action_probs.detach()), dim=1)
        # add batch because input needs to have batch, seq, input_size
        integratedModel_input = integratedModel_input.view(1, 1, -1)
        hidden_state, _, _, _, _, _ = worldModel(integratedModel_input, state_tensor  ,hidden_state)
        
        state = next_state
        max_steps -= 1

    
    if reward == 1:
        successes += 1

# Close the environment
env.close()
print(f"Success rate: {successes/num_episodes}")
    