In [1]:
import math
import os
import random
import threading
import time
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Tuple, Optional

import gymnasium as gym
import numpy as np
import torch
import torch.multiprocessing as mp
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
import cv2
# from tensorboardX import SummaryWriter
# from torch.utils.tensorboard import SummaryWriter

In [2]:
game_name = 'MsPacman-v5'
obs_shape = (1, 84, 84)

lr = 1e-4
eps = 1e-3
grad_norm = 40
batch_size = 64
learning_starts = 50000
save_interval = 500
target_net_update_interval = 2000
gamma = 0.997
prio_exponent = 0.9
importance_sampling_exponent = 0.6

training_steps = 100000
buffer_capacity = 1000000
max_episode_steps = 27000
actor_update_interval = 400
block_length = 400 

num_actors = 8
base_eps = 0.4
alpha = 7
log_interval = 10

# sequence setting
burn_in_steps = 0
learning_steps = 80
forward_steps = 5
seq_len = burn_in_steps + learning_steps + forward_steps

# network setting
hidden_dim = 512

render = False
save_plot = True
test_epsilon = 0.001

In [3]:
class NoopResetEnv(gym.Wrapper):
    def __init__(self, env, noop_max=30):
        """start the game with no-op actions to provide random starting positions
        No-op is assumed to be action 0.
        """
        gym.Wrapper.__init__(self, env)
        self.noop_max = noop_max
        self.override_num_noops = None
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

    def reset(self, **kwargs):
        """ Do no-op action for a number of steps in [1, noop_max]."""
        self.env.reset(**kwargs)
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            noops = np.random.randint(1, self.noop_max + 1) #pylint: disable=E1101
        assert noops > 0
        obs = None
        for _ in range(noops):
            obs, _, terminated, truncated, _ = self.env.step(self.noop_action)
            if terminated or truncated:
                obs = self.env.reset(**kwargs)
        return obs, {}

    def step(self, action):
        return self.env.step(action)



class WarpFrame(gym.ObservationWrapper):
    def __init__(self, env, width=84, height=84):
        """
        Warp frames to 84x84 as done in the Nature paper and later work.
        """
        super().__init__(env)
        self._width = width
        self._height = height

        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(1, self._height, self._width),
            dtype=np.uint8,
        )

    def observation(self, obs):

        obs = cv2.resize(
            obs, (self._width, self._height), interpolation=cv2.INTER_AREA
        )

        obs = np.expand_dims(obs, 0)

        return obs


def create_env(env_name=game_name, noop_start=True, render=False):

    env = gym.make(f'ALE/{env_name}', obs_type='grayscale', frameskip=4, repeat_action_probability=0, full_action_space=False, render_mode='human' if render else 'rgb_array')

    env = WarpFrame(env)
    if noop_start:
        env = NoopResetEnv(env)

    return env

In [4]:
class PriorityTree:
    def __init__(self, capacity, prio_exponent, is_exponent):
        self.num_layers = 1
        while capacity > 2**(self.num_layers - 1):
            self.num_layers += 1
        
        self.ptree = np.zeros(2**self.num_layers - 1, dtype=np.float64)
        
        self.prio_exponent = prio_exponent
        self.is_exponent = is_exponent
    
    def update(self, idxes: np.ndarray, td_error: np.ndarray) -> None:
        priorities = td_error ** self.prio_exponent
        
        idxes = idxes + 2**(self.num_layers - 1) - 1
        self.ptree[idxes] = priorities
        
        for _ in range(self.num_layers - 1):
            idxes = (idxes - 1) // 2
            idxes = np.unique(idxes)
            self.ptree[idxes] = self.ptree[2 * idxes + 1] + self.ptree[2 * idxes + 2]
    
    def sample(self, num_samples: int) -> tuple[np.ndarray, np.ndarray]:
        p_sum = self.ptree[0]
        interval = p_sum / num_samples
        
        prefixsums = np.arange(0, p_sum, interval, dtype=np.float64) + np.random.uniform(0, interval, num_samples)
        
        idxes = np.zeros(num_samples, dtype=np.int64)
        for _ in range(self.num_layers-1):
            nodes = self.ptree[2 * idxes + 1]
            idxes = np.where(prefixsums < nodes, 2 * idxes + 1, 2 * idxes + 2)
            prefixsums = np.where(idxes%2 == 0, prefixsums - self.ptree[idxes-1], prefixsums)
        
        priorities = self.ptree[idxes]
        min_p = np.min(priorities)
        is_weigths = np.power(priorities/min_p, -self.is_exponent)
        
        idxes -= 2**(self.num_layers - 1) - 1
        
        return idxes, is_weigths

In [5]:
@dataclass
class Block:
    obs: np.array
    last_action: np.array
    last_reward: np.array
    action: np.array
    n_step_reward: np.array
    gamma: np.array
    num_sequences: int
    burn_in_steps: np.array
    learning_steps: np.array
    forward_steps: np.array

