In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# FOR .py FILE!

**Disclosures**

The use of TQC+D2RL+ERE+Dreamer was inspired by this implementation which gave good results on BipedalWalker, which was used as a starting point - https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb

However, I made some important changes from this:
* The ERE buffer anneals 'eta' which determins the amount of emphasis to put on recent events. Originally, eta annealed over the course of each episode, then reset for the next one. In contrast, in this implementation eta anneals over the whole training process, and doesn't reset at any point. This is more in line with the original paper I believe.
* Attempt to introduce intrinsic rewards for exploring where dreamer is uncertain.



Paper references:
* TQC - https://bayesgroup.github.io/tqc/	
* ERE - https://arxiv.org/abs/1906.04009 
* D2RL - https://sites.google.com/view/d2rl/home
* Transformer Dreamer - https://arxiv.org/abs/2202.09481
* Intrinsic rewards - https://arxiv.org/abs/1908.06976
* Adjusting imagination horizon - https://arxiv.org/abs/2009.09593

# **Dependencies and imports**

This can take a minute...

In [2]:
# https://github.com/robert-lieck/rldurham/tree/main - rldurham source

# !pip install swig
# !pip install --upgrade rldurham

# !pip install wandb

In [3]:
import math
import random
import copy
import time
import os

import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import torch.distributions as D
from torch.distributions.transforms import TanhTransform

import rldurham as rld

import wandb
# os.environ["WANDB_API_KEY"] = "x" # TODO remove
# wandb.login()



# **RL agent**

In [4]:
MAX_TIMESTEPS = 2000 # [DONT CHANGE]
SOURCE_ID = '' # mac, ncc, colab -> for personal id in wandb


if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
DEVICE

device(type='mps')

## helper fns

In [5]:
# reward adjustment
def adjust_reward(reward):
    """
    adjust reward for bipedal walker
    """
    if reward == -100.0: 
        reward = -10.0 # bipedal walker does -100 for hitting floor, so -10 might make it more stable
    else:
        reward *= 2 # encourage forward motion
    return reward

In [6]:
# HELPERs
# for https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb

def train_on_environment(actor, env, grad_step_class, replay_buffer, max_timesteps,
    state, batch_size, total_steps, sequence_length, init_rand_steps=None, is_dreamer=False,
    use_intrinsic_reward=False, last_dreamer_test_loss=0.0, intrinsic_r_scale=0.0,
    ):
    '''
    abstracted training loop function for both real & dreamer environments to use
    '''
    ep_timesteps = 0
    ep_reward = 0

    state_dim = actor.state_dim
    act_dim = actor.action_dim
    
    # save start sequence for the dreamer model
    input_buffer = [] # store as python list to avoid cat-ing constantly

    # initial state float32 for mps
    state = state.astype(np.float32)

    for t in range(max_timesteps): 
        total_steps += 1
        ep_timesteps += 1

        # select action and step the env
        if init_rand_steps is not None and total_steps < init_rand_steps: # do random steps at start for warm-up
            action = env.action_space.sample()
        else: # otherwise use actor policy
            action = actor.select_action(state)

        step = env.step(action)

        # handle different return formats
        if len(step) == 5:
            next_state, reward, term, tr, info = step
            done = term or tr
        else:
            next_state, reward, done, info = step

        # float32 before processing/storing
        next_state = next_state.astype(np.float32)
        reward = float(reward) # extrinsic reward
        done = float(done)

        ep_reward += reward # sum extrinsic reward

        r_final = reward
        if not is_dreamer: # only adjust reward & add intrinsic for real env
            r_final = adjust_reward(reward) # adjust reward for bipedal walker
            if use_intrinsic_reward and last_dreamer_test_loss > 0:
                r_intrinsic = intrinsic_r_scale * last_dreamer_test_loss
                r_final += r_intrinsic
                wandb.log({"Intrinsic Reward Added": r_intrinsic}, commit=False) # log to wandb (commit = false lets us log things later in the ep)

        replay_buffer.add(state, action, next_state, r_final, done) # add to replay buffer for both real and dreamer

        if not is_dreamer and t < sequence_length: # store in input buffer for dreamers first seq
            trans = np.concatenate((
                state, 
                action.astype(np.float32), 
                next_state, 
                np.array([reward], dtype=np.float32), # NOTE - store original dreamer reward, not adjusted r_final
                np.array([done], dtype=np.float32)
            ), axis=0)
            # combined = torch.tensor(trans, dtype=torch.float32).unsqueeze(0).to(DEVICE)
            # input_buffer = torch.cat([input_buffer, combined], axis=0)
            input_buffer.append(trans)
    
        state = next_state
        
        # train step if buffer has enough samples for a batch AND we're past warmup
        if total_steps >= batch_size and init_rand_steps is not None and total_steps >= init_rand_steps:
            # train the agent using experiences from the real environment
            grad_step_class.take_gradient_step(replay_buffer, total_steps, batch_size)
    
        if done: # break if finished
            break

    final_input_buffer = None
    if not is_dreamer and len(input_buffer) > 0: # return if any steps were collected
        try:
            # Ensure all collected transitions have the same dimension before converting
            np_buffer = np.array(input_buffer)
            final_input_buffer = torch.tensor(np_buffer, dtype=torch.float32).to(DEVICE)
        except ValueError as e:
            print(f"Warning: Error converting input_buffer to tensor: {e}")
            print(f"Input buffer contents (first few elements): {input_buffer[:3]}")

    # only return extrinsic to eval based on that
    return ep_timesteps, ep_reward, final_input_buffer, info


In [7]:
# https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb

# generate value array from replay buffer
def retrieve_from_replay_buffer(replay_buffer, ptr):
    '''
    returns (state, action, reward, done, next_state)
    '''
    return np.concatenate((
            replay_buffer.state[ptr:replay_buffer.ptr],
            replay_buffer.action[ptr:replay_buffer.ptr],
            replay_buffer.reward[ptr:replay_buffer.ptr],
            1. - replay_buffer.not_done[ptr:replay_buffer.ptr],
            replay_buffer.next_state[ptr:replay_buffer.ptr],                
        ), 
        axis = 1 # along the columns (so each row is a memory)
    )

