In [468]:
import gymnasium as gym
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchrl.data import PrioritizedReplayBuffer, ListStorage
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import optuna
import wandb
from tqdm import tqdm
from gymnasium.envs.toy_text.frozen_lake import generate_random_map
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from collections import deque
from concurrent.futures import ThreadPoolExecutor

FrozenLake environment

In [469]:
# Environment setup
env = gym.make('FrozenLake-v1', is_slippery=False)
n_actions = env.action_space.n
n_states = env.observation_space.n

In [470]:
def convert_layout_to_tensor(map_layouts):
    nrows, ncols = len(map_layouts[0]), len(map_layouts[0][0])
    num_statuses = 4
    batch_size = len(map_layouts)

    # Initialize a tensor for the batch of layouts
    layout_tensor = torch.zeros((batch_size, nrows, ncols, num_statuses), device='cpu', dtype=torch.float)

    layout_to_val = {b'F': 0, b'H': 1, b'S': 0, b'G': 3,
                     'F': 0, 'H': 1, 'S': 0, 'G': 3}

    # Precompute all indices for the batch
    all_indices = [
        layout_to_val[item]
        for map_layout in map_layouts
        for row in map_layout
        for item in row
    ]
    indices_tensor = torch.tensor(all_indices, device='cpu').view(batch_size, nrows, ncols)

    # Update the tensor using advanced indexing
    layout_tensor.scatter_(3, indices_tensor.unsqueeze(3), 1)

    return layout_tensor


def update_start_positions(tensor_layout:Tensor, positions):
    nrows, ncols, _ = tensor_layout.size()[1:4]

    # Convert positions to a PyTorch tensor if it's not already one
    if not isinstance(positions, torch.Tensor):
        positions = torch.tensor(positions, dtype=torch.long, device=tensor_layout.device)

    # Calculate rows and columns for all positions
    rows = positions // ncols
    cols = positions % ncols

    # Reset the cells to [0, 0, 0, 0]
    tensor_layout[torch.arange(positions.size(0)), rows, cols] = torch.tensor([0, 0, 0, 0], dtype=tensor_layout.dtype, device=tensor_layout.device)

    # Set the start cells to [0, 0, 1, 0]
    tensor_layout[torch.arange(positions.size(0)), rows, cols, 2] = 1

    return tensor_layout

In [471]:
tensor = convert_layout_to_tensor(generate_random_map(4))
print(tensor)
updated_tensor = update_start_positions(tensor, [[0]])
print(updated_tensor) # dimensions: batch_size x nrows x ncols x num_statuses


tensor([[[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 0., 0., 1.]]]])
tensor([[[[0., 0., 1., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 0., 0., 1.]]]])


Implementation of DeamerV3 model
![DeamerV3 model](model_names.png)


In [472]:
# Sequence model that learns the dynamics of the latent space
# takes as input h t-1, z t-1, a t-1 and outputs ht
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(GRUModel, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # GRU Layer
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        # Fully connected layer
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        # Forward propagate GRU
        out, _ = self.gru(x, h0)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])

        return out


In [473]:
class WorldModel(nn.Module):
    def __init__(self, input_size, xt_size, gru_hidden_size, num_layers, zt_dim, num_categories, reward_size, continue_size):
        super(WorldModel, self).__init__()

        self.gru_hidden_size = gru_hidden_size
        self.num_layers = num_layers
        self.num_categories = num_categories
        
        # GRU Layer
        self.gru = nn.GRU(input_size, gru_hidden_size, num_layers, batch_first=True)
        

        # Dynamics predictor
        self.dynamics_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, zt_dim * num_categories)
        )
        
        # Reward predictor
        self.reward_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size + zt_dim * num_categories, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, reward_size)
        )
        
        # Continue signal predictor
        self.continue_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size + zt_dim * num_categories, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, continue_size)
        )
        
        # Decoder predictor
        self.decoder_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size + zt_dim * num_categories, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, xt_size)
        )
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(gru_hidden_size + xt_size, 512),
            nn.LayerNorm(512),
            nn.SiLU(),
            nn.Linear(512, zt_dim * num_categories)  # Output logits for each category in each latent dimension
        )


    def forward(self, ztat, xt, ht_1):
        batch_size = ztat.size(0)
        if ht_1 is None:
            # Initialize hidden state with zeros if not provided
            ht_1 = torch.zeros(self.num_layers, ztat.size(0), self.gru_hidden_size).to(xt.device)

        # Forward propagate GRU
        _ , ht = self.gru(ztat, ht_1)
       
        zt = self.encode(xt, ht)
        
        zt_hat = self.dynamics_predictor(ht)

        htzt = torch.cat((ht.view(batch_size, -1), zt.view(batch_size, -1)), dim=1)
        rt_hat = self.reward_predictor(htzt)
        ct_hat = self.continue_predictor(htzt)
        xt_hat = self.decoder_predictor(htzt)
    
        return ht, zt_hat, rt_hat, ct_hat, xt_hat, zt
    
    def encode(self, xt, ht, temperature=1.0):
        batch_size = xt.size(0)
        
        ht = ht.view(batch_size, -1)# it's not batch first
        encoded = self.encoder(torch.cat((xt, ht), dim=1))
        softmax = nn.functional.gumbel_softmax(encoded, tau=temperature, hard=True)
        return softmax
    
    

