# Imports and Disclosure

In [None]:
"""
This code is based off of the Samsung Labs implementation: https://github.com/SamsungLabs/tqc_pytorch
More information about TQC can be found here: https://arxiv.org/abs/2005.04269

An enhancement called D2RL was also implemented following this paper: https://arxiv.org/abs/2010.09163
The code associated with this paper: https://github.com/pairlab/d2rl

An ERE buffer is implemented to improve sample efficiency, for which the paper can be found here: https://arxiv.org/abs/1906.04009

Finally, key ideas and concepts for the auto-regressive transformer design stem from this paper: https://arxiv.org/abs/2202.09481
"""

DEVICE = 'cuda'

import numpy as np
import torch
import gym
import copy
import math

from torch.nn import Module, Linear, Transformer, MSELoss, CrossEntropyLoss
from torch.distributions import Distribution, Normal
from torch.nn.functional import gelu, logsigmoid
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
from torch import sigmoid
import matplotlib.pyplot as plt
from IPython import display as disp

# Functions

In [None]:

MAX_TIMESTEPS = 2000

# abstracted training loop function for both environments to work
# the gym environment will always have MAX_TIMESTEPS as the max_timesteps value
# while the dreamer enviroment might have less timesteps
def train_on_environement(actor, env, class_to_take_gradient_step, replay_buffer, max_timesteps, state, batch_size, total_num_steps, sequence_length):
    episode_timesteps = 0
    ep_reward = 0

    # save start sequence for the dreamer model
    input_buffer = torch.empty((0, 54), device=DEVICE)
    for t in range(max_timesteps): 
        total_num_steps += 1
        action = actor.select_action(state)
        next_state, reward, done, _ = env.step(action)
        episode_timesteps += 1
        ep_reward += reward

        # change reward of real environment
        # do not store memories if simulation ends without
        # reaching the sequence length
        if max_timesteps == MAX_TIMESTEPS:
            if reward == -100.0:
                reward = - 10.0
            else:
                reward *= 2
            replay_buffer.add(state, action, next_state, reward, done)
        else:
            if t == sequence_length:
                for row in input_buffer.cpu().numpy():
                    replay_buffer.add(row[:24], row[24:28], row[28:52], row[52], row[53])
            elif t > sequence_length:
                replay_buffer.add(state, action, next_state, reward, done)
        if t < sequence_length:
            input_buffer = torch.cat([input_buffer, torch.tensor(np.concatenate((state, action, next_state, np.array([reward]), np.array([done])), axis=0), device=DEVICE).unsqueeze(0)], axis=0)        
    
        state = next_state
        
        if total_num_steps >= batch_size:
            # train the agent using experiences from the real environment
            class_to_take_gradient_step.take_gradient_step(replay_buffer, t, batch_size)
    
        if done:
            break
    if sequence_length > 0:
        return episode_timesteps, ep_reward, input_buffer
    return episode_timesteps, ep_reward, None

# test loop for agent on environment
def simulate_on_environement(actor, env, max_timesteps, state):
    episode_timesteps = 0
    ep_reward = 0
    for t in range(max_timesteps):
        action = actor.select_action(state)
        next_state, reward, done, _ = env.step(action)
        episode_timesteps += 1
    
        state = next_state
        ep_reward += reward
    
        if done or t == max_timesteps - 1:
            break
    return episode_timesteps, ep_reward

# generate value array from replay buffer
def gen_values_from_replay_buffer(replay_buffer, ptr):
    return np.concatenate((
                    replay_buffer.state[ptr:replay_buffer.ptr],
                    replay_buffer.action[ptr:replay_buffer.ptr],
                    replay_buffer.reward[ptr:replay_buffer.ptr],
                    1. - replay_buffer.not_done[ptr:replay_buffer.ptr],
                    replay_buffer.next_state[ptr:replay_buffer.ptr],                
                ), 
                axis = 1
            )