# function to create sequences with a given window size and step size
def create_sequences(memories, window_size, step_size):
    '''
    Function to create sequences of memories from the replay buffer.

    Each sequence is of length window_size and each spaced apart by step_size
    '''
    n_memories = memories.shape[0] # just the number of time steps currently in the buffer

    # calc number of seq (of len window_size) we can create from mems available  
    n_sequences = math.floor((n_memories - window_size) / step_size) + 1 # +1 because indices start at 0

    sequences = np.zeros((n_sequences, window_size, memories.shape[1]))
    for i in range(n_sequences):
        start_idx = i * step_size # idx to start seq from
        sequences[i, :] = memories[start_idx:start_idx + window_size, :] # grab seq of memories window_size long from start_idx
    return sequences

def gen_test_train_seq(replay_buffer, train_set, test_set, train_split, window_size, step_size, ptr):
    '''
    Function to split train and test data
    '''
    memories = retrieve_from_replay_buffer(replay_buffer, ptr)
    try:
        memory_sequences = create_sequences(memories, window_size, step_size)
        n_sequences = memory_sequences.shape[0]
    except: # TODO How might this go wrong? not enough new data to create a sequence?
        return train_set, test_set

    # shuffle the sequences & split
    indices = np.arange(n_sequences)
    np.random.shuffle(indices)

    split = int(train_split * n_sequences) # get split point 
    train_indices = indices[:split]
    test_indices = indices[split:]

    if train_set is None: # if this is first train/test set, create the sets
        return memory_sequences[train_indices, :], memory_sequences[test_indices, :]
    # else just add to existing sets
    return np.concatenate((train_set, memory_sequences[train_indices, :]), axis=0), np.concatenate((test_set, memory_sequences[test_indices, :]), axis=0)


In [8]:
def get_dreamer_eps(dreamer_avg_loss, loss_threshold):
    '''
    Get number of dreamed episdoes the agent should run on the dreamer model.

    Only use dreamer for training if it's sufficiently accurate model of env.

    Note - loss_threshold is between 0 and 1.

    Disclaimer - calculation from https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb
    '''
    max_dreamer_it = 10 # a perfect dreamer has this many eps

    if dreamer_avg_loss >= loss_threshold:
        return 0
    else:
        norm_score = dreamer_avg_loss/loss_threshold # normalise score relative to threshold
        inv_score = 1 - norm_score # invert so that closer to 1 ==> dreamer score much better than threshold
        sq_score = inv_score**2 # square so that num iter increases quadratically as accuracy improves
        return int(max_dreamer_it * sq_score) # scale so that max iterations is 10 (when dreamer very accurate)


## replay buffer

In [None]:
class EREReplayBuffer(object):
    '''
    ERE implementation - https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb
    ERE paper - https://arxiv.org/abs/1906.04009
    
    Prioritized Experience Replay (PER) buffer, as in https://arxiv.org/abs/1511.05952.
    Implementation: https://github.com/BY571/Soft-Actor-Critic-and-Extensions/blob/master/SAC_PER.py
        PER parameters are sourced from this implementation.


    Rationale:
    - actions from early in training are likely v different to optimal, so want to prioritise recent ones so we're not constantly learning from crap ones
    - as time goes on, they get more similar so annealing eta allows uniform sampling later on
    '''
    def __init__(self, state_dim, action_dim, T, max_size, eta, cmin, use_per=True,
                beta_1=0.6, beta_2_start=0.4, beta_2_frames=int(1e5), epsilon=1e-6, recency_scale=1):

        self.max_size, self.ptr, self.size, self.rollover = max_size, 0, 0, False
        self.use_per = use_per # use PER flag
        
        # ERE params
        self.eta0 = eta
        self.cmin = cmin
        self.T = T
        self.recency_scale = recency_scale # when using PER too, this helps calculate ck

        # PER params
        self.beta_1 = beta_1
        self.beta_2_start = beta_2_start
        self.beta_2_frames = beta_2_frames
        self.epsilon = epsilon
        self.frame = 1 # for beta calculation

        # storage
        self.reward = np.empty((max_size, 1), dtype=np.float32) # float32 for mps
        self.state = np.empty((max_size, state_dim), dtype=np.float32)
        self.action = np.empty((max_size, action_dim), dtype=np.float32)
        self.not_done = np.empty((max_size, 1), dtype=np.float32)
        self.next_state = np.empty((max_size, state_dim), dtype=np.float32)
        self.priorities = np.zeros((max_size, 1), dtype=np.float32) # for PER
        self.max_prio = 1.0 # max priority seen so far

    def eta_anneal(self, t):
        # eta anneals over time --> 1, which reduces emphasis on recent experiences over time
        return min(1.0, self.eta0 + (1 - self.eta0) * t / self.T)

    def beta_2_by_frame(self, frame_idx):
        """
        Linearly increases beta from beta_2_start to 1 over time from 1 to beta_2_frames.
        
        see 3.4 ANNEALING THE BIAS (Paper: PER)
        """
        return min(1.0, self.beta_2_start + frame_idx * (1.0 - self.beta_2_start) / self.beta_2_frames)
 
    def add(self, state, action, next_state, reward, done):
        # Add experience to replay buffer 
        self.state[self.ptr] = state.astype(np.float32)
        self.action[self.ptr] = action.astype(np.float32)
        self.next_state[self.ptr] = next_state.astype(np.float32)
        self.reward[self.ptr] = float(reward)
        self.not_done[self.ptr] = 1. - float(done)
        self.priorities[self.ptr] = self.max_prio if self.size > 0 else 1.0 # gives max priority if buffer is not empty else 1 TODO

        self.ptr += 1
        self.ptr %= self.max_size

        # increase size counter until full, then start overwriting (rollover)
        if self.max_size > self.size + 1:
          self.size += 1
        else:
          self.size = self.max_size
          self.rollover = True

    def sample(self, batch_size, t):
        # update eta value for current timestep
        eta = self.eta_anneal(t)

        # get ERE window -----
        # get ck
        c_calc = self.size * eta ** (t / self.T * self.recency_scale)
        ck = int(max(self.cmin, c_calc)) # at least cmin samples
        ck = min(ck, self.size) # limit to buffer size

        # Determine indices within the recent window (handle rollover)
        if not self.rollover:
            # Buffer not full yet, window is from [size - ck, size)
            indices = np.arange(max(0, self.size - ck), self.size)
        else:
            # Buffer has rolled over, window wraps around
            start_idx = (self.ptr - ck + self.max_size) % self.max_size
            if start_idx < self.ptr:
                indices = np.arange(start_idx, self.ptr)
            else:
                indices = np.concatenate((np.arange(start_idx, self.max_size), np.arange(0, self.ptr)))

        # ensure ck matches the actual number of indices determined
        ck = len(indices)
        if ck <= 0:
            print('uh oh, ck is 0! This shouldnt happen....')

        # PER -----
        if self.use_per:
            # calc P = p^a/sum(p^a)
            prios = self.priorities[indices].flatten() + self.epsilon # add epsilon to avoid zero probs
            probs = prios ** self.beta_1
            probs_sum = probs.sum()
            if probs_sum <= 0: # avoid division by zero if all prios = zero
                P = np.ones(ck, dtype=np.float32) / ck # uniform sample
            else:
                P = probs / probs_sum
            
            # gets the indices depending on the probability p and the c_k range of the buffer
            rel_indices = np.random.choice(ck, batch_size, p=P, replace=True)
            samples = indices[rel_indices]
            
            beta_2 = self.beta_2_by_frame(self.frame)
            self.frame += 1 # incremement for annealing
                    
            # Compute importance-sampling weights w = (N * P)^(-beta_2)
            weights  = (ck * P[rel_indices]) ** (-beta_2)
            # normalize weights
            weights /= weights.max() 
            weights = np.array(weights, dtype=np.float32)
        else:
            print('PER not used, using uniform sampling') if self.use_per else None
            # otherwise sample uniformly like in basic ere
            rel_indices = np.random.choice(ck, batch_size, replace=True)
            samples = indices[rel_indices]
            # return weights of 1
            weights = np.ones(batch_size, dtype=np.float32)


        r = torch.tensor(self.reward[samples], dtype=torch.float32).to(DEVICE)
        s = torch.tensor(self.state[samples], dtype=torch.float32).to(DEVICE)
        ns = torch.tensor(self.next_state[samples], dtype=torch.float32).to(DEVICE)
        a = torch.tensor(self.action[samples], dtype=torch.float32).to(DEVICE)
        nd = torch.tensor(self.not_done[samples], dtype=torch.float32).to(DEVICE)
        
        return s, a, ns, r, nd, samples, weights

    def update_priorities(self, batch_indices, batch_priorities):
        priorities = np.abs(batch_priorities) + self.epsilon # ensure > 0

        self.priorities[batch_indices] = priorities.reshape(-1, 1) 
        self.max_prio = max(self.max_prio, np.max(priorities)) # update max prio


