In [196]:
import gymnasium as gym
import torch
from torch import Tensor
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchrl.data import PrioritizedReplayBuffer, ListStorage
import numpy as np
import random
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import optuna
import wandb
from tqdm import tqdm
from gymnasium.envs.toy_text.frozen_lake import generate_random_map
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from collections import deque
from concurrent.futures import ThreadPoolExecutor

FrozenLake environment

In [197]:
debug = False

In [198]:
# Environment setup
env = gym.make('FrozenLake-v1', is_slippery=False)
n_actions = env.action_space.n
n_states = env.observation_space.n

In [199]:
def convert_layout_to_tensor(map_layouts):
    nrows, ncols = len(map_layouts[0]), len(map_layouts[0][0])
    num_statuses = 4
    batch_size = len(map_layouts)

    # Initialize a tensor for the batch of layouts
    layout_tensor = torch.zeros((batch_size, nrows, ncols, num_statuses), device='cpu', dtype=torch.float)

    layout_to_val = {b'F': 0, b'H': 1, b'S': 0, b'G': 3,
                     'F': 0, 'H': 1, 'S': 0, 'G': 3}

    # Precompute all indices for the batch
    all_indices = [
        layout_to_val[item]
        for map_layout in map_layouts
        for row in map_layout
        for item in row
    ]
    indices_tensor = torch.tensor(all_indices, device='cpu').view(batch_size, nrows, ncols)

    # Update the tensor using advanced indexing
    layout_tensor.scatter_(3, indices_tensor.unsqueeze(3), 1)

    return layout_tensor



def update_start_positions(tensor_layout:Tensor, positions):
    nrows, ncols, _ = tensor_layout.size()[1:4]

    # Convert positions to a PyTorch tensor if it's not already one
    if not isinstance(positions, torch.Tensor):
        positions = torch.tensor(positions, dtype=torch.long, device=tensor_layout.device)

    # Calculate rows and columns for all positions
    rows = positions // ncols
    cols = positions % ncols

    # Reset the cells to [0, 0, 0, 0]
    tensor_layout[torch.arange(positions.size(0)), rows, cols] = torch.tensor([0, 0, 0, 0], dtype=tensor_layout.dtype, device=tensor_layout.device)

    # Set the start cells to [0, 0, 1, 0]
    tensor_layout[torch.arange(positions.size(0)), rows, cols, 2] = 1

    return tensor_layout


def convert_tensor_to_layout(tensor_layout):
    # Define the mapping from tensor indices back to layout characters
    index_to_layout = {0: b'F', 1: b'H', 2: b'S', 3: b'G'}
    
    # Get the size of the tensor
    batch_size, nrows, ncols, _ = tensor_layout.size()
    
    # Initialize the layout list
    layouts = []
    
    # Iterate over each layout in the batch
    for b in range(batch_size):
        # Initialize the layout for the current batch
        layout = []
        
        # Iterate over each row
        for i in range(nrows):
            # Initialize the row
            row = []
            
            # Iterate over each column
            for j in range(ncols):
                # Find the index of the maximum value in the last dimension (status)
                status_index = tensor_layout[b, i, j].argmax().item()
                
                # Convert the index to a layout character and append to the row
                row.append(index_to_layout[status_index])
            
            # Append the row (converted to a numpy array) to the layout
            layout.append(row)
        
        # Append the layout (converted to a numpy array) to the layouts list
        layouts.append(layout)
    
    # Convert the layouts list to a numpy array and return
    return layouts


In [200]:
tensor = convert_layout_to_tensor(generate_random_map(4))
print(tensor)
updated_tensor = update_start_positions(tensor, [[0]])
print(updated_tensor) # dimensions: batch_size x nrows x ncols x num_statuses


tensor([[[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[0., 1., 0., 0.]],

         [[0., 1., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 0., 0., 1.]]]])