In [474]:
def symlog(x):
    """Symmetric log transformation"""
    return torch.sign(x) * torch.log(torch.abs(x) + 1)

def symexp(x):
    """Inverse of symmetric log transformation"""
    return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)

class SymlogLoss(nn.Module):
    def __init__(self):
        super(SymlogLoss, self).__init__()

    def forward(self, input, target):
        # Apply symlog transformation to the target

        transformed_target = symlog(target)
        # Compute MSE Loss between input and transformed target
        return F.mse_loss(input, transformed_target)

In [475]:
# Define the Actor Network
class Actor(nn.Module):
    def __init__(self, state_size, action_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.ln1 = nn.LayerNorm(128)
        self.silu = nn.SiLU()
        self.fc2 = nn.Linear(128, action_size)

    def forward(self, state):
        x = self.ln1(self.fc1(state))
        x = self.silu(x)
        action_probs = torch.softmax(self.fc2(x), dim=-1)
        return action_probs

# Define the Critic Network
class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.ln1 = nn.LayerNorm(128)
        self.silu = nn.SiLU()
        self.fc2 = nn.Linear(128, 1)

    def forward(self, state):
        x = self.ln1(self.fc1(state))
        x = self.silu(x)
        value = self.fc2(x)
        return value

Training loop

In [476]:
device = "mps"

In [477]:
latent_size = 32 # aka zt size
num_categories = 4 # number of categories in latent space

ht_size = 32
xt_size = 4*4*4
reward_size = 1
continue_size = 1
worldModel = WorldModel(latent_size * num_categories + n_actions, xt_size, ht_size, 1, latent_size, num_categories, reward_size, continue_size).to(device)
worldModel_optimizer = optim.Adam(worldModel.parameters(), lr=1e-4, eps=1e-8)

actor = Actor(ht_size + latent_size*num_categories, n_actions).to(device)
actor_optimizer = optim.Adam(actor.parameters(), lr=3e-5, eps=1e-5)
critic = Critic(ht_size).to(device)
critic_optimizer = optim.Adam(critic.parameters(), lr=3e-5, eps=1e-5)

In [478]:
def update_world_model(worldModel, worldModel_optimizer, replay_buffer, batch_size=32, steps=3, logging=False):
    loss_fn = SymlogLoss()
    experiences, info = replay_buffer.sample(batch_size, return_info=True)
    # Unpack experiences

    states = experiences[0]

    
    # get envs from states
    envs = []
    for state in states:
        env = gym.make('FrozenLake-v1', is_slippery=False)
        env.reset()
        env.state = state
        envs.append(env)
    
    # hidden states are not batch first
    hidden_states = torch.zeros((1, batch_size, ht_size), device=device, dtype=torch.float)
    layout_tensor = convert_layout_to_tensor([env.desc for env in envs]).to(device)
    
    for step in range(steps):
        # step through envs
        #TODO: maybe not random steps
        actions = torch.tensor([env.action_space.sample() for env in envs], device=device)
        next_states = []
        rewards = []
        dones = []
        for env, action in zip(envs, actions):
            #TODO: if done should reset
            next_state, reward, done, _, _ = env.step(action.item())
            rewards.append(reward)
            dones.append(done)
            next_states.append(next_state)
        
        dones = torch.tensor(dones, device=device, dtype=torch.float)
        rewards = torch.tensor(rewards, device=device, dtype=torch.float)
        
        next_states_tensor = update_start_positions(layout_tensor, next_states).view(batch_size, -1).to(device)    
        with torch.no_grad():
            # Encode the state
            states_tensor = update_start_positions(layout_tensor, states)
            states_tensor = states_tensor.view(batch_size, -1).to(device)
            

            zts = worldModel.encode(states_tensor, hidden_states)
    
        # Integrated model loss
        action_one_hot = F.one_hot(actions.detach(), n_actions).view(batch_size, -1).to(device)
    

        integratedModel_input = torch.cat((zts, action_one_hot.detach()), dim=1)
        # add batch because input needs to have batch, seq, input_size
        integratedModel_input = integratedModel_input.view(batch_size, 1, -1)
        
            
        ht, zt_hat, rt_hat, ct_hat, xt_hat, zt = worldModel(integratedModel_input, states_tensor, hidden_states)
        xt = next_states_tensor
        rt = rewards.view(batch_size, 1)
        ct = (1 - dones).view(batch_size, 1)
        prediction_loss = (loss_fn(xt_hat, xt) 
                                + loss_fn(rt_hat, rt) 
                                + loss_fn(ct_hat, ct))
        
        zt = zt.view(batch_size, 1, -1)
        
        threshold = 1
        dynamic_loss = torch.max(torch.tensor(threshold), kl_loss(zt_hat, zt.detach()))
        representation_loss = torch.max(torch.tensor(threshold), kl_loss(zt, zt_hat.detach()))
        
        worldModel_loss = Bpred * prediction_loss + Bdyn * dynamic_loss + Brep * representation_loss
        
        worldModel_optimizer.zero_grad()
        worldModel_loss.backward()
        torch.nn.utils.clip_grad_norm_(worldModel.parameters(), 1000)
        worldModel_optimizer.step()
        
        states = next_states
        hidden_states = ht.detach()
        
        if logging:
            grad_norm = sum([torch.norm(param.grad) for param in worldModel.parameters()])
            wandb.log({"worldModel_loss": worldModel_loss.item(),
                       "worldModel_grads": grad_norm,
                       "representation_loss": representation_loss.item(),
                       "dynamic_loss": dynamic_loss.item(),
                       "prediction_loss": prediction_loss.item()})


In [479]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x2ca809810>

In [480]:
def update_actor_critic(actor, critic, actor_optimizer, critic_optimizer, worldModel, actor_critics_replay_buffer, batch_size=32, logging=False):
    #sample from buffer
    experiences, info = actor_critics_replay_buffer.sample(batch_size, return_info=True)
    # Unpack experiences
    ht, zt, state_tensors, rewards, dones = experiences # actor states is zt and ht
    

    dones = dones.float().to(device).view(batch_size, 1).detach()
    rewards = rewards.float().to(device).view(batch_size, 1).detach()
    
    ht = ht.view(batch_size, 1, -1).to(device)

    
    # Critic loss
    imagination_horizon = 3 # TODO: 15 in paper
    gamma = 0.997
    return_lambda = 0.95
    critic_estimation = critic(ht.detach()).view(batch_size, 1)
    value = critic_estimation.clone()
    lookahead_reward = rewards + (gamma * (1 - return_lambda) * critic_estimation * (1 - dones))
    with torch.no_grad():
        imagined_dones = dones.clone()
        im_actor_states = torch.cat((ht.view(batch_size, -1), zt.view(batch_size,-1)), dim=1)
        imagined_xt = state_tensors.view(batch_size, -1)
        imagined_ht = ht.view(1, batch_size, -1) # not batch first
        imagined_zt = zt
        for t in range(1, imagination_horizon+1):
            if imagined_dones.all():
                break
            
            imagined_action_probs = actor(im_actor_states)
            imagined_action = torch.multinomial(imagined_action_probs, num_samples=1)
            
            
            im_action_one_hot = F.one_hot(imagined_action, n_actions).view(batch_size, -1).to(device)

            imagined_integratedModel_input = torch.cat((imagined_zt.view(batch_size, -1), im_action_one_hot.detach()), dim=1)
            # add batch because input needs to have batch, seq, input_size
            imagined_integratedModel_input = imagined_integratedModel_input.view(batch_size, 1, -1)
            
            # Now, use the flattened input for your integrated model
            im_next_hidden_state, im_zt_hat, im_rt_hat, im_ct_hat, im_xt_hat , _ = worldModel(imagined_integratedModel_input, imagined_xt, imagined_ht)

            im_next_hidden_state = im_next_hidden_state.view(batch_size, -1)
            
            imagined_dones = torch.logical_or(imagined_dones, im_ct_hat > 0.5)
            critic_estimation = critic(im_next_hidden_state.detach())
            imagined_reward = im_rt_hat + imagined_dones * critic_estimation


            lookahead_reward +=  (gamma ** t) * (return_lambda ** (t-1)) * imagined_reward
            
            
            imagined_ht = im_next_hidden_state.view(1, batch_size, -1)
            imagined_zt = im_zt_hat
            imagined_xt = im_xt_hat
            im_actor_states = torch.cat((imagined_ht.view(batch_size, -1), imagined_zt.view(batch_size,-1)), dim=1)

        # one last critic estimation
        if not imagined_dones.all():
            lookahead_reward += imagined_dones * (gamma ** imagination_horizon+1) * (return_lambda ** imagination_horizon) * critic(im_next_hidden_state.detach())
    
    target_value = lookahead_reward
    critic_loss = F.mse_loss(value, target_value.detach())

    # Actor loss
    advantage = target_value - value.clone().detach()
    # Calculate Actor Loss
    # Assuming actions are in the form of indices of the chosen actions
    action_probs = actor( torch.cat((ht.view(batch_size, -1), zt.view(batch_size,-1)), dim=1).detach() )  # Probabilities of actions from the actor network
    # take actions from probs
    actions = torch.multinomial(action_probs, num_samples=1)
    
    # Gather the probabilities of the actions taken
    gathered_probs = action_probs.gather(1, actions.view(-1, 1))
    
    # Add a small number to probabilities to avoid log(0)
    actor_loss = -torch.log(gathered_probs + 1e-8) * advantage.detach()
    actor_loss = actor_loss.mean()

    # Backpropagation
    critic_optimizer.zero_grad()
    critic_loss.backward()
    torch.nn.utils.clip_grad_norm_(critic.parameters(), 100)
    critic_optimizer.step()

    actor_optimizer.zero_grad()
    actor_loss.backward()
    torch.nn.utils.clip_grad_norm_(actor.parameters(), 100)
    actor_optimizer.step()
    
    if logging:
        actor_grads = sum([torch.norm(param.grad) for param in actor.parameters()])
        critic_grads = sum([torch.norm(param.grad) for param in critic.parameters()])
        actor_weights_norm = sum([torch.norm(param) for param in actor.parameters()])
        critic_weights_norm = sum([torch.norm(param) for param in critic.parameters()])
        wandb.log({"critic_loss": critic_loss.item(),
                    "critic_grads": critic_grads,
                    "critic_weights_norm": critic_weights_norm,
                    "actor_loss": actor_loss.item(),
                    "actor_grads": actor_grads,
                    "actor_weights_norm": actor_weights_norm})

In [481]:
def collect_experience(encoded_state_flat, hidden_state, env):
    """
    Handles interaction with the environment and collects experience.
    """

    hidden_state = hidden_state.to(device)
    
    actor_state = torch.cat((hidden_state.view(1, -1), encoded_state_flat), dim=1)
    
    action_probs = actor(actor_state)[0]  # remove batch dimension
    action = torch.multinomial(action_probs, 1).item()
    
    next_state, reward, done, _, _ = env.step(action)
    
    return state_tensor, action, reward, next_state, done


In [482]:
# Training loop
batch_size = 16
training_delta = 10 # number of episodes to train world model before training actor critic
num_episodes = batch_size + training_delta + 30
gamma = 0.997
Bpred = 1
Bdyn = 0.5
Brep = 0.1
loss_fn = SymlogLoss()
kl_loss = nn.KLDivLoss(reduction='batchmean')

logging = False

if logging:
    wandb.init(project="Dreamerv3 Reproduction",
               name="Trying per + batches",
               reinit=True,
               config={"num_episodes": num_episodes,
                       "gamma": gamma,
                       "latent_size": latent_size,
                       "num_categories": num_categories,
                       "ht_size": ht_size,
                       "reward_size": reward_size,
                       "continue_size": continue_size,
                       "worldModel_optimizer": worldModel_optimizer,
                       "actor_optimizer": actor_optimizer,
                       "critic_optimizer": critic_optimizer})


previous_episode_reward = 0
# Parameters
buffer_size = 10000  # Adjust as needed
alpha = 0.6  # Adjust as needed for prioritization (0 for uniform, 1 for fully prioritized)

# Initialize buffer
world_model_replay_buffer = PrioritizedReplayBuffer(alpha=alpha, storage=ListStorage(buffer_size), beta=0.4)
actor_critics_replay_buffer = PrioritizedReplayBuffer(alpha=alpha, storage=ListStorage(buffer_size), beta=0.4)

for episode in tqdm(range(num_episodes)):
    # Reset environment and episode-specific variables
    state = env.reset()[0]
    done = False
    episode_rewards = []
    layout_tensor = convert_layout_to_tensor([env.desc]).to(device)
    hidden_state = torch.zeros((1, 1, ht_size), device=device, dtype=torch.float)
    max_steps = n_states * n_actions
    
    while not done and max_steps > 0:
        # --- Collect Experience ---
        state_tensor = update_start_positions(layout_tensor, [state])
        state_tensor = state_tensor.view(1, -1).to(device)
    
        with torch.no_grad():
        # Encode the state
            encoded_state = worldModel.encode(state_tensor, hidden_state)

        state_tensor, action, reward, next_state, done = collect_experience(encoded_state, hidden_state, env)
        episode_rewards.append(reward)
        
        # Add experience to world model buffer
        world_model_replay_buffer.add((state, action, reward, next_state, done))
        
        # --- World Model Update ---
        if len(world_model_replay_buffer) > batch_size:
            update_world_model(worldModel, worldModel_optimizer, world_model_replay_buffer, batch_size=batch_size, logging=logging)
        
        # --- Actor-Critic Update ---
        if episode > batch_size + training_delta:
            experience_actor_critic = (hidden_state, encoded_state, state_tensor, reward, done)
            actor_critics_replay_buffer.add(experience_actor_critic)
            if len(world_model_replay_buffer) > batch_size:
                update_actor_critic(actor, critic, actor_optimizer, critic_optimizer, worldModel, actor_critics_replay_buffer, batch_size=batch_size, logging=logging)
        
        # --- Logging and State Updates ---
        if logging:
            wandb.log({"previous_episode_reward": previous_episode_reward})
        
        state = next_state
        integratedModel_input = torch.cat((encoded_state, F.one_hot(torch.tensor(action), n_actions).view(1, -1).to(device)), dim=1)
        integratedModel_input = integratedModel_input.view(1, 1, -1)
        hidden_state, _, _, _, _, _ = worldModel(integratedModel_input, state_tensor, hidden_state)
        max_steps -= 1
    
    previous_episode_reward = sum(episode_rewards)


# Close the environment
env.close()

if logging:
    wandb.finish()


100%|██████████| 56/56 [02:24<00:00,  2.59s/it]


In [483]:
# Test the model
num_episodes = 100
successes = 0
for _ in range(num_episodes):
    actor.eval()
    critic.eval()
    worldModel.eval()
    
    state = env.reset()[0]
    done = False
    episode_rewards = []
    layout_tensor = convert_layout_to_tensor([env.desc])
    hidden_state = torch.zeros((1, 1, ht_size), device=device, dtype=torch.float)
    
    autoencoder_losses = []
    
    max_steps = 100
    while not done and max_steps > 0:
        state_tensor = update_start_positions(layout_tensor, [state])
        state_tensor = state_tensor.view(1, -1).to(device)
        with torch.no_grad():
            # Encode the state
            encoded_state = worldModel.encode(state_tensor, hidden_state)
        encoded_state_flat = encoded_state.view(1, -1)
        
        actor_state = torch.cat((hidden_state.view(1, -1), encoded_state_flat), dim=1)
        action_probs = actor(actor_state)[0]
        value = critic(hidden_state)

        action = torch.multinomial(action_probs, 1).item()
        action_one_hot = F.one_hot(torch.tensor(action), n_actions).view(1, -1).to(device)
        next_state, reward, done, _, _ = env.step(action)
        next_state_tensor = update_start_positions(layout_tensor, [next_state])

        # Store rewards for later use
        episode_rewards.append(reward)
        
        integratedModel_input = torch.cat((encoded_state_flat, action_one_hot.detach()), dim=1)
        # add batch because input needs to have batch, seq, input_size
        integratedModel_input = integratedModel_input.view(1, 1, -1)
        hidden_state, _, _, _, _, _ = worldModel(integratedModel_input, state_tensor  ,hidden_state)
        
        state = next_state
        max_steps -= 1

    
    if reward == 1:
        successes += 1

# Close the environment
env.close()
print(f"Success rate: {successes/num_episodes}")
    

Success rate: 0.0