## agent

**dreamer**

In [11]:
class DreamerAgent(nn.Module):
    '''
    Dreamer agent.
    Uses a transformer model to predict the next state, reward and done signal given the current (state, action, reward, done)

    Acknowledgements:
    - implementation based on https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb
    - (from implementation designer) key ideas and concepts for the auto-regressive transformer design stem from this paper: https://arxiv.org/abs/2202.09481
    '''
    def __init__(self, state_dim, action_dim, hidden_dim, seq_len, num_layers, num_heads, dropout_prob, lr, weight_decay):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.seq_len = seq_len
        self.input_dim = state_dim + action_dim + 2 # (state, action, reward, done)
        self.target_dim = self.state_dim + self.action_dim # (state, action)
        self.hidden_dim = hidden_dim
        
        self.mse_loss = nn.MSELoss()
        self.bce_loss = nn.BCEWithLogitsLoss()
        
        self.input_fc = nn.Linear(self.input_dim, hidden_dim).to(DEVICE)
        self.target_fc = nn.Linear(self.target_dim, hidden_dim).to(DEVICE)

        self.transformer = nn.Transformer( # uses transformer model!
            d_model=hidden_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dropout=dropout_prob,
            activation=F.gelu,
            batch_first=True
        ).to(DEVICE)

        self.output_next_state = nn.Linear(hidden_dim + self.target_dim, state_dim).to(DEVICE)
        self.output_reward = nn.Linear(hidden_dim + self.target_dim, 1).to(DEVICE)
        self.output_done = nn.Linear(hidden_dim + self.target_dim, 1).to(DEVICE)
        self.optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
        
    # separate out the ground truth variables and compare against predictions
    def loss_fn(self, output_next_state, output_reward, output_done, ground_truth):
        reward, done, next_state = torch.split(ground_truth, [1, 1, self.state_dim], dim=-1)
        loss = self.mse_loss(output_next_state[:, -1], next_state)
        loss += self.mse_loss(output_reward[:, -1], reward)
        loss += self.bce_loss(output_done[:, -1], done.float())
        return loss

    def forward(self, input_tensor):
        # separate the input and target tensors
        target = input_tensor[:, -1, :self.target_dim].unsqueeze(1)
        encoded_target = self.target_fc(target)
        encoded_input = self.input_fc(input_tensor[:, :-1, :self.input_dim])

        # pass these into the transformer
        encoded_output = self.transformer(encoded_input, encoded_target)

        # decode the densely connected output
        output_next_state = self.output_next_state(torch.cat([encoded_output, target], axis=2))
        output_reward = self.output_reward(torch.cat([encoded_output, target], axis=2))
        output_done = self.output_done(torch.cat([encoded_output, target], axis=2)) # don't need sigmoid because nn.CrossEntropy already does it? (or can use BCE loss after sigmoid)
        # output_done = torch.sigmoid(self.output_done(torch.cat([encoded_output, target], axis=2)))
        return output_next_state, output_reward, output_done
    
    def predict(self, input_tensor, target_tensor):
        # separate the input and target tensors
        encoded_target = self.target_fc(target_tensor)
        encoded_input = self.input_fc(input_tensor)

        # pass these into the transformer
        encoded_output = self.transformer(encoded_input, encoded_target)

        # decode the densely connected output
        output_next_state = self.output_next_state(torch.cat([encoded_output, target_tensor], axis=1))
        output_reward = self.output_reward(torch.cat([encoded_output, target_tensor], axis=1))
        output_done = torch.sigmoid(self.output_done(torch.cat([encoded_output, target_tensor], axis=1)))
        return output_next_state, output_reward, output_done

    # transformer training loop
    # sequences shape: (batch, sequence, features)
    def train_dreamer(self, sequences, epochs, batch_size=256):
        print("Training Dreamer...")
        inputs = torch.tensor(sequences, dtype=torch.float).to(DEVICE)
        
        train_dataset = TensorDataset(inputs)
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        self.transformer.train()

        for epoch in range(epochs):
            running_loss = 0.0
            for i, input_batch in enumerate(train_dataloader):
                input_batch = input_batch[0].to(DEVICE)
                self.optimizer.zero_grad()
                output_next_state, output_reward, output_done = self.forward(input_batch)
                loss = self.loss_fn(output_next_state, output_reward, output_done, input_batch[:, -1, self.target_dim:])
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()
            print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, epochs, running_loss / len(train_dataloader)))

    # transformer testing loop
    def test_dreamer(self, sequences, batch_size=64):
        print("Testing Dreamer...")
        inputs = torch.tensor(sequences, dtype=torch.float).to(DEVICE)
        
        test_dataset = TensorDataset(inputs)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

        self.transformer.eval()

        with torch.no_grad():
            running_loss = 0.0
            for i, input_batch in enumerate(test_dataloader):
                input_batch = input_batch[0].to(DEVICE)
                output_next_state, output_reward, output_done = self.forward(input_batch)
                loss = self.loss_fn(output_next_state, output_reward, output_done, input_batch[:, -1, self.target_dim:])
                running_loss += loss.item()
            print('Test Loss: {:.4f}'.format(running_loss / len(test_dataloader)))
        return running_loss / len(test_dataloader)
    
    def step(self, action):
        act = torch.tensor(np.array([action]), dtype=torch.float32).to(DEVICE)
        self.actions = torch.cat([self.actions, act], axis=0)

        input_sequence = torch.cat([self.states[:-1], self.actions[:-1], self.rewards, self.dones], axis=1).to(torch.float32)
        target = torch.cat([self.states[-1], self.actions[-1]], axis=0).unsqueeze(0).to(torch.float32)
        
        with torch.no_grad():
            next_state, reward, done = self.predict(input_sequence, target)
            next_state = next_state.to(torch.float32)
            reward = reward.to(torch.float32)
            done = (done >= 0.6).to(torch.float32) # bias towards not done to avoid false positives

            self.states = torch.cat([self.states, next_state], axis=0)
            self.rewards = torch.cat([self.rewards, reward], axis=0)
            self.dones = torch.cat([self.dones, done], axis=0)
            
            # trim sequences
            if self.states.shape[0] > self.seq_len:
                self.states = self.states[1:]
                self.rewards = self.rewards[1:]
                self.dones = self.dones[1:]
            if self.actions.shape[0] > self.seq_len - 1: # action seq shorter
                self.actions = self.actions[1:]
        
        return next_state.squeeze(0).cpu().numpy(), reward.cpu().item(), done.cpu().item(), None