# function to create sequences with a given window size and step size
def create_sequences(values, window_size, step_size):
    n_memories = values.shape[0]
    n_sequences = math.ceil(n_memories / step_size) - math.floor(window_size / step_size) + (1 if n_memories % step_size == 0 and window_size % step_size == 0 else 0)
    sequences = np.zeros((n_sequences, window_size, values.shape[1]))
    for i in range(n_sequences):
        sequences[i, :] = values[i * step_size:i * step_size + window_size, :]
    return sequences

# function to split train and test data
def generate_train_and_test_sequences(replay_buffer, train_set, test_set, train_split, window_size, step_size, ptr):
    values = gen_values_from_replay_buffer(replay_buffer, ptr)
    try:
        memory_sequences = create_sequences(values, window_size, step_size)
    except:
        return train_set, test_set
    indices = np.arange(memory_sequences.shape[0])
    np.random.shuffle(indices)
    split = int(train_split * memory_sequences.shape[0])
    train_indices = indices[:split]
    test_indices = indices[split:]
    if train_set is None:
        return memory_sequences[train_indices, :], memory_sequences[test_indices, :]
    return np.concatenate((train_set, memory_sequences[train_indices, :]), axis=0), np.concatenate((test_set, memory_sequences[test_indices, :]), axis=0)


# Calculate the number of training steps for the agent on the dreamer.
def calc_dreamer_iterations(dreamer_performance_score, score_threshold, avg_reward):
    if dreamer_performance_score >= score_threshold:
        return 0
    else:
        return int(10 * (1 - dreamer_performance_score / score_threshold) ** 2)