In [6]:
class ReplayBuffer:
    def __init__(self, sample_queue_list, batch_queue, priority_queue, alpha=prio_exponent, beta=importance_sampling_exponent):
        self.buffer_capacity = buffer_capacity
        self.sequence_len = learning_steps
        self.num_sequences = buffer_capacity // self.sequence_len
        self.block_len = block_length
        self.num_blocks = self.buffer_capacity // self.block_len
        self.seq_pre_block = self.block_len // self.sequence_len
        
        self.block_ptr = 0
        
        self.priority_tree = PriorityTree(self.num_sequences, alpha, beta)
        
        self.batch_size = batch_size
        
        self.env_steps = 0
        
        self.num_episodes = 0
        self.episode_reward = 0
        
        self.training_steps = 0
        self.last_training_steps = 0
        self.sum_loss = 0
        
        self.lock = threading.Lock()
        
        self.size = 0
        self.last_size = 0
        
        self.buffer = [None] * self.num_blocks
        
        self.sample_queue_list = sample_queue_list
        self.batch_queue = batch_queue
        self.priority_queue = priority_queue
        
        
    def __len__(self):
        return self.size
    
    def run(self):
        background_thread = threading.Thread(target=self.add_data, daemon=True)
        background_thread.start()
        
        background_thread = threading.Thread(target=self.prepare_data, daemon=True)
        background_thread.start()
        
        background_thread = threading.Thread(target=self.update_data, daemon=True)
        background_thread.start()

        writer_thread = threading.Thread(target=self.write_stats, daemon=True)
        writer_thread.start()
        
        while True:
            print(f'Buffer size: {self.size}')
            print(f'Buffer update speed: {(self.size - self.last_size) / log_interval}/s')
            self.last_size = self.size
            print(f'Number of environment steps: {self.env_steps}')
            if self.num_episodes != 0:
                print(f'Average episode return: {self.episode_reward / self.num_episodes:.4f}')
                self.episode_reward = 0
                self.num_episodes = 0
            print(f'Number of training steps: {self.training_steps}')
            print(f'Training speed: {(self.training_steps - self.last_training_steps) / log_interval}/s')
            if self.training_steps != self.last_training_steps:
                print(f'Loss: {self.sum_loss/(self.training_steps-self.last_training_steps):.4f}')
                self.last_training_steps = self.training_steps
                self.sum_loss = 0
            self.last_env_steps = self.env_steps
            print()
            
            if self.training_steps == training_steps:
                break
            else:
                time.sleep(log_interval)
    
    def write_stats(self):
        os.makedirs(f'logs/t2d2/{game_name}', exist_ok=True)
        while True:
            with open(f'logs/t2d2/{game_name}/buffer_size.txt', 'a') as f:
                f.write(f'{self.size} {self.env_steps}\n')
            
            with open(f'logs/t2d2/{game_name}/buffer_update_speed.txt', 'a') as f:
                f.write(f'{(self.size - self.last_size) / log_interval} {self.env_steps}\n')
            
            with open(f'logs/t2d2/{game_name}/environment_steps.txt', 'a') as f:
                f.write(f'{self.env_steps}\n')

            if self.num_episodes != 0:
                with open(f'logs/t2d2/{game_name}/episode_return.txt', 'a') as f:
                    f.write(f'{self.episode_reward / self.num_episodes} {self.env_steps}\n')

            with open(f'logs/t2d2/{game_name}/training_steps.txt', 'a') as f:
                f.write(f'{self.training_steps} {self.env_steps}\n')

            with open(f'logs/t2d2/{game_name}/training_speed.txt', 'a') as f:
                f.write(f'{(self.training_steps - self.last_training_steps) / log_interval} {self.env_steps}\n')

            if self.training_steps != self.last_training_steps:
                with open(f'logs/t2d2/{game_name}/training_loss.txt', 'a') as f:
                    f.write(f'{self.sum_loss/(self.training_steps-self.last_training_steps)} {self.env_steps}\n')

            time.sleep(log_interval)

    def prepare_data(self):
        while self.size < learning_starts:
            time.sleep(1)
        
        while True:
            if not self.batch_queue.full():
                data = self.sample_batch()
                self.batch_queue.put(data)
            else:
                time.sleep(0.1)
    
    def add_data(self):
        while True:
            for sample_queue in self.sample_queue_list:
                if not sample_queue.empty():
                    data = sample_queue.get_nowait()
                    self.add(*data)
    
    def update_data(self):
        while True:
            if not self.priority_queue.empty():
                data = self.priority_queue.get_nowait()
                self.update_priorities(*data)
            else:
                time.sleep(0.1)
    
    def add(self, block: Block, priority: np.array, episode_reward: float):
        with self.lock:
            idxes = np.arange(self.block_ptr * self.seq_pre_block, (self.block_ptr + 1) * self.seq_pre_block, dtype=np.int64)
            
            self.priority_tree.update(idxes, priority)
            
            if self.buffer[self.block_ptr] is not None:
                self.size -= np.sum(self.buffer[self.block_ptr].learning_steps).item()

            self.size += np.sum(block.learning_steps).item()
            
            self.buffer[self.block_ptr] = block
            
            self.env_steps += np.sum(block.learning_steps, dtype=np.int32)
            
            self.block_ptr = (self.block_ptr + 1) % self.num_blocks
            
            if episode_reward:
                self.episode_reward += episode_reward
                self.num_episodes += 1
    
    def sample_batch(self):
        batch_obs, batch_last_action, batch_last_reward, batch_action, batch_reward, batch_gamma = [], [], [], [], [], []
        burn_in_steps, learning_steps, forward_steps = [], [], []

        with self.lock:

            idxes, is_weights = self.priority_tree.sample(self.batch_size)

            block_idxes = idxes // self.seq_pre_block
            sequence_idxes = idxes % self.seq_pre_block


            for block_idx, sequence_idx  in zip(block_idxes, sequence_idxes):

                block = self.buffer[block_idx]

                assert sequence_idx < block.num_sequences

                burn_in_step = block.burn_in_steps[sequence_idx]
                learning_step = block.learning_steps[sequence_idx]
                forward_step = block.forward_steps[sequence_idx]
                
                start_idx = block.burn_in_steps[0] + np.sum(block.learning_steps[:sequence_idx])

                obs = block.obs[start_idx-burn_in_step:start_idx+learning_step+forward_step]
                last_action = block.last_action[start_idx-burn_in_step:start_idx+learning_step+forward_step]
                last_reward = block.last_reward[start_idx-burn_in_step:start_idx+learning_step+forward_step]
                obs, last_action, last_reward = torch.from_numpy(obs), torch.from_numpy(last_action), torch.from_numpy(last_reward)
                
                start_idx = np.sum(block.learning_steps[:sequence_idx])
                end_idx = start_idx + block.learning_steps[sequence_idx]
                action = block.action[start_idx:end_idx]
                reward = block.n_step_reward[start_idx:end_idx]
                gamma = block.gamma[start_idx:end_idx]
                
                batch_obs.append(obs)
                batch_last_action.append(last_action)
                batch_last_reward.append(last_reward)
                batch_action.append(action)
                batch_reward.append(reward)
                batch_gamma.append(gamma)

                burn_in_steps.append(burn_in_step)
                learning_steps.append(learning_step)
                forward_steps.append(forward_step)

            batch_obs = pad_sequence(batch_obs, batch_first=True)
            batch_last_action = pad_sequence(batch_last_action, batch_first=True)
            batch_last_reward = pad_sequence(batch_last_reward, batch_first=True)

            is_weights = np.repeat(is_weights, learning_steps)


            data = (
                batch_obs,
                batch_last_action,
                batch_last_reward,

                torch.from_numpy(np.concatenate(batch_action)).unsqueeze(1),
                torch.from_numpy(np.concatenate(batch_reward)),
                torch.from_numpy(np.concatenate(batch_gamma)),

                torch.ByteTensor(burn_in_steps),
                torch.ByteTensor(learning_steps),
                torch.ByteTensor(forward_steps),

                idxes,
                torch.from_numpy(is_weights.astype(np.float32)),
                self.block_ptr,

                self.env_steps
            )

        return data

    def update_priorities(self, idxes: np.ndarray, td_errors: np.ndarray, old_ptr: int, loss: float):
        """Update priorities of sampled transitions"""
        with self.lock:

            # discard the idxes that already been replaced by new data in replay buffer during training
            if self.block_ptr > old_ptr:
                # range from [old_ptr, self.seq_ptr)
                mask = (idxes < old_ptr*self.seq_pre_block) | (idxes >= self.block_ptr*self.seq_pre_block)
                idxes = idxes[mask]
                td_errors = td_errors[mask]
            elif self.block_ptr < old_ptr:
                # range from [0, self.seq_ptr) & [old_ptr, self,capacity)
                mask = (idxes < old_ptr*self.seq_pre_block) & (idxes >= self.block_ptr*self.seq_pre_block)
                idxes = idxes[mask]
                td_errors = td_errors[mask]

            self.priority_tree.update(idxes, td_errors)

        self.training_steps += 1
        self.sum_loss += loss
            