**actor-critic**

In [None]:
# LOSS
# https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb

def quantile_huber_loss(quantiles, samples, sum_over_quantiles=False):
    '''
    From TQC (see paper p3)
    Specific implementation: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/utils.py#L8

    Huber loss is less sensitive to outliers than MSE

    samples: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles)
    quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1)
    pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles)
    '''
    # uses a squared term if the absolute element-wise error falls below delta and a delta-scaled L1 term otherwise
    delta = samples.unsqueeze(1) - quantiles.unsqueeze(3) # (batch_size, 1, 1, n_target_quantiles) - (batch_size, n_critics, n_quantiles, 1)
    abs_delta = torch.abs(delta)
    huber_loss = torch.where(abs_delta > 1., abs_delta - 0.5, delta ** 2 * 0.5) # 1.0 as threshold k for Huber loss
    n_quantiles = quantiles.shape[2]

    # cumulative probabilities to calc quantiles
    cum_prob = (torch.arange(n_quantiles, device=quantiles.device, dtype=torch.float) + 0.5) / n_quantiles
    cum_prob = cum_prob.view(1, 1, -1, 1) # quantiles has shape (batch_size, n_critics, n_quantiles), so make cum_prob broadcastable to (batch_size, n_critics, n_quantiles, n_target_quantiles)
    
    # Calculate quantile loss: |τ - I(δ < 0)| * L_k(δ)
    # τ = cum_prob, I(δ < 0) = (delta < 0).float(), L_k(δ) = huber_loss
    loss = (torch.abs(cum_prob - (delta < 0).float()) * huber_loss)

    # Summing over the quantile dimension 
    if sum_over_quantiles:
        # sum over quantiles
        # then average over target quantiles
        loss = loss.sum(dim=2).mean(dim=2) # (batch_size, n_critics)
    else:
        loss = loss.mean()

    return loss

In [13]:
# TODO - remove this
# Standard MLP for actor (without D2RL)
class StandardActorMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        layer_sizes = [input_dim] + hidden_dims
        self.layers = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            self.layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]).to(DEVICE))

        self.last_layer_mean_linear = nn.Linear(hidden_dims[-1], output_dim).to(DEVICE)
        self.last_layer_log_std_linear = nn.Linear(hidden_dims[-1], output_dim).to(DEVICE)

    def forward(self, input_):
        curr = input_
        for layer in self.layers:
            curr = F.gelu(layer(curr)) # Or F.relu

        mean_linear = self.last_layer_mean_linear(curr)
        log_std_linear = self.last_layer_log_std_linear(curr)
        return mean_linear, log_std_linear
        
# Standard MLP for critic (without D2RL)
class StandardCriticMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        layer_sizes = [input_dim] + hidden_dims
        self.layers = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            self.layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]).to(DEVICE))

        self.last_layer = nn.Linear(hidden_dims[-1], output_dim).to(DEVICE)

    def forward(self, input_):
        curr = input_
        for layer in self.layers:
            curr = F.gelu(layer(curr)) # Or F.relu
        output = self.last_layer(curr)
        return output


# https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb

# MLP for actor that implements D2RL architecture
class ActorMLP(nn.Module):
  def __init__(self, input_dim, hidden_dims, output_dim):
    # input size = state dim, output size = action dim
    super().__init__()
    self.layer_list = nn.ModuleList()
    self.input_dim = input_dim

    current_dim = input_dim # first layer has input dim = state dim
    for i, next_size in enumerate(hidden_dims):
      layer = nn.Linear(current_dim, next_size).to(DEVICE)
      self.layer_list.append(layer)
      current_dim = next_size + self.input_dim # prev layers output + original input size
        
    # Final layer input dim is last hidden layer output + original input size
    self.last_layer_mean_linear = nn.Linear(current_dim, output_dim).to(DEVICE)
    self.last_layer_log_std_linear = nn.Linear(current_dim, output_dim).to(DEVICE)

  def forward(self, input_):
    curr = input_
    for layer in self.layer_list:
      curr = F.gelu(layer(curr))
      curr = torch.cat([curr, input_], dim=1) # cat with output layer

    mean_linear = self.last_layer_mean_linear(curr)
    log_std_linear = self.last_layer_log_std_linear(curr)
    return mean_linear, log_std_linear