# dreamer agent that will generate
class DreamerAgent(Module):
    def __init__(self, state_dim, action_dim, hidden_dim, seq_len, num_layers, num_heads, dropout_prob, lr=0.001):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.input_dim = state_dim + action_dim + 2
        self.target_dim = self.state_dim + self.action_dim
        self.hidden_dim = hidden_dim
        self.seq_len = seq_len
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout_prob = dropout_prob

        self.mse_loss = MSELoss()
        self.ce_loss = CrossEntropyLoss()
        
        self.input_fc = Linear(self.input_dim, hidden_dim, device=DEVICE) # 
        self.target_fc = Linear(self.target_dim, hidden_dim, device=DEVICE) 
        self.transformer = Transformer(
            hidden_dim, 
            num_layers, 
            num_heads, 
            dropout=dropout_prob,
            device=DEVICE,
            activation=gelu,
            batch_first=True
        )
        self.output_next_state = Linear(hidden_dim + self.target_dim, state_dim, device=DEVICE)
        self.output_reward = Linear(hidden_dim + self.target_dim, 1, device=DEVICE)
        self.output_done = Linear(hidden_dim + self.target_dim, 1, device=DEVICE)
        self.optimizer = Adam(self.parameters(), lr=3e-4)
        
    # separate out the ground truth variables and compare against predictions
    def loss_fn(self, output_next_state, output_reward, output_done, ground_truth):
        reward, done, next_state = torch.split(ground_truth, [1, 1, self.state_dim], dim=-1)
        loss = self.mse_loss(output_next_state[:, -1], next_state)
        loss += self.mse_loss(output_reward[:, -1], reward)
        loss += self.ce_loss(output_done[:, -1], done)
        return loss

    def forward(self, input_tensor):
        # separate the input and target tensors
        target = input_tensor[:, -1, :self.target_dim].unsqueeze(1)
        encoded_target = self.target_fc(target)
        encoded_input = self.input_fc(input_tensor[:, :-1, :self.input_dim])

        # pass these into the transformer
        encoded_output = self.transformer(encoded_input, encoded_target)

        # decode the densely connected output
        output_next_state = self.output_next_state(torch.cat([encoded_output, target], axis=2))
        output_reward = self.output_reward(torch.cat([encoded_output, target], axis=2))
        output_done = sigmoid(self.output_done(torch.cat([encoded_output, target], axis=2)))
        return output_next_state, output_reward, output_done
    
    def predict(self, input_tensor, target_tensor):
        # separate the input and target tensors
        encoded_target = self.target_fc(target_tensor)
        encoded_input = self.input_fc(input_tensor)

        # pass these into the transformer
        encoded_output = self.transformer(encoded_input, encoded_target)

        # decode the densely connected output
        output_next_state = self.output_next_state(torch.cat([encoded_output, target_tensor], axis=1))
        output_reward = self.output_reward(torch.cat([encoded_output, target_tensor], axis=1))
        output_done = sigmoid(self.output_done(torch.cat([encoded_output, target_tensor], axis=1)))
        return output_next_state, output_reward, output_done

    # transformer training loop
    # sequences shape: (batch, sequence, features)
    def train_dreamer(self, sequences, epochs, batch_size=256):
        inputs = torch.tensor(sequences, dtype=torch.float, device=DEVICE)
        
        train_dataset = TensorDataset(inputs)
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        self.transformer.train()

        for epoch in range(epochs):
            running_loss = 0.0
            for i, input_batch in enumerate(train_dataloader):
                input_batch = input_batch[0]
                self.optimizer.zero_grad()
                output_next_state, output_reward, output_done = self.forward(input_batch)
                loss = self.loss_fn(output_next_state, output_reward, output_done, input_batch[:, -1, self.target_dim:])
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()
            print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, epochs, running_loss / len(train_dataloader)))

    # transformer testing loop
    def test_dreamer(self, sequences, batch_size=64):
        inputs = torch.tensor(sequences, dtype=torch.float, device=DEVICE)
        
        test_dataset = TensorDataset(inputs)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

        self.transformer.eval()

        with torch.no_grad():
            running_loss = 0.0
            for i, input_batch in enumerate(test_dataloader):
                input_batch = input_batch[0]
                output_next_state, output_reward, output_done = self.forward(input_batch)
                loss = self.loss_fn(output_next_state, output_reward, output_done, input_batch[:, -1, self.target_dim:])
                running_loss += loss.item()
            print('Test Loss: {:.4f}'.format(running_loss / len(test_dataloader)))
        return running_loss / len(test_dataloader)
    
    def step(self, action):
        self.actions = torch.cat([self.actions, torch.tensor(np.array([action]), device=DEVICE)], axis=0)
        input_sequence = torch.cat([self.states[:-1], self.actions[:-1], self.rewards, self.dones], axis=1).to(torch.float32)
        target = torch.cat([self.states[-1], self.actions[-1]], axis=0).unsqueeze(0).to(torch.float32)
        with torch.no_grad():
            next_state, reward, done = self.predict(input_sequence, target)
            done = (done >= 0.6).float() # bias towards not done to avoid false positives

            self.states = torch.cat([self.states, next_state], axis=0)
            self.rewards = torch.cat([self.rewards, reward], axis=0)
            self.dones = torch.cat([self.dones, done], axis=0)
            
            if self.states.shape[0] > self.seq_len:
                self.states = self.states[1:]
                self.rewards = self.rewards[1:]
                self.dones = self.dones[1:]
            if self.actions.shape[0] > self.seq_len - 1:
                self.actions = self.actions[1:]
        
        return next_state.squeeze(0).cpu().numpy(), reward.cpu().item(), done.cpu().item(), None


def quantile_huber_loss(quantiles, samples, sum_over_quantiles = False):
    #return huber loss - uses a squared term if the absolute element-wise error falls below delta and a delta-scaled L1 term otherwise
    delta = samples[:, np.newaxis, np.newaxis, :] - quantiles[:, :, :, np.newaxis]  
    abs_delta = torch.abs(delta)
    huber_loss = torch.where(abs_delta > 1, abs_delta - 0.5, delta ** 2 * 0.5)
    n_quantiles = quantiles.shape[2]
    cumulative_prob = (torch.arange(n_quantiles, device=quantiles.device, dtype=torch.float) + 0.5) / n_quantiles
    cumulative_prob_shaped = cumulative_prob.view(1, 1, -1, 1)
    loss = (torch.abs(cumulative_prob_shaped - (delta < 0).float()) * huber_loss)

    # Summing over the quantile dimension 
    if sum_over_quantiles:
        loss = loss.sum(dim=-2).mean()
    else:
        loss = loss.mean()

    return loss