In [7]:
@dataclass
class AgentState:
    obs: torch.tensor
    action_dim: int
    last_action: torch.tensor = field(init=False)
    last_reward: torch.tensor = torch.zeros((1, 1), dtype=torch.float32)

    def __post_init__(self):
        self.last_action = torch.zeros((1, self.action_dim), dtype=torch.float32)

    def update(self, obs, last_action, last_reward):
        new_obs = torch.from_numpy(obs).unsqueeze(0)
        self.obs = torch.cat([self.obs, new_obs], dim=0)
        new_action = torch.tensor([[1 if i == last_action else 0 for i in range(self.action_dim)]],
                                        dtype=torch.float32)
        self.last_action = torch.cat([self.last_action, new_action], dim=0)
        new_reward = torch.tensor([[last_reward]], dtype=torch.float32)
        self.last_reward = torch.cat([self.last_reward, new_reward], dim=0)

In [8]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)

        if d_model % 2 == 0:
            pe[:, 0, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 0, 1::2] = torch.cos(position * div_term)[:, :-1]
        self.register_buffer('pe', pe)

    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        Arguments:
            x: torch.tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [9]:
class Network(nn.Module):
    def __init__(self, observation_dim, action_dim, hidden_dim=hidden_dim):
        super().__init__()

        self.action_dim = action_dim
        self.obs_shape = observation_dim
        self.max_forward_steps = 5

        self.feature = nn.Sequential(
            nn.Conv2d(1, 32, 8, 4),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2),
            nn.ReLU(True),
            nn.Conv2d(64, 64, 3, 1),
            nn.ReLU(True),
            nn.Flatten(),
            nn.Linear(3136, hidden_dim - action_dim - 1),
            nn.ReLU(True),
        )

        self.dim_model = hidden_dim

        self.positional_encoding = PositionalEncoding(d_model=self.dim_model)

        transformer_layer = nn.TransformerEncoderLayer(d_model=self.dim_model, nhead=4, dim_feedforward=hidden_dim)
        self.recurrent = nn.TransformerEncoder(transformer_layer, num_layers=1, enable_nested_tensor=False)

        self.advantage = nn.Sequential(
            nn.Linear(self.dim_model, self.dim_model),
            nn.ReLU(True),
            nn.Linear(self.dim_model, action_dim)
        )

        self.value = nn.Sequential(
            nn.Linear(self.dim_model, self.dim_model),
            nn.ReLU(True),
            nn.Linear(self.dim_model, 1)
        )

    @staticmethod
    def create_square_subsequent_mask(size):
        mask = (torch.triu(torch.ones(size, size)) == 0).transpose(0, 1)
        return mask.bool()

    @staticmethod
    def create_padding_mask(seq_lengths, max_len=None):
        if max_len is None:
            max_len = seq_lengths.max()
        batch_size = seq_lengths.size(0)
        mask = torch.arange(max_len).expand(batch_size, max_len) >= seq_lengths.unsqueeze(1)
        return mask.bool()

    def forward(self, state: AgentState):
        latent = self.feature(state.obs / 255)
        recurrent_input = torch.cat([latent, state.last_action, state.last_reward], dim=1) * math.sqrt(self.dim_model)

        recurrent_input = recurrent_input.unsqueeze(1)
        
        recurrent_input = self.positional_encoding(recurrent_input)

        mask = self.create_square_subsequent_mask(state.obs.size(0))
        mask = mask.to(recurrent_input.device)

        recurrent_output = self.recurrent(recurrent_input, mask=mask)
        
        hidden = recurrent_output[-1]
        
        adv = self.advantage(hidden)
        val = self.value(hidden)
        q_value = val + adv - adv.mean(1, keepdim=True)
        
        return q_value

    def calculate_q_(self, obs, last_action, last_reward, burn_in_steps, learning_steps, forward_steps):
        # obs shape: (batch_size, seq_len, obs_shape)
        batch_size, max_seq_length, *_ = obs.size()
        seq_len = burn_in_steps + learning_steps + forward_steps

        obs = obs.reshape(-1, *self.obs_shape)
        latent = self.feature(obs)
        latent = latent.reshape(batch_size, max_seq_length, -1)

        recurrent_input = torch.cat((latent, last_action, last_reward), dim=2) * math.sqrt(self.dim_model)

        recurrent_input = recurrent_input.transpose(0, 1) # (seq_len, batch_size, obs_shape)
        recurrent_input = self.positional_encoding(recurrent_input)

        mask = self.create_square_subsequent_mask(max_seq_length)
        mask = mask.to(recurrent_input.device)

        padding_mask = self.create_padding_mask(seq_len, max_seq_length)
        padding_mask = padding_mask.to(recurrent_input.device)

        recurrent_output = self.recurrent(recurrent_input, mask=mask, src_key_padding_mask=padding_mask)
        recurrent_output = recurrent_output.transpose(0, 1) # (batch_size, seq_len, dim_model)

        seq_start_idx = burn_in_steps + self.max_forward_steps
        forward_pad_steps = torch.minimum(self.max_forward_steps - forward_steps, learning_steps)

        hidden = []
        for hidden_seq, start_idx, end_idx, padding_length in zip(recurrent_output, seq_start_idx, seq_len, forward_pad_steps):
            hidden.append(hidden_seq[start_idx:end_idx])
            if padding_length > 0:
                hidden.append(hidden_seq[end_idx-1:end_idx].repeat(padding_length, 1))

        hidden = torch.cat(hidden)

        assert hidden.size(0) == torch.sum(learning_steps), f'{hidden.size(0)} != {torch.sum(learning_steps)}'

        adv = self.advantage(hidden)
        val = self.value(hidden)
        q_value = val + adv - adv.mean(1, keepdim=True)

        return q_value

    def calculate_q(self, obs, last_action, last_reward, burn_in_steps, learning_steps):
        # obs shape: (batch_size, seq_len, obs_shape)
        batch_size, max_seq_length, *_ = obs.size()
        
        obs = obs.reshape(-1, *self.obs_shape)
        latent = self.feature(obs)
        latent = latent.reshape(batch_size, max_seq_length, -1)

        recurrent_input = torch.cat((latent, last_action, last_reward), dim=2) * math.sqrt(self.dim_model)

        seq_len = burn_in_steps + learning_steps

        recurrent_input = recurrent_input.transpose(0, 1) # (seq_len, batch_size, obs_shape)
        recurrent_input = self.positional_encoding(recurrent_input)

        mask = self.create_square_subsequent_mask(max_seq_length)
        mask = mask.to(recurrent_input.device)

        padding_mask = self.create_padding_mask(seq_len, max_seq_length)
        padding_mask = padding_mask.to(recurrent_input.device)

        recurrent_output = self.recurrent(recurrent_input, mask=mask, src_key_padding_mask=padding_mask)
        recurrent_output = recurrent_output.transpose(0, 1) # (batch_size, seq_len, dim_model)

        hidden = torch.cat([output[burn_in:burn_in+learning] for output, burn_in, learning in zip(recurrent_output, burn_in_steps, learning_steps)], dim=0)

        adv = self.advantage(hidden)
        val = self.value(hidden)

        q_value = val + adv - adv.mean(1, keepdim=True)

        return q_value