# MLP for critic that implements D2RL architecture 
class CriticMLP(nn.Module):
  def __init__(self, input_dim, hidden_dims, output_dim):
    # input size = state dim + action dim, output size = n_quantiles
    super().__init__()
    self.layer_list = nn.ModuleList()
    self.input_dim = input_dim

    current_dim = input_dim
    for i, next_size in enumerate(hidden_dims):
      layer = nn.Linear(current_dim, next_size).to(DEVICE)
      self.layer_list.append(layer)
      current_dim = next_size + self.input_dim

    self.last_layer = nn.Linear(current_dim, output_dim).to(DEVICE)

  def forward(self, input_):
    curr = input_
    for layer in self.layer_list:
      curr = F.gelu(layer(curr))
      curr = torch.cat([curr, input_], dim=1)
      
    output = self.last_layer(curr)
    return output


In [None]:
# https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb

class GradientStep(object):
  '''
  see (D2RL) https://github.com/pairlab/d2rl/blob/main/sac/sac.py
  and (TQC) https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/tqc/tqc.py
  '''
  def __init__(self,*,
    actor, critic, critic_target, discount, tau,
    actor_lr, critic_lr, alpha_lr,
    n_quantiles, n_mini_critics, top_quantiles_to_drop_per_net, target_entropy,
    use_per=True
    ):

    self.actor = actor
    self.critic = critic
    self.critic_target = critic_target
    self.use_per = use_per

    self.log_alpha = nn.Parameter(torch.zeros(1).to(DEVICE)) # log alpha is learned
    self.quantiles_total = n_quantiles * n_mini_critics
    
    self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
    self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
    self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

    self.discount, self.tau = discount, tau
    self.top_quantiles_to_drop = top_quantiles_to_drop_per_net * n_mini_critics # total number of quantiles to drop
    self.target_entropy = target_entropy

  def take_gradient_step(self, replay_buffer, total_steps, batch_size=256):
    # Sample batch from replay buffer
    state, action, next_state, reward, not_done, indices, weights = replay_buffer.sample(batch_size, total_steps)
    weights = torch.tensor(weights, dtype=torch.float32).to(DEVICE).unsqueeze(1) # PER weights. add dim for broadcast
    alpha = torch.exp(self.log_alpha) # entropy temperature coefficient

    with torch.no_grad():
      # sample new action from actor on next state
      new_next_action, next_log_pi = self.actor(next_state)

      # Compute and cut quantiles at the next state
      next_z = self.critic_target(next_state, new_next_action)
      
      # Sort and drop top k quantiles to control overestimation (TQC)
      sorted_z, _ = torch.sort(next_z.reshape(batch_size, -1))
      sorted_z_part = sorted_z[:, :self.quantiles_total-self.top_quantiles_to_drop] # estimated truncated Q-val dist for next state

      # td error + entropy term
      target = reward + not_done * self.discount * (sorted_z_part - alpha * next_log_pi)
    
    # Get current quantile estimates using action from the replay buffer
    cur_z = self.critic(state, action)
    per_critic_loss = quantile_huber_loss(cur_z, target.unsqueeze(1), sum_over_quantiles=True) # keep quantile dim for now
    critic_loss = (per_critic_loss * weights).mean() # PER loss

    new_action, log_pi = self.actor(state)
    # detach the variable from the graph so we don't change it with other losses
    alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() # as in D2RL implementation for auto entropy tuning

    # Optimise critic
    self.critic_optimizer.zero_grad()
    critic_loss.backward()
    self.critic_optimizer.step()

    # update PER prios
    if self.use_per and indices is not None:
      avg_critic_loss = per_critic_loss.mean(1) # average over critics
      new_prios = avg_critic_loss.detach().cpu().numpy()
      replay_buffer.update_priorities(indices, new_prios)

    # Soft update target networks
    for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
      target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    
    # Compute actor (π) loss
    # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
    actor_loss = (alpha * log_pi - self.critic(state, new_action).mean(2).mean(1, keepdim=True)).mean()
    # ^ mean(2) is over quantiles, mean(1) is over critic ensemble
    
    # Optimise the actor
    self.actor_optimizer.zero_grad()
    actor_loss.backward()
    self.actor_optimizer.step()

    # Optimise the entropy coefficient
    self.alpha_optimizer.zero_grad()
    alpha_loss.backward()
    self.alpha_optimizer.step()


In [15]:
# https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dims=[512, 512], use_d2rl=True):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        if use_d2rl:
            self.mlp = ActorMLP(state_dim, hidden_dims, action_dim)
        else:
            self.mlp = StandardActorMLP(state_dim, hidden_dims, action_dim) # Use standard MLP

    def forward(self, obs):
        mean, log_std = self.mlp(obs)
        log_std = log_std.clamp(-20, 2) # clamp for stability
        std = torch.exp(log_std)

        base_N_dist = D.Normal(mean, std) # base normal dist
        tanh_transform = TanhTransform(cache_size=1) # transform to get the tanh dist

        log_prob = None
        if self.training: # i.e. agent.train()
            transformed_dist = D.TransformedDistribution(base_N_dist, tanh_transform) # transformed distribution
            action = transformed_dist.rsample() # samples from base dist & applies transform
            log_prob = transformed_dist.log_prob(action) # log prob of action after transform
            log_prob = log_prob.sum(dim=1, keepdim=True) # sum over action dim
        else: # evaluation mode
            action = torch.tanh(mean)
            
        return action, log_prob

    def select_action(self, obs):
        obs = torch.tensor(obs, dtype=torch.float32).to(DEVICE)
        if obs.ndim == 1: # add batch dim if missing
             obs = obs.unsqueeze(0)
        act, _ = self.forward(obs)
        return np.array(act[0].cpu().detach())