#MLP for critic that implements D2RL architecture 
class Mlp_for_Critic(Module):
    def __init__(self,input_size,hidden_sizes,output_size):
        super().__init__()
        input_size_ = input_size
        input_dim = 28 + hidden_sizes[0] 
        self.list_of_layers = []
        for i, next_size in enumerate(hidden_sizes):
            if i == 0:
              lay = Linear(input_size_, next_size, device=DEVICE)
            else: 
              lay = Linear(input_dim, next_size, device=DEVICE)
            self.add_module(f'layer{i}', lay)
            self.list_of_layers.append(lay)
        self.last_layer = Linear(input_dim, output_size, device=DEVICE)
    

    def forward(self, input_):
        curr = input_
        for lay in self.list_of_layers:
            curr_ = gelu(lay(curr))
            curr = torch.cat([curr_, input_], dim = 1)
        output = self.last_layer(curr)
        return output


#MLP for actor that implements D2RL architecture
class Mlp_for_Actor(Module):
    def __init__(self,input_size,hidden_sizes,output_size):
        super().__init__()
        self.list_of_layers = []
        input_size_ = input_size
        num_inputs = 24 
        input_dim = hidden_sizes[0] + num_inputs
        for i, next_size in enumerate(hidden_sizes):
            if i == 0:
              lay = Linear(input_size_, next_size, device=DEVICE)
            else:
              lay = Linear(input_dim, next_size, device=DEVICE)
            self.add_module(f'layer{i}', lay)
            self.list_of_layers.append(lay)
            input_size_ = next_size
            
        self.last_layer_mean_linear = Linear(input_dim, output_size, device=DEVICE)
        self.last_layer_log_std_linear = Linear(input_dim, output_size, device=DEVICE)

    def forward(self, input_):
        curr = input_

        for layer in self.list_of_layers:
            intermediate = layer(curr)
            curr = gelu(intermediate)

            curr = torch.cat([curr, input_], dim=1)

        mean_linear = self.last_layer_mean_linear(curr)
        log_std_linear = self.last_layer_log_std_linear(curr)
        return mean_linear, log_std_linear


class EREReplayBuffer(object):
    def __init__(self, state_dim, action_dim, T, max_size=int(1e6), eta=0.996, cmin=5000):
        self.max_size, self.ptr, self.size, self.rollover = max_size, 0, 0, False
        self.eta0 = eta
        self.cmin = cmin
        self.c_list = []
        self.index = []
        self.T = T

        self.reward = np.empty((max_size, 1))
        self.state = np.empty((max_size, state_dim))
        self.action = np.empty((max_size, action_dim))
        self.not_done = np.empty((max_size, 1))
        self.next_state = np.empty((max_size, state_dim))
        
    def sample(self, batch_size, t):
        
        # eta value for current timestep
        eta = self.eta_anneal(t)

        index = np.array([self._get_index(eta, k, batch_size) for k in range(batch_size)])

        r = torch.tensor(self.reward[index], dtype = torch.float, device = DEVICE)
        s = torch.tensor(self.state[index], dtype = torch.float, device = DEVICE)
        ns = torch.tensor(self.next_state[index], dtype = torch.float, device = DEVICE)
        a = torch.tensor(self.action[index], dtype = torch.float, device = DEVICE)
        nd = torch.tensor(self.not_done[index], dtype = torch.float, device = DEVICE)
        
        return s, a, ns, r, nd
    
    def _get_index(self, eta, k, batch_size):
        c_calc = self.size * eta ** (k * 1000 / batch_size)
        ck = c_calc if c_calc > self.cmin else self.size

        if not self.rollover:
            return np.random.randint(self.size - ck, self.size)
        
        return np.random.randint(self.ptr + self.size - ck, self.ptr + self.size) % self.size

    def eta_anneal(self, t):
       return self.eta0 + (1 - self.eta0) * t / self.T
 
    def add(self, state, action, next_state, reward, done):
        #Add experience to replay buffer 
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.next_state[self.ptr] = next_state
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done

        self.ptr += 1
        self.ptr %= self.max_size

        if self.max_size > self.size + 1:
          self.size = self.size + 1
        else:
          self.size = self.max_size
          self.rollover = True