In [10]:
def calculate_mixed_td_errors(td_error, learning_steps):
    
    start_idx = 0
    mixed_td_errors = np.empty(learning_steps.shape, dtype=td_error.dtype)
    for i, steps in enumerate(learning_steps):
        mixed_td_errors[i] = 0.9*td_error[start_idx:start_idx+steps].max() + 0.1*td_error[start_idx:start_idx+steps].mean()
        start_idx += steps
    
    return mixed_td_errors

class Learner:
    def __init__(self, batch_queue, priority_queue, model, grad_norm: int = grad_norm,
                lr: float = lr, eps:float = eps, game_name: str = game_name,
                target_net_update_interval: int = target_net_update_interval, save_interval: int = save_interval):

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.online_net = deepcopy(model)
        self.online_net.to(self.device)
        self.online_net.train()
        self.target_net = deepcopy(self.online_net)
        self.target_net.eval()
        self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=lr, eps=eps)
        self.loss_fn = nn.MSELoss(reduction='none')
        self.grad_norm = grad_norm
        self.batch_queue = batch_queue
        self.priority_queue = priority_queue
        self.num_updates = 0
        self.done = False

        self.target_net_update_interval = target_net_update_interval
        self.save_interval = save_interval

        self.batched_data = []

        self.shared_model = model

        self.game_name = game_name

    def store_weights(self):
        self.shared_model.load_state_dict(self.online_net.state_dict())

    def prepare_data(self):

        while True:
            if not self.batch_queue.empty() and len(self.batched_data) < 4:
                data = self.batch_queue.get_nowait()
                self.batched_data.append(data)
            else:
                time.sleep(0.1)

    def run(self):
        background_thread = threading.Thread(target=self.prepare_data, daemon=True)
        background_thread.start()
        time.sleep(2)

        start_time = time.time()
        while self.num_updates < training_steps:
            
            while not self.batched_data:
                time.sleep(1)
            data = self.batched_data.pop(0)

            batch_obs, batch_last_action, batch_last_reward, batch_action, batch_n_step_reward, batch_n_step_gamma, burn_in_steps, learning_steps, forward_steps, idxes, is_weights, old_ptr, env_steps = data
            batch_obs, batch_last_action, batch_last_reward = batch_obs.to(self.device), batch_last_action.to(self.device), batch_last_reward.to(self.device)
            batch_action = batch_action.to(self.device)
            batch_n_step_reward, batch_n_step_gamma = batch_n_step_reward.to(self.device), batch_n_step_gamma.to(self.device)
            is_weights = is_weights.to(self.device)

            batch_obs, batch_last_action = batch_obs.float(), batch_last_action.float()
            batch_action = batch_action.long()
            burn_in_steps, learning_steps, forward_steps = burn_in_steps, learning_steps, forward_steps

            batch_obs = batch_obs / 255

            # double q learning
            with torch.no_grad():
                batch_action_ = self.online_net.calculate_q_(batch_obs, batch_last_action, batch_last_reward, burn_in_steps, learning_steps, forward_steps).argmax(1).unsqueeze(1)
                batch_q_ = self.target_net.calculate_q_(batch_obs, batch_last_action, batch_last_reward, burn_in_steps, learning_steps, forward_steps).gather(1, batch_action_).squeeze(1)
            
            target_q = self.value_rescale(batch_n_step_reward + batch_n_step_gamma * self.inverse_value_rescale(batch_q_))
            # target_q = batch_n_step_reward + batch_n_step_gamma * batch_q_

            batch_q = self.online_net.calculate_q(batch_obs, batch_last_action, batch_last_reward, burn_in_steps, learning_steps).gather(1, batch_action).squeeze(1)
            
            loss = (is_weights * self.loss_fn(batch_q, target_q)).mean()

            
            td_errors = (target_q-batch_q).detach().clone().squeeze().abs().cpu().float().numpy()

            priorities = calculate_mixed_td_errors(td_errors, learning_steps.numpy())

            # automatic mixed precision training
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.online_net.parameters(), self.grad_norm)
            self.optimizer.step()

            self.num_updates += 1

            self.priority_queue.put((idxes, priorities, old_ptr, loss.item()))

            # store new weights in shared memory
            if self.num_updates % 4 == 0:
                self.store_weights()

            # update target net
            if self.num_updates % self.target_net_update_interval == 0:
                self.target_net.load_state_dict(self.online_net.state_dict())
            
            # save model 
            if self.num_updates % self.save_interval == 0:
                os.makedirs(f'models/t2d2/{game_name}', exist_ok=True)
                torch.save((self.online_net.state_dict(), self.num_updates, env_steps, (time.time()-start_time)/60), os.path.join(f'models/t2d2/{game_name}', '{}.pth'.format(self.num_updates)))

    @staticmethod
    def value_rescale(value, eps=1e-3):
        return value.sign()*((value.abs()+1).sqrt()-1) + eps*value

    @staticmethod
    def inverse_value_rescale(value, eps=1e-3):
        temp = ((1 + 4*eps*(value.abs()+1+eps)).sqrt() - 1) / (2*eps)
        return value.sign() * (temp.square() - 1)