class Critic(nn.Module): # really a mega-critic from lots of mini-critics
    '''
    Ensemble of critics for TQC
    '''
    def __init__(self, state_dim, action_dim, n_quantiles, n_nets, hidden_dims=[256, 256], use_d2rl=True):
        super().__init__()
        self.critics = nn.ModuleList()
        self.n_quantiles = n_quantiles

        for _ in range(n_nets): # multiple critic mlps
            if use_d2rl:
                net = CriticMLP(state_dim + action_dim, hidden_dims, n_quantiles)
            else:
                net = StandardCriticMLP(state_dim + action_dim, hidden_dims, n_quantiles) # Use standard MLP
            self.critics.append(net)

    def forward(self, state, action):
        # cat state and action (to pass to critic)
        state_act = torch.cat((state, action), dim=1)

        # pool quantiles from each critic mlp
        quantiles = [critic(state_act) for critic in self.critics]
        quantiles = torch.stack(quantiles, dim=1) # stack into tensor
        return quantiles


# Parameters and Training

**Prepare the environment and wrap it to capture statistics, logs, and videos**

In [None]:
# My hyperparams

seed = 42 # DONT CHANGE FOR COURSEWORK

hyperparams = {
    # env/general params
    "max_timesteps": MAX_TIMESTEPS, # per episode [DONT CHANGE]
    "max_episodes": 100,
    "target_score": 300, # stop training when average score over r_list > target_score
    "len_r_list": 100, # length of reward list to average over for target score (stop training when avg > target_score)
    "hardcore": False, # fixed in wandb sweep
    "init_rand_steps": 10000, # number of steps to take with random actions before training (helps exploration)

    # Agent hyperparams (from https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb)
    "use_d2rl": True, # use D2RL architecture for actor/critic TODO remove after ablation
    "n_mini_critics": 5, # each mini-critic is a single mlp, which combine to make one mega-critic
    "n_quantiles": 20, # quantiles per mini critic
    "top_quantiles_to_drop_per_net": 'auto', # per mini critic (auto based on n_quantiles)
    "actor_hidden_dims": [512, 512],
    "mini_critic_hidden_dims": [256, 256], # * n_mini_critics
    "batch_size": 256,
    "discount": 0.98, # gamma
    "tau": 0.005,
    "actor_lr": 3.29e-4, # empirically chosen
    "critic_lr": 3.8e-4, # empirically chosen
    "alpha_lr": 3.24e-4, # empirically chosen

    # ERE buffer (see paper for their choices)
    "use_per": False, # use PER sampling as well (DO NOT USE ON MAC IT BREAKS AHHH)
    "buffer_size": 100000, # smaller size improves learning early on but is outperformed later on
    "eta0": 0.996, # 0.994 - 0.999 is good (according to paper)
    "annealing_steps": 'auto', # number of steps to anneal eta over (after which sampling is uniform) - None = auto-set to max estimated steps in training
    "cmin": 5000, # min number of samples to sample from
    "recency_scale": 1, # scale factor for recency

    # dreamer hyperparams (from https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb)
    # info on hyperparam choice: https://arxiv.org/abs/1912.01603
    "use_dreamer": False,
    "intrinsic_reward_scale": 0., # scale factor for dreamer intrinsic reward
    "batch_size_dreamer": 512,
    "hidden_dim": 256,
    "num_layers": 4,
    "num_heads": 8,
    "dreamer_lr": 3e-4,
    "dreamer_weight_decay": 1e-4,
    "dropout_prob": 0.1,
    "window_size": 40,               # transformer context window size
    "step_size": 1,                  # how many timesteps to skip between each context window
    "train_split": 0.8,              # train/validation split
    "loss_threshold": 10,           # use dreamer if loss < loss_threshold
    "imagination_horizon": 15,       # how many timesteps to run the dreamer model for (H in Dreamer paper)
    "dreamer_train_epochs": 10,      # how many epochs to train the dreamer model for
    "dreamer_train_frequency": 10,   # how often to train the dreamer model
    "episode_threshold": 20,         # how many episodes to run before training the dreamer model
    "max_size": 50000,               # maximum size of the training set for the dreamer model
}

# recording/logging
plot_interval = 100 # plot every Nth episode (wandb still plots every ep)
save_fig = True # save figures too

is_recording = True # TODO dont forget this
if hyperparams['hardcore']: # NOTE this doesnt change in wandb sweep
    video_interval = 30 # record every Nth episode
    ep_start_rec = 500 # start recording on this episode
else:
    video_interval = 20
    ep_start_rec = 50

if hyperparams['annealing_steps'] == 'auto':
    hyperparams['annealing_steps'] = hyperparams['max_episodes']*hyperparams['max_timesteps'] # max est number of steps in training