#Actor
class Actor(Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.mlp = Mlp_for_Actor(state_dim, [512, 512], action_dim)

    def forward(self, obs):
        mean, log_std = self.mlp(obs)
        log_std = log_std.clamp(-20, 2)
        std = torch.exp(log_std)
        log_prob = None
        if self.training == False: 
            action = torch.tanh(mean)
        elif self.training == True:
            tanh_dist = TanhNormal(mean, std)
            action, pre_tanh = tanh_dist.random_sample()
            log_prob = tanh_dist.log_probability(pre_tanh)
            log_prob = log_prob.sum(dim=1, keepdim=True)      
        else:  
            print('Something wrong with training mode')
            
        return action, log_prob

    def select_action(self, obs):
        obs = torch.FloatTensor(obs).to(DEVICE)[np.newaxis, :]
        action, log_prob = self.forward(obs)
        return np.array(action[0].cpu().detach())


#Critic
class Critic(Module):
    def __init__(self, state_dim, action_dim, n_quantiles, n_nets):
        super().__init__()
        self.list_of_mlp = []
        self.n_quantiles = n_quantiles

        for i in range(n_nets):
            net = Mlp_for_Critic(state_dim + action_dim, [256, 256], n_quantiles)
            self.add_module(f'net{i}', net)
            self.list_of_mlp.append(net)

    def forward(self, state, action):
        quantiles = torch.stack(tuple(net(torch.cat((state, action), dim=1)) for net in self.list_of_mlp), dim=1)
        return quantiles



class TanhNormal(Distribution):
    def __init__(self, normal_mean, normal_std):
        super().__init__()
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)
        self.stand_normal = Normal(torch.zeros_like(self.normal_mean, device=DEVICE), torch.ones_like(self.normal_std, device=DEVICE))
        
        
    def logsigmoid(tensor):

      denominator = 1 + torch.exp(-tensor)
      return torch.log(1/ denominator)

    def log_probability(self, pre_tanh):
        final = (self.normal.log_prob(pre_tanh)) - (2 * np.log(2) + logsigmoid(2 * pre_tanh) + logsigmoid(-2 * pre_tanh))
        return final

    def random_sample(self):
        pretanh = self.normal_mean + self.normal_std * self.stand_normal.sample()
        return torch.tanh(pretanh), pretanh


class Gradient_Step(object):
  def __init__(
    self,
    *,
    actor,
    critic,
    critic_target,
    discount,
    tau,
    top_quantiles_to_drop,
    target_entropy,
    quantiles_total
  ):
    self.actor = actor
    self.critic = critic
    self.critic_target = critic_target
    self.log_alpha = torch.zeros((1,), requires_grad=True, device=DEVICE)
    self.quantiles_total = quantiles_total
    self.actor_optimizer = Adam(self.actor.parameters(), lr=3e-4)
    
    self.alpha_optimizer = Adam([self.log_alpha], lr=3e-4)
    self.critic_optimizer = Adam(self.critic.parameters(), lr=3e-4)
    self.discount, self.tau, self.top_quantiles_to_drop, self.target_entropy  = discount, tau, top_quantiles_to_drop,target_entropy


  def take_gradient_step(self, replay_buffer, t, batch_size=256):
    # Sample replay buffer
    state, action, next_state, reward, not_done = replay_buffer.sample(batch_size, t)
    alpha = torch.exp(self.log_alpha) #entropy temperature coefficient

    with torch.no_grad():
      # Action by the current actor for the sampled state
      new_next_action, next_log_pi = self.actor(next_state)

      # Compute and cut quantiles at the next state
      next_z = self.critic_target(next_state, new_next_action)  
      
      # Sort and drop top k quantiles to control overestimation.
      sorted_z, _ = torch.sort(next_z.reshape(batch_size, -1))
      sorted_z_part = sorted_z[:, :self.quantiles_total-self.top_quantiles_to_drop]

      # td error + entropy term
      target = reward + not_done * self.discount * (sorted_z_part - alpha * next_log_pi)
    
    # Get current Quantile estimates using action from the replay buffer
    cur_z = self.critic(state, action)
    critic_loss = quantile_huber_loss(cur_z, target)


    new_action, log_pi = self.actor(state)
    # detach the variable from the graph so we don't change it with other losses
    alpha_loss = -self.log_alpha * (log_pi + self.target_entropy).detach().mean()

    # Optimise critic
    self.critic_optimizer.zero_grad()
    critic_loss.backward()
    self.critic_optimizer.step()

    # Update target networks
    for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
      target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    
    # Compute actor loss
    actor_loss = (alpha * log_pi - self.critic(state, new_action).mean(2).mean(1, keepdim=True)).mean()
    
    # Optimise the actor
    self.actor_optimizer.zero_grad()
    actor_loss.backward()
    self.actor_optimizer.step()

    # Optimise the entropy coefficient
    self.alpha_optimizer.zero_grad()
    alpha_loss.backward()
    self.alpha_optimizer.step()