tensor([[[[0., 0., 1., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[0., 1., 0., 0.]],

         [[0., 1., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]]],


        [[[1., 0., 0., 0.]],

         [[0., 1., 0., 0.]],

         [[1., 0., 0., 0.]],

         [[0., 0., 0., 1.]]]])


Implementation of DeamerV3 model
![DeamerV3 model](model_names.png)


In [201]:
class WorldModel(nn.Module):
    def __init__(self, input_size, xt_size, gru_hidden_size, num_layers, zt_dim, num_categories, reward_size, continue_size):
        super(WorldModel, self).__init__()

        self.gru_hidden_size = gru_hidden_size
        self.num_layers = num_layers
        self.num_categories = num_categories
        
        # GRU Layer
        self.gru = nn.GRU(input_size, gru_hidden_size, num_layers, batch_first=True, bias=False)
        # Input normalization layer for GRU
        self.input_norm = nn.LayerNorm(input_size)
        #self.input_norm = nn.BatchNorm1d(input_size)
        #self.input_norm = nn.Identity()
        # Hidden state normalization layer for GRU
        self.hidden_norm = nn.LayerNorm(gru_hidden_size)

        # Dynamics predictor
        self.dynamics_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, zt_dim * num_categories)
        )
        
        # Reward predictor
        self.reward_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size + zt_dim * num_categories, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, reward_size)
        )
        
        # Continue signal predictor
        self.continue_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size + zt_dim * num_categories, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, continue_size)
        )
        
        # Decoder predictor
        self.decoder_predictor = nn.Sequential(
            nn.Linear(gru_hidden_size + zt_dim * num_categories, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Linear(256, xt_size)
        )
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(gru_hidden_size + xt_size, 512),
            nn.LayerNorm(512),
            nn.SiLU(),
            nn.Linear(512, zt_dim * num_categories)  # Output logits for each category in each latent dimension
        )


    def forward(self, ztat, xt, ht_1):
        batch_size = ztat.size(0)
        if ht_1 is None:
            # Initialize hidden state with zeros if not provided
            ht_1 = torch.zeros(self.num_layers, ztat.size(0), self.gru_hidden_size).to(xt.device)

        # Normalize GRU input
        normalized_ztat = self.input_norm(ztat)
        
        # Forward propagate GRU
        _, ht = self.gru(normalized_ztat, ht_1)
        # Normalize GRU output (hidden state)
        ht = self.hidden_norm(ht)
       
        zt = self.encode(xt, ht)
        
        zt_hat = self.dynamics_predictor(ht)
        # clamp logits to avoid numerical instability
        zt_hat = torch.clamp(zt_hat, -10, 10)

        htzt = torch.cat((ht.view(batch_size, -1), zt.view(batch_size, -1)), dim=1)
        rt_hat = self.reward_predictor(htzt)
        ct_hat = self.continue_predictor(htzt)
        xt_hat = self.decoder_predictor(htzt)
    
        return ht, zt_hat, rt_hat, ct_hat, xt_hat, zt
    
    def encode(self, xt, ht, temperature=1.0):
        batch_size = xt.size(0)
        ht = ht.view(batch_size, -1)  # it's not batch first
        logits = self.encoder(torch.cat((xt, ht), dim=1))
        # clamp logits to avoid numerical instability
        logits = torch.clamp(logits, -10, 10)
        
        return logits


    
    

In [202]:
def symlog(x):
    """Symmetric log transformation"""
    return torch.sign(x) * torch.log(torch.abs(x) + 1)

def symexp(x):
    """Inverse of symmetric log transformation"""
    return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)

class SymlogLoss(nn.Module):
    def __init__(self):
        super(SymlogLoss, self).__init__()

    def forward(self, input, target):
        # Apply symlog transformation to the target

        transformed_target = symlog(target)
        # Compute MSE Loss between input and transformed target
        return F.mse_loss(input, transformed_target)