In [11]:
class LocalBuffer:
    def __init__(self, action_dim: int, forward_steps: int = forward_steps,
                burn_in_steps = burn_in_steps, learning_steps: int = learning_steps, 
                gamma: float = gamma, hidden_dim: int = hidden_dim, block_length: int = block_length):
        
        self.action_dim = action_dim
        self.gamma = gamma
        self.hidden_dim = hidden_dim
        self.forward_steps = forward_steps
        self.learning_steps = learning_steps
        self.burn_in_steps = burn_in_steps
        self.block_length = block_length
        self.curr_burn_in_steps = 0
        
    def __len__(self):
        return self.size
    
    def reset(self, init_obs: np.ndarray):
        self.obs_buffer = [init_obs]
        self.last_action_buffer = [np.array([1 if i == 0 else 0 for i in range(self.action_dim)], dtype=bool)]
        self.last_reward_buffer = [[0]]
        self.action_buffer = []
        self.reward_buffer = []
        self.qval_buffer = []
        self.curr_burn_in_steps = 0
        self.size = 0
        self.sum_reward = 0
        self.done = False

    def add(self, action: int, reward: float, next_obs: np.ndarray, q_value: np.ndarray):
        self.action_buffer.append(action)
        self.reward_buffer.append(reward)
        self.obs_buffer.append(next_obs)
        self.last_action_buffer.append(np.array([1 if i == action else 0 for i in range(self.action_dim)], dtype=bool))
        self.last_reward_buffer.append([reward])
        self.qval_buffer.append(q_value)
        self.sum_reward += reward
        self.size += 1
    
    def finish(self, last_qval: np.ndarray = None) -> Tuple:
        assert self.size <= self.block_length
        # assert len(self.last_action_buffer) == self.curr_burn_in_steps + self.size + 1

        num_sequences = math.ceil(self.size/self.learning_steps)

        max_forward_steps = min(self.size, self.forward_steps)
        n_step_gamma = [self.gamma**self.forward_steps] * (self.size-max_forward_steps)

        # last_qval is none means episode done 
        if last_qval is not None:
            self.qval_buffer.append(last_qval)
            n_step_gamma.extend([self.gamma**i for i in reversed(range(1, max_forward_steps+1))])
        else:
            self.done = True
            self.qval_buffer.append(np.zeros_like(self.qval_buffer[0]))
            n_step_gamma.extend([0 for _ in range(max_forward_steps)]) # set gamma to 0 so don't need 'done'

        n_step_gamma = np.array(n_step_gamma, dtype=np.float32)

        obs = np.stack(self.obs_buffer)
        last_action = np.stack(self.last_action_buffer)
        last_reward = np.array(self.last_reward_buffer, dtype=np.float32)


        actions = np.array(self.action_buffer, dtype=np.uint8)

        qval_buffer = np.concatenate(self.qval_buffer)
        reward_buffer = self.reward_buffer + [0 for _ in range(self.forward_steps-1)]
        n_step_reward = np.convolve(reward_buffer, 
                                    [self.gamma**(self.forward_steps-1-i) for i in range(self.forward_steps)],
                                    'valid').astype(np.float32)

        burn_in_steps = np.array([min(i*self.learning_steps+self.curr_burn_in_steps, self.burn_in_steps) for i in range(num_sequences)], dtype=np.uint8)
        learning_steps = np.array([min(self.learning_steps, self.size-i*self.learning_steps) for i in range(num_sequences)], dtype=np.uint8)
        forward_steps = np.array([min(self.forward_steps, self.size+1-np.sum(learning_steps[:i+1])) for i in range(num_sequences)], dtype=np.uint8)
        assert forward_steps[-1] == 1 and burn_in_steps[0] == self.curr_burn_in_steps
        # assert last_action.shape[0] == self.curr_burn_in_steps + np.sum(learning_steps) + 1

        max_qval = np.max(qval_buffer[max_forward_steps:self.size+1], axis=1)
        max_qval = np.pad(max_qval, (0, max_forward_steps-1), 'edge')
        target_qval = qval_buffer[np.arange(self.size), actions]

        td_errors = np.abs(n_step_reward + n_step_gamma * max_qval - target_qval, dtype=np.float32)
        priorities = np.zeros(self.block_length//self.learning_steps, dtype=np.float32)

        priorities[:num_sequences] = calculate_mixed_td_errors(td_errors, learning_steps)

        # save burn in information for next block
        self.obs_buffer = self.obs_buffer[-self.burn_in_steps-1:]
        self.last_action_buffer = self.last_action_buffer[-self.burn_in_steps-1:]
        self.last_reward_buffer = self.last_reward_buffer[-self.burn_in_steps-1:]
        self.action_buffer.clear()
        self.reward_buffer.clear()
        self.qval_buffer.clear()
        self.curr_burn_in_steps = len(self.obs_buffer)-1
        self.size = 0
        
        block = Block(obs, last_action, last_reward, actions, n_step_reward, n_step_gamma, num_sequences, burn_in_steps, learning_steps, forward_steps)
        return [block, priorities, self.sum_reward if self.done else None]

In [12]:
class Actor:
    def __init__(self, epsilon: float, model, sample_queue,
                max_episode_steps: int = max_episode_steps, block_length: int = block_length):

        self.env = create_env(game_name)
        self.action_dim = self.env.action_space.n
        self.model = Network(self.env.observation_space.shape[0], self.action_dim)
        self.model.eval()
        self.local_buffer = LocalBuffer(self.action_dim)

        self.epsilon = epsilon
        self.shared_model = model
        self.sample_queue = sample_queue
        self.max_episode_steps = max_episode_steps
        self.block_length = block_length

    def run(self):
        
        actor_steps = 0

        while True:

            done = False
            agent_state = self.reset()
            episode_steps = 0

            while not done and episode_steps < self.max_episode_steps:

                with torch.no_grad():
                    q_value = self.model(agent_state)

                if random.random() < self.epsilon:
                    action = self.env.action_space.sample()
                else:
                    action = torch.argmax(q_value, 1).item()

                # apply action in env
                next_obs, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated

                agent_state.update(next_obs, action, reward)

                episode_steps += 1
                actor_steps += 1

                self.local_buffer.add(action, reward, next_obs, q_value.numpy())

                if done:
                    block = self.local_buffer.finish()
                    self.sample_queue.put(block)

                elif len(self.local_buffer) == self.block_length or episode_steps == self.max_episode_steps:
                    with torch.no_grad():
                        q_value = self.model(agent_state)

                    block = self.local_buffer.finish(q_value.numpy())

                    if self.epsilon > 0.01:
                        block[2] = None
                    self.sample_queue.put(block)

                if actor_steps % actor_update_interval == 0:
                    self.update_weights()

                
    def update_weights(self):
        self.model.load_state_dict(self.shared_model.state_dict())
    
    def reset(self):
        obs, _ = self.env.reset()
        self.local_buffer.reset(obs)

        state = AgentState(torch.from_numpy(obs).unsqueeze(0), self.action_dim)

        return state

In [13]:
def get_epsilon(actor_id: int, base_eps: float = base_eps, alpha: float = alpha, num_actors: int = num_actors):
    exponent = 1 + actor_id / (num_actors-1) * alpha
    return base_eps**exponent

In [14]:
assert False

AssertionError: 

In [None]:
env = create_env(game_name)
action_dim = env.action_space.n
observation_dim = env.observation_space.shape
queue = mp.Queue()
actor = Actor(epsilon=get_epsilon(0), model=Network(observation_dim, action_dim), sample_queue=queue)
actor.run()

In [None]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.set_num_threads(1)

env = create_env(game_name)
n_observations = env.observation_space.shape
n_actions = env.action_space.n
env.close()

model = Network(n_observations, n_actions)
model.share_memory()
sample_queue_list = [mp.Queue() for _ in range(num_actors)]
batch_queue = mp.Queue(num_actors)
priority_queue = mp.Queue(num_actors)

buffer = ReplayBuffer(sample_queue_list, batch_queue, priority_queue)
learner = Learner(batch_queue, priority_queue, model)
actors = [Actor(get_epsilon(i), model, sample_queue_list[i]) for i in range(num_actors)]

actor_procs = [mp.Process(target=actor.run) for actor in actors]
for proc in actor_procs:
    proc.start()

buffer_proc = mp.Process(target=buffer.run)
buffer_proc.start()

learner.run()

buffer_proc.join()

for proc in actor_procs:
    proc.terminate()

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Buffer size: 0
Buffer update speed: 0.0/s
Number of environment steps: 0
Number of training steps: 0
Training speed: 0.0/s

Buffer size: 0
Buffer update speed: 0.0/s
Number of environment steps: 0
Number of training steps: 0
Training speed: 0.0/s

Buffer size: 0
Buffer update speed: 0.0/s
Number of environment steps: 0
Number of training steps: 0
Training speed: 0.0/s

Buffer size: 0
Buffer update speed: 0.0/s
Number of environment steps: 0
Number of training steps: 0
Training speed: 0.0/s

Buffer size: 0
Buffer update speed: 0.0/s
Number of environment steps: 0
Number of training steps: 0
Training speed: 0.0/s

Buffer size: 0
Buffer update speed: 0.0/s
Number of environment steps: 0
Number of training steps: 0
Training speed: 0.0/s

Buffer size: 0
Buffer update speed: 0.0/s
Number of environment steps: 0
Number of training steps: 0
Training speed: 0.0/s

Buffer size: 0
Buffer update speed: 0.0/s
Number of environment steps: 0
Number of training steps: 0
Training speed: 0.0/s

Buffer s

KeyboardInterrupt: 

### 3h training

# Test model

In [15]:
checkpoint = torch.load('models/t2d2/MsPacman-v5/16000.pth')
checkpoint

(OrderedDict([('feature.0.weight',
               tensor([[[[ 0.0082,  0.0660, -0.1215,  ...,  0.0489,  0.0313,  0.1322],
                         [-0.0274,  0.0168, -0.0717,  ..., -0.0691, -0.0402,  0.0209],
                         [ 0.0284,  0.0415, -0.1347,  ...,  0.1284, -0.0006,  0.1166],
                         ...,
                         [-0.0352, -0.1052, -0.0262,  ..., -0.0391,  0.1156,  0.1200],
                         [-0.1224, -0.0339,  0.0801,  ...,  0.0500,  0.1214, -0.0241],
                         [ 0.0052, -0.1307, -0.1171,  ...,  0.0854, -0.0217,  0.0925]]],
               
               
                       [[[ 0.0918,  0.0142,  0.0535,  ...,  0.1188, -0.1044, -0.0421],
                         [ 0.0801,  0.1379,  0.1628,  ..., -0.0797,  0.0218, -0.0527],
                         [-0.0931,  0.1349,  0.1347,  ..., -0.0218, -0.0348, -0.0394],
                         ...,
                         [-0.1170, -0.0084, -0.1125,  ..., -0.0310, -0.0975,  0.0292],
 

In [16]:
env = create_env(game_name, render=True)
n_observations = env.observation_space.shape
n_actions = env.action_space.n
model = Network(n_observations, n_actions)
model.load_state_dict(checkpoint[0])
obs, _ = env.reset()
state = AgentState(torch.from_numpy(obs).unsqueeze(0), n_actions)
done = False
total_reward = 0
while not done:
    with torch.no_grad():
        q_value = model(state)
    action = torch.argmax(q_value, 1).item()
    obs, reward, terminated, truncated, _, = env.step(action)
    done = terminated or truncated
    total_reward += reward
    state.update(obs, action, reward)
env.close()
print(total_reward)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


420.0