if hyperparams['top_quantiles_to_drop_per_net'] == 'auto':
    hyperparams['top_quantiles_to_drop_per_net'] = int(hyperparams['n_quantiles'] // 12.5) # keep ratio same as M=25 d=2 (from TQC paper)

In [17]:
## SETUP

# wandb & (potential) checkpoints

# load checkpoint here if using
wandb_run_id = None
start_episode = 1 # Default start episode

# init wandb run
import datetime
suffix = '_hardcore' if hyperparams['hardcore'] else ''
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
run_name = SOURCE_ID + f"_{timestamp}" + suffix

wandb_init_kwargs = {
    "project": "RL-Coursework-Walker2d",
    "config": hyperparams,
    "name": run_name,
    "save_code": True, # optionally saves code to wandb
}
if wandb_run_id:
    wandb_init_kwargs["id"] = wandb_run_id
    wandb_init_kwargs["resume"] = "allow" # Or "must" if you require resuming

wandb.init(**wandb_init_kwargs)

config = wandb.config # use wandb.config to access hyperparameters


[34m[1mwandb[0m: Currently logged in as: [33mtheo-farrell99[0m ([33mtheo-farrell99-durham-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [18]:
# make env
# only attempt hardcore when your agent has solved the non-hardcore version
env = rld.make("rldurham/Walker", render_mode="rgb_array", hardcore=config.hardcore)

# get statistics, logs, and videos
video_prefix = "nchw73-agent-hardcore-video" if config.hardcore else "nchw73-agent-video"
video_prefix = f"{timestamp}_" + video_prefix # mark video with date/time to be mateched to run (remove later!)
env = rld.Recorder(
    env,
    smoothing=10,                       # track rolling averages (useful for plotting)
    video=True,                         # enable recording videos
    video_folder="videos",              # folder for videos
    video_prefix=video_prefix,          # prefix for videos
    logs=True,                          # keep logs
)

rld.check_device() # training on CPU recommended

env.video = False # switch video recording off (only switch on every x episodes as this is slow)

# environment info
discrete_act, discrete_obs, act_dim, state_dim = rld.env_info(env, print_out=True)

# render start image
env.reset(seed=seed)
# rld.render(env)

The device is: cpu (as recommended)
actions are continuous with 4 dimensions/#actions
observations are continuous with 24 dimensions/#observations
maximum timesteps is: None


  logger.warn(


(array([-1.3781445e-03, -2.0011244e-05,  1.4475726e-03, -1.5999975e-02,
         9.8153561e-02, -1.2663029e-03,  8.5725147e-01,  2.0965904e-03,
         1.0000000e+00,  3.9205227e-02, -1.2662878e-03,  8.5045713e-01,
         1.0122511e-03,  1.0000000e+00,  4.3961507e-01,  4.4460756e-01,
         4.6016780e-01,  4.8821869e-01,  5.3265011e-01,  6.0082245e-01,
         7.0722014e-01,  8.8352221e-01,  1.0000000e+00,  1.0000000e+00],
       dtype=float32),
 {})

**Training**

In [None]:
# in the submission please use seed_everything with seed 42 for verification
seed, state, info = rld.seed_everything(seed, env)

# track statistics for plotting
tracker = rld.InfoTracker()

replay_buffer = EREReplayBuffer(
    state_dim, act_dim, config.annealing_steps, config.buffer_size, config.eta0, config.cmin,
    use_per=config.use_per, recency_scale=config.recency_scale,
    )

actor = Actor(
    state_dim, act_dim, config.actor_hidden_dims, use_d2rl=config.use_d2rl
    ).to(DEVICE)

critic = Critic(
    state_dim, act_dim, config.n_quantiles, config.n_mini_critics,
    config.mini_critic_hidden_dims, use_d2rl=config.use_d2rl
    ).to(DEVICE)
critic_target = copy.deepcopy(critic)

dreamer = None # init as none
if config.use_dreamer:
    dreamer = DreamerAgent(
        state_dim, act_dim, config.hidden_dim, 
        config.window_size, config.num_layers, config.num_heads, config.dropout_prob,
        config.dreamer_lr, config.dreamer_weight_decay,
        ).to(DEVICE)

target_entropy = -np.prod(env.action_space.shape).item() # target entropy heuristic = −dim(A)
grad_step_class = GradientStep(
    actor=actor, critic=critic, critic_target=critic_target,
    discount=config.discount, tau=config.tau,
    actor_lr=config.actor_lr, critic_lr=config.critic_lr, alpha_lr=config.alpha_lr,
    n_quantiles=config.n_quantiles, n_mini_critics=config.n_mini_critics,
    top_quantiles_to_drop_per_net=config.top_quantiles_to_drop_per_net,
    target_entropy=target_entropy, use_per=config.use_per,
    )

actor.train()

total_steps = 0
memory_ptr = 0 # marks boundary between new and old data in the replay buffer
train_set, test_set = None, None
recent_rewards= []
ep_timesteps_dreamer = ep_reward_dreamer = dreamer_eps = 0
last_dreamer_test_loss = 0.0 # track dreamer test loss
completed_env = False

episode = start_episode # =1 by default or whatever was in a checkpoint


Seed set to 42


In [None]:
# training based on https://github.com/ArijusLengvenis/bipedal-walker-dreamer/blob/main/dreamer_bipedal_walker_code.ipynb

# training loop
while episode <= config.max_episodes: # index from 1
    # recording statistics and video can be switched on and off (video recording is slow!)
    if is_recording and episode >= ep_start_rec:
        env.info = episode % video_interval == 0   # track every x episodes (usually tracking every episode is fine)
        env.video = episode % video_interval == 0  # record videos every x episodes (set BEFORE calling reset!)

    # reset for new episode
    state, info = env.reset()
    state = state.astype(np.float32) # float32 for mps

    # sample real env
    ep_timesteps, ep_reward, input_buffer, info = train_on_environment(
        actor, env, grad_step_class, replay_buffer, config.max_timesteps, state,
        config.batch_size, total_steps, config.window_size, config.init_rand_steps,
        is_dreamer=False,
        # use intrinsic reward if dreamer active
        use_intrinsic_reward=(config.use_dreamer and episode > config.episode_threshold and total_steps >= config.init_rand_steps),
        last_dreamer_test_loss=last_dreamer_test_loss,
        intrinsic_r_scale=config.intrinsic_reward_scale,
        )

    total_steps += ep_timesteps
    tracker.track(info) # track statistics for plotting

    # update train/test sets for dreamer (add new eps and trim to window size)
    train_set, test_set = gen_test_train_seq(
        replay_buffer, train_set, test_set, config.train_split, config.window_size, config.step_size, memory_ptr)

    # train dreamer after ep_thresh and if we have train/test sets
    if config.use_dreamer and episode >= config.episode_threshold and train_set is not None:
        # train and assess the dreamer every train_frequency
        if episode % config.dreamer_train_frequency == 0:
            dreamer.train_dreamer(train_set, config.dreamer_train_epochs, config.batch_size_dreamer)

        # truncate the training set to control train time performance
        if train_set.shape[0] > config.max_size:
            train_set = train_set[-config.max_size:]

        if test_set.shape[0] > config.max_size//4: # cap test set to 25% train set
            test_set = test_set[-(config.max_size//4):]

        print(f'Size of train set: {train_set.shape[0]}, test set: {test_set.shape[0]}')
        
        # evaluate the dreamer's performance & decide num of training steps for dreamer
        dreamer_avg_loss = dreamer.test_dreamer(test_set, config.batch_size_dreamer)
        dreamer_eps = get_dreamer_eps(dreamer_avg_loss, config.loss_threshold)
        last_dreamer_test_loss = dreamer_avg_loss

        # train on dreamer if its accurate enough
        H = config.imagination_horizon # H = imagination horizon
        if dreamer_eps > 0:
            # if input_buffer is None or input_buffer.shape[0] != config.window_size:
                # input buffer not big enough to run dreamer
                # print('Dreamer not run: input buffer not big enough')
            if state is None or not isinstance(state, np.ndarray): # TODO
                print("Warning: Invalid initial state for Dreamer imagination. Skipping.")
                dreamer_eps = 0
            else:
                dreamer_start = state.astype(np.float32)
                dreamer.reset(dreamer_start) # reset dreamer with the current state

                # scale H based on dreamer loss: 1 if loss = 0 and 0 if loss >= thresh 
                H_scale_factor = max(0., 1. - (dreamer_avg_loss / config.loss_threshold))
                H = int(H * H_scale_factor ** 2) # quadratic scaling to encourage better models

                print(f'Dreamer active for {dreamer_eps} iterations, with {H} timesteps per ep.')
                ep_timesteps_dreamer = ep_reward_dreamer = 0
                for dep in range(dreamer_eps):
                    # print(f'Dreamer ep: {dep+1}')

                    # initialise dreamer states with the input sequence
                    # input buffer = state, action, next_state, reward, done
                    # dreamer.states = input_buffer[:, :state_dim]
                    # dreamer.actions = input_buffer[:-1, state_dim:state_dim+act_dim]
                    # dreamer.rewards = input_buffer[:-1, -2].unsqueeze(1)
                    # dreamer.dones = input_buffer[:-1, -1].unsqueeze(1)

                    dreamer.reset(dreamer_start) # TODO alternatively could resample from reply buffer

                    # sample from dreamer environment (ignore input buffer and info)
                    _td, _rd, _, _ = train_on_environment(
                        actor=actor, env=dreamer, grad_step_class=grad_step_class,
                        replay_buffer=replay_buffer, max_timesteps=H,
                        state=dreamer.current_state, # initial state for dreamer TODO
                        batch_size=config.batch_size, total_steps=total_steps,
                        sequence_length=config.window_size, init_rand_steps=None, # No random steps during imagination
                        is_dreamer=True, use_intrinsic_reward=False
                        )

                    ep_timesteps_dreamer += _td
                    ep_reward_dreamer += _rd

    memory_ptr = replay_buffer.ptr # update the memory pointer to the current position in the buffer

    print(f"Ep: {episode} | Timesteps: {ep_timesteps} | Reward: {ep_reward:.3f} | Total Steps: {total_steps:.2g} | Dreamer: {dreamer_eps > 0}")
    if config.use_dreamer and dreamer_eps > 0:
        print(f"\t Dreamer Eps: {dreamer_eps} | Dreamer Avg Reward: {ep_reward_dreamer/dreamer_eps:.3f} | Dreamer Avg Timesteps: {ep_timesteps_dreamer/dreamer_eps:.3g}")

    # plot tracked statistics (on real env)
    if episode % plot_interval == 0 or completed_env or episode == config.max_episodes: # always save last plot
        # save as well (show=False returns it but doesnt display)
        if save_fig and info: # check info isnt {} (i.e. not None)
            fig, _ = tracker.plot(show=False, r_mean_=True, r_std_=True, r_sum=dict(linestyle=':', marker='x'))
            wandb.log({"Coursework Plot": wandb.Image(fig)}, step=episode) # Log the figure directly to W&B
            # fig.savefig(f'./tracker_{run_name}.png', bbox_inches = 'tight')
            plt.close(fig) # Close the figure to free memory
        # show by default
        tracker.plot(r_mean_=True, r_std_=True, r_sum=dict(linestyle=':', marker='x'))

    # Log metrics to W&B
    log_dict = {
        "Total Steps": total_steps,
        "Episode Timesteps": ep_timesteps,
    }

    # log tracked statistics
    if 'recorder' in info: # if we have info available
        log_dict["TrackedInfo/r_mean_"] = info['recorder']['r_mean_'] # r is extrinsic only
        log_dict["TrackedInfo/r_std_"] = info['recorder']['r_std_']
        log_dict["TrackedInfo/r_sum"] = info['recorder']['r_sum']

    if config.use_dreamer:
        log_dict["Dreamer Average Loss"] = dreamer_avg_loss if 'dreamer_avg_loss' in locals() else None # Only log if calculated
        log_dict["Dreamer Episodes Run"] = dreamer_eps
        if dreamer_eps > 0:
            log_dict["Dreamer Average Reward"] = ep_reward_dreamer / dreamer_eps
            log_dict["Dreamer Average Timesteps"] = ep_timesteps_dreamer / dreamer_eps

    recent_rewards.append(ep_reward)
    current_avg = np.array(recent_rewards).mean()
    log_dict["R_mean_100"] = current_avg

    wandb.log(log_dict, step=episode) # log metrics each ep

    # env completion (break) condition - stop if we consistently meet target score
    if len(recent_rewards) >= config.len_r_list:
        print(f'Current progress: {current_avg:.3f} / {config.target_score}')
        if current_avg >= config.target_score: # quit when we've got good enough performance
            print(f"Completed environment in {episode} episodes!")
            break
        recent_rewards = recent_rewards[-config.len_r_list+1:] # discard oldest on list (keep most recent 99)        
    
    
    episode += 1

# don't forget to close environment (e.g. triggers last video save)
env.close()

# write log file
if config.hardcore:
    filename = f"[{run_name}]_" + "nchw73-agent-hardcore-log.txt"
else:
    filename = f"[{run_name}]_" + "nchw73-agent-log.txt"
env.write_log(folder="logs", file=filename)

wandb.finish() # finish wandb run


Ep: 1 | Timesteps: 69 | Reward: -116.031 | Total Steps: 69 | Dreamer: False
Ep: 2 | Timesteps: 130 | Reward: -100.340 | Total Steps: 2e+02 | Dreamer: False
Ep: 3 | Timesteps: 2000 | Reward: -61.315 | Total Steps: 2.2e+03 | Dreamer: False
Ep: 4 | Timesteps: 95 | Reward: -97.463 | Total Steps: 2.3e+03 | Dreamer: False
Ep: 5 | Timesteps: 2000 | Reward: -74.790 | Total Steps: 4.3e+03 | Dreamer: False
Ep: 6 | Timesteps: 2000 | Reward: -70.298 | Total Steps: 6.3e+03 | Dreamer: False
Ep: 7 | Timesteps: 2000 | Reward: -66.159 | Total Steps: 8.3e+03 | Dreamer: False
Ep: 8 | Timesteps: 110 | Reward: -98.922 | Total Steps: 8.4e+03 | Dreamer: False
torch.Size([256, 256, 95])
weights shape: torch.Size([256, 1])
per critic loss shape: torch.Size([256, 256, 95])


KeyboardInterrupt: 