# Hyperparameters

In [None]:
plot_interval = 10 # update the plot every N episodes
video_every = 25 # videos can take a very long time to render so only do it every N episodes

# agent hyperparameters
seed = 42
n_quantiles = 25
top_quantiles_to_drop_per_net = 2
n_nets = 5
batch_size = 256
discount = 0.98
tau = 0.005

# dreamer hyperparameters
batch_size_dreamer = 512
hidden_dim = 256
num_layers = 16
num_heads = 4
dropout_prob = 0.1
window_size = 40               # transformer context window size
step_size = 1                  # how many timesteps to skip between each context window
train_split = 0.80             # train/validation split
score_threshold = 0.8          # quality threshold for using the dreamer model
dreamer_train_epochs = 15      # how many epochs to train the dreamer model for
dreamer_train_frequency = 10   # how often to train the dreamer model
episode_threshold = 50         # how many episodes to run before training the dreamer model
max_size = int(5e4)            # maximum size of the training set for the dreamer model

record = True
save_model = True

env = gym.make('BipedalWalker-v3')
env.seed(seed)
env.action_space.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if record:
    env = gym.wrappers.Monitor(env, "./video", video_callable=lambda ep_id: ep_id % video_every == 0 and ep_id >= 50, force=True)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

# max_episodes = 100
max_timesteps = MAX_TIMESTEPS

# Model initialisation and Main loop

In [None]:
#Intialise everything
replay_buffer = EREReplayBuffer(state_dim, action_dim, MAX_TIMESTEPS)
actor = Actor(state_dim, action_dim).to(DEVICE)

critic = Critic(state_dim, action_dim, n_quantiles, n_nets).to(DEVICE)
critic_target = copy.deepcopy(critic)

dreamer = DreamerAgent(state_dim, action_dim, hidden_dim, window_size, num_layers, num_heads, dropout_prob).to(DEVICE)

top_quantiles_to_drop = top_quantiles_to_drop_per_net * n_nets

class_to_take_gradient_step = Gradient_Step(actor=actor,critic=critic,critic_target=critic_target,top_quantiles_to_drop=top_quantiles_to_drop,discount=discount,tau=tau,target_entropy=-np.prod(env.action_space.shape).item(), quantiles_total = n_quantiles * n_nets)
actor.train()
state = env.reset()

episode_timesteps = 0
episode = 1
total_num_steps = 0
ep_reward = 0
memory_ptr = 0
train_set, test_set = None, None
reward_list = []
reward_avg_list = []
plot_data = []
episode_timesteps_dreamer = ep_reward_dreamer = dreamer_iterations = 0
current_size = 0