In [203]:
# Define the Actor Network
class Actor(nn.Module):
    def __init__(self, state_size, action_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.ln1 = nn.LayerNorm(128)
        self.silu = nn.SiLU()
        self.fc2 = nn.Linear(128, action_size)

    def forward(self, state):
        x = self.ln1(self.fc1(state))
        x = self.silu(x)
        action_probs = torch.softmax(self.fc2(x), dim=-1)
        return action_probs

# Define the Critic Network
class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.ln1 = nn.LayerNorm(128)
        self.silu = nn.SiLU()
        self.fc2 = nn.Linear(128, 1)

    def forward(self, state):
        x = self.ln1(self.fc1(state))
        x = self.silu(x)
        value = self.fc2(x)
        return value

Training loop

In [204]:
device = "mps"

In [205]:
latent_size = 128 # aka zt size
num_categories = 4 # number of categories in latent space
num_layers = 1

ht_size = 32
xt_size = 4*4*4
reward_size = 1
continue_size = 1
worldModel = WorldModel(latent_size * num_categories + n_actions, xt_size, ht_size, num_layers, latent_size, num_categories, reward_size, continue_size).to(device)
worldModel_optimizer = optim.Adam(worldModel.parameters(), lr=1e-4, eps=1e-8, weight_decay=1e-5)

actor = Actor(ht_size + latent_size*num_categories, n_actions).to(device)
actor_optimizer = optim.Adam(actor.parameters(), lr=3e-5, eps=1e-5, weight_decay=1e-5)
critic = Critic(ht_size + latent_size*num_categories ).to(device)
critic_optimizer = optim.Adam(critic.parameters(), lr=3e-5, eps=1e-5, weight_decay=1e-5)

In [206]:
def update_world_model(worldModel, worldModel_optimizer, replay_buffer, actor, batch_size=32, steps=15, logging=False):
    actor.eval()
    experiences, info = replay_buffer.sample(batch_size, return_info=True)
    # Unpack experiences
    
    experience_losses = torch.zeros((batch_size, 1), device=device)
    
    states = experiences[0]
    states = states.int()
    
    # get envs from states
    envs = []
    for state in states:
        env = gym.make('FrozenLake-v1', is_slippery=False)
        env.reset()
        env.state = state
        envs.append(env)
    
    first_prediction_loss = None
    # hidden states are not batch first
    hidden_states = torch.zeros((1, batch_size, ht_size), device=device, dtype=torch.float)
    layout_tensor = convert_layout_to_tensor([env.desc for env in envs]).to(device)


    dones_bool = [False] * batch_size
    for step in range(steps):
        # reset done envs
        for i, done in enumerate(dones_bool):
            if done:
                envs[i].reset()
                states[i] = envs[i].state
                hidden_states[0][i] = torch.zeros((1, ht_size), device=device, dtype=torch.float)
                dones_bool[i] = False
        
        with torch.no_grad():
            # Encode the state
            states_tensor = update_start_positions(layout_tensor.clone(), states)
            states_tensor = states_tensor.view(batch_size, -1).to(device)
            zts = worldModel.encode(states_tensor, hidden_states)
            actor_state = torch.cat((hidden_states.view(batch_size, -1), zts.view(batch_size,-1)), dim=1)
            action_probs = actor(actor_state)
            actions = torch.multinomial(action_probs, num_samples=1)

        next_states = []
        rewards = []
        dones = []
        for env, action in zip(envs, actions):
            #TODO: if done should reset
            next_state, reward, done, _, _ = env.step(action.item())
            # try non sparse reward:
            reward = reward + (next_state%4 + next_state//4)/(16 * 3) 
            rewards.append(reward)
            dones.append(done)
            next_states.append(next_state)
        
        dones_bool = torch.tensor(dones, device=device, dtype=torch.bool)
        dones = torch.tensor(dones, device=device, dtype=torch.float)
        rewards = torch.tensor(rewards, device=device, dtype=torch.float)
        # print number of rewards equal to 1
        #count = (rewards >= 1).sum()
        #if count >0:
        #    print(f"Number of rewards equal to 1: {count}, in step {step}")
         
        with torch.no_grad():
            # Encode the state
            next_states_tensor = update_start_positions(layout_tensor.clone(), next_states).view(batch_size, -1).to(device)   

            zts = worldModel.encode(states_tensor, hidden_states)
    
        # Integrated model loss
        action_one_hot = F.one_hot(actions.detach(), n_actions).view(batch_size, -1).to(device)
    

        integratedModel_input = torch.cat((zts, action_one_hot.detach()), dim=1)
        # add batch because input needs to have batch, seq, input_size
        integratedModel_input = integratedModel_input.view(batch_size, 1, -1)
        
   
        ht, zt_hat, rt_hat, ct_hat, xt_hat, zt = worldModel(integratedModel_input, states_tensor, hidden_states)
        xt = next_states_tensor
        rt = rewards.view(batch_size, 1)
        ct = (1 - dones).view(batch_size, 1)
        #loss_fn = SymlogLoss()
        # Define the loss functions with reduction set to 'none' to keep individual losses
        xt_loss_fn = nn.MSELoss(reduction='none')
        rt_loss_fn = nn.MSELoss(reduction='none')
        ct_loss_fn = nn.MSELoss(reduction='none')
        
        # Compute the losses for each element in the batch (without reducing them)
        xt_loss = xt_loss_fn(xt_hat, xt)
        rt_loss = rt_loss_fn(rt_hat, rt)
        ct_loss = ct_loss_fn(ct_hat, ct)
        
        # You can now use xt_loss, rt_loss, and ct_loss to update buffer weights
        # Each *_loss tensor contains the individual losses for each example in the batch
        
        # If you want a single loss value for backward pass, you can reduce the losses manually
        prediction_loss = (xt_loss.mean() + rt_loss.mean() + ct_loss.mean())
        
        zt = zt.view(batch_size, -1)
        zt_hat = zt_hat.view(batch_size, -1)
        threshold = 0.02 # TODO: not following paper here, should be 1 threshold and kl loss
        #kl_loss = nn.KLDivLoss(reduction='batchmean')

        #dynamic_loss = torch.max(torch.tensor(threshold), kl_loss(F.log_softmax(zt_hat, dim=1), zt.detach()))
        #representation_loss = torch.max(torch.tensor(threshold), kl_loss(F.log_softmax(zt, dim=1) , zt_hat.detach()))
        # try MSE
        dynamic_loss = torch.max(torch.tensor(threshold), F.mse_loss(zt_hat, zt.detach()))
        representation_loss = torch.max(torch.tensor(threshold), F.mse_loss(zt, zt_hat.detach()))

        
        worldModel_loss = Bpred * prediction_loss + Bdyn * dynamic_loss + Brep * representation_loss
        
        if debug:
            if torch.isnan(prediction_loss).any(): print("Prediction losses: ", prediction_loss)
            if torch.isnan(dynamic_loss).any(): print("Dynamic losses: ", dynamic_loss)
            if torch.isnan(representation_loss).any(): print("Representation losses: ", representation_loss)
            if torch.isnan(worldModel_loss).any(): print("World model losses: ", worldModel_loss)
            if torch.isnan(xt).any(): print("Xt: ", xt)
            if torch.isnan(xt_hat).any(): print("Xt_hat: ", xt_hat)
            if torch.isnan(rt).any(): print("Rt: ", rt)
            if torch.isnan(rt_hat).any(): print("Rt_hat: ", rt_hat)
            if torch.isnan(ct).any(): print("Ct: ", ct)
            if torch.isnan(ct_hat).any(): print("Ct_hat: ", ct_hat)
            if torch.isnan(zt).any(): print("Zt: ", zt)
            if torch.isnan(zt_hat).any(): print("Zt_hat: ", zt_hat)
            if torch.isnan(ht).any(): print("Ht: ", ht)
            if torch.isnan(hidden_states).any(): print("Ht-1: ", hidden_states)
            if torch.isnan(actions).any(): print("Action: ", actions)
            if torch.isnan(integratedModel_input).any(): print("Model input: ", integratedModel_input)
            if torch.isnan(zts).any(): print("zts: ", zts)
            if torch.isnan(states_tensor).any(): print("States tensor: ", states_tensor)

        
        worldModel_optimizer.zero_grad()
        worldModel_loss.backward()
        torch.nn.utils.clip_grad_norm_(worldModel.parameters(), 10)#TODO: change back to paper values
        worldModel_optimizer.step()
        
        states = next_states
        hidden_states = ht.detach()
        # reset done envs
        for i, done in enumerate(dones):
            if done:
                envs[i].reset()
                states[i] = envs[i].state
                hidden_states[0][i] = torch.zeros((1, ht_size), device=device, dtype=torch.float)
                dones[i] = False


        experience_losses += rt_loss.detach() + ct_loss.detach()

        
        if logging:
            grad_norm = sum([torch.norm(param.grad) for param in worldModel.parameters()])
            weight_norm = sum([torch.norm(param) for param in worldModel.parameters()])
            wandb.log({"worldModel_loss": worldModel_loss.item(),
                       "worldModel_grads": grad_norm,
                       "representation_loss": representation_loss.item(),
                       "dynamic_loss": dynamic_loss.item(),
                       "prediction_loss": prediction_loss.item(),
                       "WorldModel_weight_norm": weight_norm})
            
    experience_losses /= steps#TODO: should be individual steps?
    #print(first_prediction_loss)
    actor.train()
    # Update priorities in the replay buffer
    # Ensure you have the indices of the sampled experiences. This is typically returned by the sample method of the replay buffer.
    indices = info['index']
    #print(f"Indices: {indices}")
    #print("Experience losses: ", experience_losses)
    replay_buffer.update_priority(indices, experience_losses.cpu().numpy())



In [207]:
def update_actor_critic(actor, critic, actor_optimizer, critic_optimizer, worldModel, actor_critics_replay_buffer, batch_size=32, logging=False):
    worldModel.eval()
    
    #sample from buffer
    experiences, info = actor_critics_replay_buffer.sample(batch_size, return_info=True)
    # Unpack experiences
    ht, zt, state_tensors, rewards, dones = experiences # actor states is zt and ht
    

    dones = dones.float().to(device).view(batch_size, 1).detach()
    rewards = rewards.float().to(device).view(batch_size, 1).detach()
    
    ht = ht.view(batch_size, 1, -1).to(device)

    
    # Critic loss
    imagination_horizon = 4 # TODO: 15 in paper
    gamma = 0.997
    return_lambda = 0.95
    original_input_state = torch.cat((ht.view(batch_size, -1), zt.view(batch_size,-1)), dim=1)
    critic_estimation = critic(original_input_state.detach()).view(batch_size, 1)
    value = critic_estimation.clone()
    lookahead_reward = rewards + (gamma * (1 - return_lambda) * critic_estimation * (1 - dones))
    with torch.no_grad():
        imagined_dones = dones.clone()
        im_actor_states = original_input_state
        imagined_xt = state_tensors.view(batch_size, -1)
        imagined_ht = ht.view(1, batch_size, -1) # not batch first
        imagined_zt = zt
        for t in range(0, imagination_horizon):
            if imagined_dones.all():
                break
            
            imagined_action_probs = actor(im_actor_states)
            imagined_action = torch.multinomial(imagined_action_probs, num_samples=1)
            
            
            im_action_one_hot = F.one_hot(imagined_action, n_actions).view(batch_size, -1).to(device)

            imagined_integratedModel_input = torch.cat((imagined_zt.view(batch_size, -1), im_action_one_hot.detach()), dim=1)
            # add batch because input needs to have batch, seq, input_size
            imagined_integratedModel_input = imagined_integratedModel_input.view(batch_size, 1, -1)
            
            # Now, use the flattened input for your integrated model
            im_next_hidden_state, im_zt_hat, im_rt_hat, im_ct_hat, im_xt_hat , _ = worldModel(imagined_integratedModel_input, imagined_xt, imagined_ht)

            im_next_hidden_state = im_next_hidden_state.view(batch_size, -1)
            
            # Create a mask where 0 indicates the state is done, and 1 indicates it's not done
            not_done_mask = 1 - imagined_dones
            
            # zero rewards if previously done
            # Element-wise multiplication between the reward tensor and the not-done mask
            im_rt_hat = im_rt_hat * not_done_mask

            
            imagined_dones = torch.logical_or(imagined_dones, im_ct_hat < 0.5).float()
            im_actor_states_hat = torch.cat((imagined_ht.view(batch_size, -1), imagined_zt.view(batch_size,-1)), dim=1)
            critic_estimation = critic(im_actor_states_hat.detach())
            imagined_reward =  im_rt_hat + (1 - return_lambda) * imagined_dones * critic_estimation

            #print("--- Imagined step ---")
            #print("Imagined reward: ", imagined_reward)
            #print("Imagined done: ", imagined_dones)
            #print("Imagined critic estimation: ", critic_estimation)
            #print("Imagined ct_hat: ", im_ct_hat)
            #print("Imagined rt_hat: ", im_rt_hat)
            #print("Imagined zt_hat NAN: ", torch.isnan(im_zt_hat).any())
            #print("Imagined zt NAN: ", torch.isnan(imagined_zt).any())
            #print("Imagined ht NAN: ", torch.isnan(imagined_ht).any())
            #print("Imagined ht: ", imagined_ht)
            #print("Imagined ht avg: ", imagined_ht.mean())
        
            
            lookahead_reward +=  (gamma ** t) * (return_lambda ** t) * imagined_reward
            
            
            imagined_ht = im_next_hidden_state.view(1, batch_size, -1)
            imagined_zt = im_zt_hat
            imagined_xt = im_xt_hat
            im_actor_states = im_actor_states_hat

        # one last critic estimation
        if not imagined_dones.all():
            lookahead_reward += imagined_dones * (gamma ** imagination_horizon) * (return_lambda ** imagination_horizon-1) * critic(im_actor_states_hat.detach()) #TODO: i think should go to next state, not previous critic estimation.
    
    target_value = lookahead_reward
    critic_loss = F.mse_loss(value, target_value.detach())
    
    # Compute TD-errors for the sampled experiences
    td_errors = torch.abs(target_value.detach() - value.detach()).squeeze()
    # Actor loss
    advantage = target_value - value.clone().detach()
    

    # Calculate Actor Loss
    # Assuming actions are in the form of indices of the chosen actions
    action_probs = actor( torch.cat((ht.view(batch_size, -1), zt.view(batch_size,-1)), dim=1).detach() )  # Probabilities of actions from the actor network
    # take actions from probs
    actions = torch.multinomial(action_probs, num_samples=1)
    
    # Gather the probabilities of the actions taken
    gathered_probs = action_probs.gather(1, actions.view(-1, 1))
    
    # Add a small number to probabilities to avoid log(0)
    actor_loss = -torch.log(gathered_probs + 1e-8) * advantage.detach()
    actor_loss = actor_loss.mean()

    #print("Advantage: ", advantage)
    #print("Target value: ", target_value)
    #print("Value: ", value)
    #print("Actor loss: ", actor_loss)
    #print("Critic loss: ", critic_loss)
    #print("TD errors: ", td_errors)
    #print("Actions: ", actions)
    #print("Action probs: ", action_probs)
    #print("Gathered probs: ", gathered_probs)

    # Backpropagation
    critic_optimizer.zero_grad()
    critic_loss.backward()
    torch.nn.utils.clip_grad_norm_(critic.parameters(), 10)
    critic_optimizer.step()

    actor_optimizer.zero_grad()
    actor_loss.backward()
    torch.nn.utils.clip_grad_norm_(actor.parameters(), 10)
    actor_optimizer.step()
    
    indices = info['index']
    actor_critics_replay_buffer.update_priority(indices, td_errors.cpu().numpy())

    worldModel.train()

    if logging:
        actor_grads = sum([torch.norm(param.grad) for param in actor.parameters()])
        critic_grads = sum([torch.norm(param.grad) for param in critic.parameters()])
        actor_weights_norm = sum([torch.norm(param) for param in actor.parameters()])
        critic_weights_norm = sum([torch.norm(param) for param in critic.parameters()])
        wandb.log({"critic_loss": critic_loss.item(),
                    "critic_grads": critic_grads,
                    "critic_weights_norm": critic_weights_norm,
                    "actor_loss": actor_loss.item(),
                    "actor_grads": actor_grads,
                    "actor_weights_norm": actor_weights_norm})

In [208]:
def collect_experience(encoded_state_flat, hidden_state, envs, n_envs):
    """
    Handles interaction with the environment and collects experience.
    """

    hidden_state = hidden_state.to(device)
    
    actor_state = torch.cat((hidden_state.view(n_envs, -1), encoded_state_flat), dim=1)
    
    action_probs = actor(actor_state)
    actions = torch.multinomial(action_probs, 1)
    
    next_states = []
    rewards = []
    dones = []
    for env,action in zip(envs,actions):
        next_state, reward, done, _, _ = env.step(action.item())
        # try non sparse reward:
        reward = reward + (next_state%4 + next_state//4)/(16 * 3)
        next_states.append(next_state)
        rewards.append(reward)
        dones.append(done)
    
    return actions, rewards, next_states, dones


In [209]:
# Training loop
batch_size = 16
training_delta = 100 # number of episodes to train world model before training actor critic
n_steps = batch_size + training_delta + 200
gamma = 0.997
Bpred = 1
Bdyn = 0.5
Brep = 0.1

logging = True

if logging:
    wandb.init(project="Dreamerv3 Reproduction",
               name="Trying non sparse rewards.",
               reinit=True,
               config={"n_steps": n_steps,
                       "gamma": gamma,
                       "latent_size": latent_size,
                       "num_categories": num_categories,
                       "ht_size": ht_size,
                       "reward_size": reward_size,
                       "continue_size": continue_size,
                       "worldModel_optimizer": worldModel_optimizer,
                       "actor_optimizer": actor_optimizer,
                       "critic_optimizer": critic_optimizer})


# Parameters
buffer_size = 10000  # Adjust as needed
alpha_world = 1  # Adjust as needed for prioritization (0 for uniform, 1 for fully prioritized)
alpha_actor_critic = 1

# Initialize buffer
world_model_replay_buffer = PrioritizedReplayBuffer(alpha=alpha_world, storage=ListStorage(buffer_size), beta=0.4)
actor_critics_replay_buffer = PrioritizedReplayBuffer(alpha=alpha_actor_critic, storage=ListStorage(buffer_size), beta=0.4)

env_name = 'FrozenLake-v1'
n_envs = batch_size//2


# Reset environment and episode-specific variables
envs = [gym.make(env_name) for _ in range(n_envs)]
layouts = convert_layout_to_tensor([env.desc for env in envs])
states = [env.reset()[0] for env in envs]
episode_steps_count = [0] * n_envs
dones = [False] * n_envs


hidden_states = torch.zeros((1, n_envs, ht_size), device=device, dtype=torch.float)
max_steps = n_states * n_actions

for step in tqdm(range(n_steps)):
    # --- Reset the environment if done ---
    for i in range(n_envs):
        episode_steps_count[i] += 1
        if episode_steps_count[i] > max_steps or dones[i]:
            states[i] = envs[i].reset()[0]
            episode_steps_count[i] = 0
            dones[i] = False
            hidden_states[0][i] = torch.zeros((1, ht_size), device=device, dtype=torch.float)
            

    # --- Collect Experience ---
    state_tensors = update_start_positions(layouts.clone(), states)
    state_tensors = state_tensors.view(n_envs, -1).to(device)

    with torch.no_grad():
    # Encode the state
        encoded_states = worldModel.encode(state_tensors, hidden_states)

    actions, rewards, next_states, dones = collect_experience(encoded_states, hidden_states, envs, n_envs)
    
    # Add experience to world model buffer
    for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
        #print(f"S: {state}, A: {action}, R: {reward}, S': {next_state}, D: {done}")
        world_model_replay_buffer.add((state, action.item(), reward, next_state, bool(done)))
    
    
    # --- World Model Update ---
    if len(world_model_replay_buffer) > batch_size:
        update_world_model(worldModel, worldModel_optimizer, world_model_replay_buffer, actor, batch_size=batch_size, logging=logging)
    
    # --- Actor-Critic Update ---
    if step > batch_size + training_delta:
        for hidden_state, encoded_state, state_tensor, reward, done in zip(hidden_states[0], encoded_states, state_tensors, rewards, dones):
            actor_critics_replay_buffer.add((hidden_state, encoded_state, state_tensor, reward, done))
        if len(world_model_replay_buffer) > batch_size:
            update_actor_critic(actor, critic, actor_optimizer, critic_optimizer, worldModel, actor_critics_replay_buffer, batch_size=batch_size, logging=logging)
    
    # --- Logging and State Updates ---
    if logging:
        wandb.log({"sum_rewards": sum(rewards)})
    
    states = next_states
    integratedModel_input = torch.cat((encoded_states, F.one_hot(actions, n_actions).view(n_envs, -1).to(device)), dim=1)
    integratedModel_input = integratedModel_input.view(n_envs, 1, -1)
    hidden_state, _, _, _, _, _ = worldModel(integratedModel_input, state_tensors, hidden_states)




# Close the environment
env.close()

if logging:
    wandb.finish()


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
sum_rewards,▁

0,1
sum_rewards,0.0625


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167378242438039, max=1.0…

100%|██████████| 316/316 [04:44<00:00,  1.11it/s]


VBox(children=(Label(value='0.001 MB of 0.006 MB uploaded\r'), FloatProgress(value=0.16369191188293938, max=1.…

0,1
WorldModel_weight_norm,██▇▆▆▆▅▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁
actor_grads,█▃▂▃▁▅▂▂▂▃▃▁▁▂▂▃▄▁▂▃▂▁▂▅▂▂▁▃▅▁▂▂▅▆▃▂▂▂▂▂
actor_loss,▁▅▇▇▆▆▆▄▅▆▇▆▇▆▇▅▇▆▆▄▆▅▅█▄▅▆▆▆▆▆▆▇█▅▆▆▆▄▅
actor_weights_norm,▇▇███████▇▇▇▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁▁▂▂▁▁▂▃▄
critic_grads,█▂▃▄▁▂▁▃▂▂▂▁▄▂▃▂▃▁▂▃▂▁▂▅▃▂▁▂▁▁▁▁▂▄▂▁▂▁▃▂
critic_loss,█▁▂▂▁▂▁▂▁▁▁▁▂▁▂▂▂▁▁▂▁▁▁▃▂▁▁▁▃▁▁▁▃▄▁▁▁▁▂▁
critic_weights_norm,▂▅█▇▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▂▂▁▁▂▂▂▂▂▂▂▂▂▂
dynamic_loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
prediction_loss,▇▅▂▄█▄▄▄█▄▁▇▄▇▆▁▃▄▁▇▄▄▆▂▄▄▁▄▃▇▄▅▄▆█▅▄▇▄▅
representation_loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
WorldModel_weight_norm,222.57423
actor_grads,0.1721
actor_loss,0.04325
actor_weights_norm,19.42466
critic_grads,1.33402
critic_loss,0.00493
critic_weights_norm,18.74509
dynamic_loss,0.01
prediction_loss,0.17509
representation_loss,0.01


In [210]:
# Test the model
num_episodes = 10
successes = 0
for _ in range(num_episodes):
    actor.eval()
    critic.eval()
    worldModel.eval()
    
    state = env.reset()[0]
    done = False
    episode_rewards = []
    layout_tensor = convert_layout_to_tensor([env.desc])
    hidden_state = torch.zeros((1, 1, ht_size), device=device, dtype=torch.float)
    
    autoencoder_losses = []
    
    max_steps = 100
    while not done and max_steps > 0:
        state_tensor = update_start_positions(layout_tensor.clone(), [state])
        state_tensor = state_tensor.view(1, -1).to(device)
        with torch.no_grad():
            # Encode the state
            encoded_state = worldModel.encode(state_tensor, hidden_state)
        encoded_state_flat = encoded_state.view(1, -1)
        
        actor_state = torch.cat((hidden_state.view(1, -1), encoded_state_flat), dim=1)
        action_probs = actor(actor_state)[0]

        action = torch.argmax(action_probs).item()
        action_one_hot = F.one_hot(torch.tensor(action), n_actions).view(1, -1).to(device)
        next_state, reward, done, _, _ = env.step(action)
        next_state_tensor = update_start_positions(layout_tensor.clone(), [next_state])

        # Store rewards for later use
        episode_rewards.append(reward)
        
        integratedModel_input = torch.cat((encoded_state_flat, action_one_hot.detach()), dim=1)
        # add batch because input needs to have batch, seq, input_size
        integratedModel_input = integratedModel_input.view(1, 1, -1)
        hidden_state, zt_hat, rt_hat, ct_hat, xt_hat, zt = worldModel(integratedModel_input, state_tensor  ,hidden_state)
        
        state = next_state
        max_steps -= 1

    
    if reward >= 1:
        successes += 1

# Close the environment
env.close()
print(f"Success rate: {successes/num_episodes}")
    

Success rate: 0.0


In [211]:
# Test the model
def test_model_with_diagnostics(env, actor, critic, worldModel, device, n_actions, convert_layout_to_tensor, ht_size):
    actor.eval()
    critic.eval()
    worldModel.eval()
    action_space = {0: "LEFT", 1: "DOWN", 2: "RIGHT", 3: "UP"}
    

    with torch.no_grad():
        for state in range(n_states):
            env.reset()
            env.state = state
            layout_tensor = convert_layout_to_tensor([env.desc])
            hidden_state = torch.zeros((1, 1, ht_size), device=device, dtype=torch.float)
    

            state_tensor = update_start_positions(layout_tensor.clone(), [state])
            state_tensor = state_tensor.view(1, -1).to(device)
            
            print(f"State {state}")
            print(f"Layout:")
            layout = convert_tensor_to_layout(state_tensor.view(1, 4, 4, 4))
            for line in layout[0]:
                print(line)
            # Encode the state
            encoded_state = worldModel.encode(state_tensor, hidden_state)
            encoded_state_flat = encoded_state.view(1, -1)
            
            actor_state = torch.cat((hidden_state.view(1, -1), encoded_state_flat), dim=1)
            action_probs = actor(actor_state)[0]
            #print(f"Action probabilities : {np.round(action_probs.cpu().numpy(),2)}")
    
            for action in range(n_actions):
                action_one_hot = F.one_hot(torch.tensor(action), n_actions).view(1, -1).to(device)
                integratedModel_input = torch.cat((encoded_state_flat, action_one_hot), dim=1)
                # add batch because input needs to have batch, seq, input_size
                integratedModel_input = integratedModel_input.view(1, 1, -1)
                hidden_state, zt_hat, rt_hat, ct_hat, xt_hat, zt = worldModel(integratedModel_input, state_tensor, hidden_state)
                
                layout = convert_tensor_to_layout(xt_hat.round().int().view(1, 4, 4, 4))
                print(f"World model predictions for action {action_space[action]}:")
                print(f"rt_hat: {rt_hat.item():.2f}, ct_hat: {ct_hat.item():.2f},")
                for line in layout[0]:
                    print(line)
                #print(f"Xt_hat: {xt_hat.round().int().view(1, 4, 4, 4)}")
                #print(f"Zt_hat: {zt_hat.round().int()}")
                #print(f"Zt: {zt.round().int()}")



# Call the function
test_model_with_diagnostics(env, actor, critic, worldModel, device, n_actions, convert_layout_to_tensor, ht_size)


State 0
Layout:
[b'S', b'F', b'F', b'F']
[b'F', b'H', b'F', b'H']
[b'F', b'F', b'F', b'H']
[b'H', b'F', b'F', b'G']
World model predictions for action LEFT:
rt_hat: 0.03, ct_hat: 1.04,
[b'S', b'F', b'F', b'F']
[b'F', b'H', b'F', b'H']
[b'F', b'F', b'F', b'H']
[b'H', b'F', b'F', b'G']
World model predictions for action DOWN:
rt_hat: 0.02, ct_hat: 1.06,
[b'S', b'F', b'F', b'F']
[b'F', b'H', b'F', b'H']
[b'F', b'F', b'F', b'H']
[b'H', b'F', b'F', b'G']
World model predictions for action RIGHT:
rt_hat: 0.02, ct_hat: 1.04,
[b'S', b'F', b'F', b'F']
[b'F', b'H', b'F', b'H']
[b'F', b'F', b'F', b'H']
[b'H', b'F', b'F', b'G']
World model predictions for action UP:
rt_hat: 0.01, ct_hat: 1.04,
[b'S', b'F', b'F', b'F']
[b'F', b'H', b'F', b'H']
[b'F', b'F', b'F', b'H']
[b'H', b'F', b'F', b'G']
State 1
Layout:
[b'F', b'S', b'F', b'F']
[b'F', b'H', b'F', b'H']
[b'F', b'F', b'F', b'H']
[b'H', b'F', b'F', b'G']
World model predictions for action LEFT:
rt_hat: 0.02, ct_hat: 1.08,
[b'S', b'F', b'F', b'F']

In [212]:
print(env.desc)

[[b'S' b'F' b'F' b'F']
 [b'F' b'H' b'F' b'H']
 [b'F' b'F' b'F' b'H']
 [b'H' b'F' b'F' b'G']]