log_f = open("agent-log.txt","w+")
episode = 1
while True:

    # sample from true environment
    episode_timesteps, ep_reward, input_buffer = train_on_environement(actor, env, class_to_take_gradient_step, replay_buffer, max_timesteps, state, batch_size, total_num_steps, window_size)
    total_num_steps += episode_timesteps
    state = env.reset()

    # generate and trim the size of the train/test sets
    train_set, test_set = generate_train_and_test_sequences(replay_buffer, train_set, test_set, train_split, window_size, step_size, memory_ptr)

    if episode >= episode_threshold and input_buffer.shape[0] == window_size:  

        # train and assess the dreamer every train_frequency
        if episode % dreamer_train_frequency == 0:
            dreamer.train_dreamer(train_set, dreamer_train_epochs, batch_size_dreamer)

        # truncate the training set to control train time performance
        if test_set.shape[0] > max_size:
            train_set = train_set[-max_size:]
        
        # Evaluate the dreamer's performance
        dreamer_effectiveness_score = dreamer.test_dreamer(test_set, batch_size_dreamer)
        dreamer_iterations = calc_dreamer_iterations(dreamer_effectiveness_score, score_threshold, np.array(reward_avg_list).mean())        

        print('Size of sequences: ', train_set.shape[0], test_set.shape[0])        

        if dreamer_iterations > 0:
            print(f'Dreamer active for {dreamer_iterations} iterations')
            episode_timesteps_dreamer = ep_reward_dreamer = 0
            # train the agent on the dreamer if the dreamer is good enough to accurately depict the environment
            for dep in range(dreamer_iterations):
                print('Dreamer episode: ', dep+1)

                # initialise dreamer states with the input sequence
                dreamer.states = input_buffer[:, :state_dim]
                dreamer.actions = input_buffer[:-1, state_dim:state_dim+action_dim]
                dreamer.rewards = input_buffer[:-1, -2-state_dim].unsqueeze(1)
                dreamer.dones = input_buffer[:-1, -1-state_dim].unsqueeze(1)

            # sample from dreamer environment
            _td, _rd, _ = train_on_environement(actor, dreamer, class_to_take_gradient_step, replay_buffer, math.ceil(total_num_steps/episode) - 1, dreamer.states[-1].cpu().numpy(), batch_size, total_num_steps, window_size)
            episode_timesteps_dreamer += _td
            ep_reward_dreamer += _rd
    memory_ptr = replay_buffer.ptr

    # save results and reset variables
    # NOTE dreamer rewards are based on values stored in the replay buffer,
    # which are modified to have fall penalty -100 -> -10 and all other rewards
    # scaled by a factor of 2
    print(f"Episode Num: {episode} Episode T: {episode_timesteps} Reward: {ep_reward:.3f} Dreamer Eps: {dreamer_iterations} Dreamer Avg Timesteps: {episode_timesteps_dreamer/dreamer_iterations if dreamer_iterations > 0 else 0} Dreamer Avg Reward: {ep_reward_dreamer/dreamer_iterations if dreamer_iterations > 0 else 0}")
    log_f.write('episode: {}, reward: {}\n'.format(episode, ep_reward))
    log_f.flush()
    reward_list.append(ep_reward)
    reward_avg_list.append(ep_reward)
    ep_reward = 0
    episode += 1

    # print reward data every so often
    if episode % plot_interval == 0:
        plot_data.append([episode, np.array(reward_list).mean(), np.array(reward_list).std()])
        reward_list = []
        plt.plot([x[0] for x in plot_data], [x[1] for x in plot_data], '-', color='tab:grey')
        plt.fill_between([x[0] for x in plot_data], [x[1]-x[2] for x in plot_data], [x[1]+x[2] for x in plot_data], alpha=0.2, color='tab:grey')
        plt.xlabel('Episode number')
        plt.ylabel('Episode reward')
        plt.show()
        disp.clear_output(wait=True)

    # break condition
    if len(reward_avg_list) == 100:
        print(f'Current progress: {np.array(reward_avg_list).mean()}/300')
        if np.array(reward_avg_list).mean() >= 300:
            print('Completed environment!')
            break
        if episode == 1000:
            print('Exceeded training time')
            break
        reward_avg_list = reward_avg_list[-99:]

if save_model:
    torch.save(actor.state_dict(), './model.pth')
env.close()

with open("plot.txt", "w") as file:
    for episode, mean, std in plot_data:
        file.write(str(episode) + ',' + str(mean) + ',' + str(std) + "\n")