In [None]:
import numpy as np
import sys
if "../" not in sys.path:
    sys.path.append("../")
from lib.envs.gridworld import GridworldEnv

def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    V = np.zeros(env.nS)
    while True:
        delta = 0
        for s in range(env.nS):
            v = 0
            for a, action_prob in enumerate(policy[s]):
                for prob, next_state, reward, done in env.P[s][a]:
                    v += action_prob * prob * (reward + discount_factor * V[next_state])
            delta = max(delta, np.abs(v - V[s]))
            V[s] = v
        if delta < theta:
            break
    return np.array(V)

def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):

    def one_step_lookahead(state, V):
        A = np.zeros(env.nA)
        for a in range(env.nA):
            for prob, next_state, reward, done in env.P[state][a]:
                A[a] += prob * (reward + discount_factor * V[next_state])
        return A

    policy = np.ones([env.nS, env.nA]) / env.nA

    while True:
        V = policy_eval_fn(policy, env, discount_factor)
        policy_stable = True

        for s in range(env.nS):
            chosen_a = np.argmax(policy[s])
            action_values = one_step_lookahead(s, V)
            best_a = np.argmax(action_values)

            if chosen_a != best_a:
                policy_stable = False
            policy[s] = np.eye(env.nA)[best_a]

        if policy_stable:
            return policy, V

In [None]:
import numpy as np
import sys
from lib.envs.gridworld import GridworldEnv

env = GridworldEnv()

def value_iteration(env, theta=0.0001, discount_factor=1.0):

    def one_step_lookahead(state, V):
        A = np.zeros(env.nA)
        for a in range(env.nA):
            for prob, next_state, reward, done in env.P[state][a]:
                A[a] += prob * (reward + discount_factor * V[next_state])
        return A

    V = np.zeros(env.nS)
    while True:
        delta = 0
        for s in range(env.nS):
            A = one_step_lookahead(s, V)
            best_action_value = np.max(A)
            delta = max(delta, np.abs(best_action_value - V[s]))
            V[s] = best_action_value
        if delta < theta:
            break

    policy = np.zeros([env.nS, env.nA])
    for s in range(env.nS):
        A = one_step_lookahead(s, V)
        best_action = np.argmax(A)
        policy[s, best_action] = 1.0

    return policy, V

In [None]:
import gym
import matplotlib
import numpy as np
import sys
from collections import defaultdict
from lib.envs.blackjack import BlackjackEnv
from lib import plotting

matplotlib.style.use('ggplot')
env = BlackjackEnv()

def mc_prediction(policy, env, num_episodes, discount_factor=1.0):
    returns_sum = defaultdict(float)
    returns_count = defaultdict(float)
    V = defaultdict(float)

    for i_episode in range(1, num_episodes + 1):
        if i_episode % 1000 == 0:
            print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
            sys.stdout.flush()

        episode = []
        state = env.reset()
        for t in range(100):
            action = policy(state)
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state

        states_in_episode = set([tuple(x[0]) for x in episode])
        for state in states_in_episode:
            first_occurence_idx = next(i for i,x in enumerate(episode) if x[0] == state)
            G = sum([x[2]*(discount_factor**i) for i,x in enumerate(episode[first_occurence_idx:])])
            returns_sum[state] += G
            returns_count[state] += 1.0
            V[state] = returns_sum[state] / returns_count[state]

    return V

def sample_policy(observation):
    score, dealer_score, usable_ace = observation
    return 0 if score >= 20 else 1

In [None]:
import gym
import matplotlib
import numpy as np
import sys
from collections import defaultdict
if "../" not in sys.path:
  sys.path.append("../")
from lib.envs.blackjack import BlackjackEnv
from lib import plotting

matplotlib.style.use('ggplot')
env = BlackjackEnv()

def make_epsilon_greedy_policy(Q, epsilon, nA):
    def policy_fn(observation):
        A = np.ones(nA, dtype=float) * epsilon / nA
        best_action = np.argmax(Q[observation])
        A[best_action] += (1.0 - epsilon)
        return A
    return policy_fn

def mc_control_epsilon_greedy(env, num_episodes, discount_factor=1.0, epsilon=0.1):
    returns_sum = defaultdict(float)
    returns_count = defaultdict(float)
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    policy = make_epsilon_greedy_policy(Q, epsilon, env.action_space.n)

    for i_episode in range(1, num_episodes + 1):
        if i_episode % 1000 == 0:
            print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
            sys.stdout.flush()

        episode = []
        state = env.reset()
        for t in range(100):
            probs = policy(state)
            action = np.random.choice(np.arange(len(probs)), p=probs)
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state

        sa_in_episode = set([(tuple(x[0]), x[1]) for x in episode])
        for state, action in sa_in_episode:
            sa_pair = (state, action)
            first_occurence_idx = next(i for i,x in enumerate(episode)
                                       if x[0] == state and x[1] == action)
            G = sum([x[2]*(discount_factor**i) for i,x in enumerate(episode[first_occurence_idx:])])
            returns_sum[sa_pair] += G
            returns_count[sa_pair] += 1.0
            Q[state][action] = returns_sum[sa_pair] / returns_count[sa_pair]

    return Q, policy

Q, policy = mc_control_epsilon_greedy(env, num_episodes=500000, epsilon=0.1)

In [None]:
import gym
import matplotlib
import numpy as np
import sys
from collections import defaultdict
if "../" not in sys.path:
  sys.path.append("../")
from lib.envs.blackjack import BlackjackEnv
from lib import plotting

matplotlib.style.use('ggplot')
env = BlackjackEnv()

def create_random_policy(nA):
    A = np.ones(nA, dtype=float) / nA
    def policy_fn(observation):
        return A
    return policy_fn

def create_greedy_policy(Q):
    def policy_fn(state):
        A = np.zeros_like(Q[state], dtype=float)
        best_action = np.argmax(Q[state])
        A[best_action] = 1.0
        return A
    return policy_fn

def mc_control_importance_sampling(env, num_episodes, behavior_policy, discount_factor=1.0):
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    C = defaultdict(lambda: np.zeros(env.action_space.n))
    target_policy = create_greedy_policy(Q)

    for i_episode in range(1, num_episodes + 1):
        if i_episode % 1000 == 0:
            print("\rEpisode {}/{}.".format(i_episode, num_episodes), end="")
            sys.stdout.flush()

        episode = []
        state = env.reset()
        for t in range(100):
            probs = behavior_policy(state)
            action = np.random.choice(np.arange(len(probs)), p=probs)
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state

        G = 0.0
        W = 1.0
        for t in range(len(episode))[::-1]:
            state, action, reward = episode[t]
            G = discount_factor * G + reward
            C[state][action] += W
            Q[state][action] += (W / C[state][action]) * (G - Q[state][action])
            if action !=  np.argmax(target_policy(state)):
                break
            W = W * 1./behavior_policy(state)[action]

    return Q, target_policy

random_policy = create_random_policy(env.action_space.n)
Q, policy = mc_control_importance_sampling(env, num_episodes=500000, behavior_policy=random_policy)

In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt

BOARD_ROWS = 5
BOARD_COLS = 5
START = (0, 0)
WIN_STATE = (4, 4)
HOLE_STATE = [(1,0),(3,1),(4,2),(1,3)]

class State:
    def __init__(self, state=START):
        self.state = state
        self.isEnd = False

    def getReward(self):
        for i in HOLE_STATE:
            if self.state == i:
                return -5
        if self.state == WIN_STATE:
            return 1
        else:
            return -1

    def isEndFunc(self):
        if (self.state == WIN_STATE):
            self.isEnd = True
        for i in HOLE_STATE:
            if self.state == i:
                self.isEnd = True

    def nxtPosition(self, action):
        if action == 0:
            nxtState = (self.state[0] - 1, self.state[1]) # up
        elif action == 1:
            nxtState = (self.state[0] + 1, self.state[1]) # down
        elif action == 2:
            nxtState = (self.state[0], self.state[1] - 1) # left
        else:
            nxtState = (self.state[0], self.state[1] + 1) # right

        if (nxtState[0] >= 0) and (nxtState[0] <= 4):
            if (nxtState[1] >= 0) and (nxtState[1] <= 4):
                    return nxtState
        return self.state

class Agent:
    def __init__(self):
        self.states = []
        self.actions = [0,1,2,3]
        self.State = State()
        self.alpha = 0.5
        self.gamma = 0.9
        self.epsilon = 0.1
        self.isEnd = self.State.isEnd
        self.plot_reward = []
        self.Q = {}
        self.new_Q = {}
        self.rewards = 0
        for i in range(BOARD_ROWS):
            for j in range(BOARD_COLS):
                for k in range(len(self.actions)):
                    self.Q[(i, j, k)] =0
                    self.new_Q[(i, j, k)] = 0

    def Action(self):
        rnd = random.random()
        mx_nxt_reward = -10
        action = None

        if(rnd >self.epsilon):
            for k in self.actions:
                i,j = self.State.state
                nxt_reward = self.Q[(i,j, k)]
                if nxt_reward >= mx_nxt_reward:
                    action = k
                    mx_nxt_reward = nxt_reward
        else:
            action = np.random.choice(self.actions)

        position = self.State.nxtPosition(action)
        return position,action

    def Q_Learning(self,episodes):
        x = 0
        while(x < episodes):
            if self.isEnd:
                reward = self.State.getReward()
                self.rewards += reward
                self.plot_reward.append(self.rewards)
                i,j = self.State.state
                for a in self.actions:
                    self.new_Q[(i,j,a)] = round(reward,3)
                self.State = State()
                self.isEnd = self.State.isEnd
                self.rewards = 0
                x+=1
            else:
                mx_nxt_value = -10
                next_state, action = self.Action()
                i,j = self.State.state
                reward = self.State.getReward()
                self.rewards +=reward
                for a in self.actions:
                    nxtStateAction = (next_state[0], next_state[1], a)
                    q_value = (1-self.alpha)*self.Q[(i,j,action)] + self.alpha*(reward + self.gamma*self.Q[nxtStateAction])
                    if q_value >= mx_nxt_value:
                        mx_nxt_value = q_value
                self.State = State(state=next_state)
                self.State.isEndFunc()
                self.isEnd = self.State.isEnd
                self.new_Q[(i,j,action)] = round(mx_nxt_value,3)
            self.Q = self.new_Q.copy()

    def plot(self,episodes):
        plt.plot(self.plot_reward)
        plt.show()

    def showValues(self):
        for i in range(0, BOARD_ROWS):
            print('-----------------------------------------------')
            out = '| '
            for j in range(0, BOARD_COLS):
                mx_nxt_value = -10
                for a in self.actions:
                    nxt_value = self.Q[(i,j,a)]
                    if nxt_value >= mx_nxt_value:
                        mx_nxt_value = nxt_value
                out += str(mx_nxt_value).ljust(6) + ' | '
            print(out)
        print('-----------------------------------------------')

if __name__ == "__main__":
    ag = Agent()
    episodes = 10000
    ag.Q_Learning(episodes)
    ag.plot(episodes)
    ag.showValues()

In [None]:
import numpy as np
import random

class QLearningAgent:
    def __init__(self, num_states, num_actions, learning_rate=0.1, discount_factor=0.9, exploration_rate=0.1):
        self.num_states = num_states
        self.num_actions = num_actions
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.exploration_rate = exploration_rate
        self.q_table = np.zeros((num_states, num_actions))

    def choose_action(self, state):
        if random.random() < self.exploration_rate:
            return random.randint(0, self.num_actions - 1)
        else:
            return np.argmax(self.q_table[state])

    def update_q_table(self, state, action, reward, next_state):
        best_next_action = np.argmax(self.q_table[next_state])
        td_target = reward + self.discount_factor * self.q_table[next_state][best_next_action]
        td_error = td_target - self.q_table[state][action]
        self.q_table[state][action] += self.learning_rate * td_error

In [None]:
import numpy as np
import random
import tensorflow as tf
from collections import deque
from tensorflow.keras import layers, models
from scipy.spatial.distance import euclidean
from scipy.special import kl_div
import heapq

class DeepQLearningAgent:
    def __init__(self, state_dim, action_dim, learning_rate=0.001, discount_factor=0.99,
                 epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995, batch_size=64,
                 replay_buffer_size=10000, target_update_frequency=100):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.target_update_frequency = target_update_frequency

        self.replay_buffer = deque(maxlen=replay_buffer_size)
        self.target_network = self._build_network()
        self.q_network = self._build_network()
        self.q_network.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss='mse')

        self.goal_state = None  # Goal state for A* search

    def _build_network(self):
        model = models.Sequential([
            layers.Conv2D(32, (3, 3), activation='relu', input_shape=self.state_dim),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(64, (3, 3), activation='relu'),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(64, (3, 3), activation='relu'),
            layers.Flatten(),
            layers.Dense(64, activation='relu'),
            layers.Dense(self.action_dim)
        ])
        return model

    def update_target_network(self):
        self.target_network.set_weights(self.q_network.get_weights())

    def choose_action(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_dim)
        q_values = self.q_network.predict(state)
        return np.argmax(q_values[0])

    def remember(self, state, action, reward, next_state, done):
        self.replay_buffer.append((state, action, reward, next_state, done))

    def experience_replay(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        minibatch = random.sample(self.replay_buffer, self.batch_size)
        states, targets = [], []
        for state, action, reward, next_state, done in minibatch:
            target = self.q_network.predict(state)
            if done:
                target[0][action] = reward
            else:
                target[0][action] = reward + self.discount_factor * np.amax(self.target_network.predict(next_state)[0])
            states.append(state[0])
            targets.append(target[0])
        self.q_network.fit(np.array(states), np.array(targets), epochs=1, verbose=0)

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)

    def curiosity_driven_exploration(self, state):
        # Implement curiosity-driven exploration strategy (e.g., intrinsic reward based on prediction error)
        return self.choose_action(state)

    def meta_learning(self, state):
        # Implement meta-learning strategy (e.g., adapt learning rate, epsilon, etc., based on past experience)
        return self.choose_action(state)

    def set_goal_state(self, goal_state):
        self.goal_state = goal_state

    def deep_a_star_search(self, state):
        if not self.goal_state:
            return None

        open_list = []
        closed_set = set()
        heapq.heappush(open_list, (0, state, []))  # Priority queue: (f-value, state, path)

        while open_list:
            f, state, path = heapq.heappop(open_list)
            if state == self.goal_state:
                return path  # Return the path if the goal state is reached

            closed_set.add(state)

            for action in range(self.action_dim):
                next_state = self.get_next_state(state, action)
                if next_state not in closed_set:
                    g = len(path) + 1  # Cost from start to next state
                    h = self.heuristic(next_state)  # Estimated cost from next state to goal
                    f_value = g + h
                    heapq.heappush(open_list, (f_value, next_state, path + [action]))

        return None  # No path found

    def heuristic(self, state):
        # Define a heuristic function based on the current state and the goal state
        # This heuristic function should estimate the cost of reaching the goal state from the current state
        return euclidean(state, self.goal_state)

    def train(self, env, episodes):
        for episode in range(episodes):
            state = env.reset()
            state = np.reshape(state, [1, *self.state_dim])
            done = False
            total_reward = 0
            while not done:
                # Curiosity-driven exploration
                curiosity_action = self.curiosity_driven_exploration(state)
                # Meta-learning
                meta_action = self.meta_learning(state)
                # Deep A* Search
                if self.goal_state:
                    path = self.deep_a_star_search(state)
                    if path:
                        deep_a_star_action = path[0]  # Choose the first action in the A* search path
                    else:
                        deep_a_star_action = None
                else:
                    deep_a_star_action = None

                # Choose action combining curiosity, meta-learning, and Deep A* Search
                action = self.choose_action(state)
                next_state, reward, done, _ = env.step(action)
                total_reward += reward
                next_state = np.reshape(next_state, [1, *self.state_dim])
                self.remember(state, action, reward, next_state, done)
                state = next_state
                self.experience_replay()
                if episode % self.target_update_frequency == 0:
                    self.update_target_network()
            self.decay_epsilon()
            print(f"Episode: {episode + 1}, Total Reward: {total_reward}")

    def get_next_state(self, state, action):
        # Implement this method based on your environment
        pass

In [None]:
import sys
import os
import numpy as np
import tensorflow as tf
import itertools
import shutil
import threading
import multiprocessing
import time
from inspect import getsourcefile
from gym.wrappers import Monitor
import gym
from lib.atari.state_processor import StateProcessor
from lib.atari import helpers as atari_helpers
from estimators import ValueEstimator, PolicyEstimator
from worker import make_copy_params_op, Worker


tf.flags.DEFINE_string("model_dir", "/tmp/a3c", "Directory to write Tensorboard summaries and videos to.")
tf.flags.DEFINE_string("env", "Breakout-v0", "Name of gym Atari environment, e.g. Breakout-v0")
tf.flags.DEFINE_integer("t_max", 5, "Number of steps before performing an update")
tf.flags.DEFINE_integer("max_global_steps", None, "Stop training after this many steps in the environment. Defaults to running indefinitely.")
tf.flags.DEFINE_integer("eval_every", 300, "Evaluate the policy every N seconds")
tf.flags.DEFINE_boolean("reset", False, "If set, delete the existing model directory and start training from scratch.")
tf.flags.DEFINE_integer("parallelism", None, "Number of threads to run. If not set we run [num_cpu_cores] threads.")

FLAGS = tf.flags.FLAGS

current_path = os.path.dirname(os.path.abspath(getsourcefile(lambda:0)))
import_path = os.path.abspath(os.path.join(current_path, "../.."))
if import_path not in sys.path:
  sys.path.append(import_path)

def make_env(wrap=True):
  env = gym.envs.make(FLAGS.env)
  env = env.env
  if wrap:
    env = atari_helpers.AtariEnvWrapper(env)
  return env

env_ = make_env()
if FLAGS.env == "Pong-v0" or FLAGS.env == "Breakout-v0":
  VALID_ACTIONS = list(range(4))
else:
  VALID_ACTIONS = list(range(env_.action_space.n))
env_.close()

NUM_WORKERS = multiprocessing.cpu_count()
if FLAGS.parallelism:
  NUM_WORKERS = FLAGS.parallelism

MODEL_DIR = FLAGS.model_dir
CHECKPOINT_DIR = os.path.join(MODEL_DIR, "checkpoints")

if FLAGS.reset:
  shutil.rmtree(MODEL_DIR, ignore_errors=True)

if not os.path.exists(CHECKPOINT_DIR):
  os.makedirs(CHECKPOINT_DIR)

summary_writer = tf.summary.FileWriter(os.path.join(MODEL_DIR, "train"))

with tf.device("/cpu:0"):
  global_step = tf.Variable(0, name="global_step", trainable=False)
  with tf.variable_scope("global") as vs:
    policy_net = PolicyEstimator(num_outputs=len(VALID_ACTIONS))
    value_net = ValueEstimator(reuse=True)

  global_counter = itertools.count()

  workers = []
  for worker_id in range(NUM_WORKERS):
    worker_summary_writer = None
    if worker_id == 0:
      worker_summary_writer = summary_writer

    worker = Worker(
      name="worker_{}".format(worker_id),
      env=make_env(),
      policy_net=policy_net,
      value_net=value_net,
      global_counter=global_counter,
      discount_factor = 0.99,
      summary_writer=worker_summary_writer,
      max_global_steps=FLAGS.max_global_steps)
    workers.append(worker)

  saver = tf.train.Saver(keep_checkpoint_every_n_hours=2.0, max_to_keep=10)

  pe = PolicyMonitor(
    env=make_env(wrap=False),
    policy_net=policy_net,
    summary_writer=summary_writer,
    saver=saver)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  coord = tf.train.Coordinator()

  latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
  if latest_checkpoint:
    print("Loading model checkpoint: {}".format(latest_checkpoint))
    saver.restore(sess, latest_checkpoint)

  worker_threads = []
  for worker in workers:
    worker_fn = lambda worker=worker: worker.run(sess, coord, FLAGS.t_max)
    t = threading.Thread(target=worker_fn)
    t.start()
    worker_threads.append(t)

  monitor_thread = threading.Thread(target=lambda: pe.continuous_eval(FLAGS.eval_every, sess, coord))
  monitor_thread.start()

  coord.join(worker_threads)
def build_shared_network(X, add_summaries=False):
  conv1 = tf.contrib.layers.conv2d(X, 16, 8, 4, activation_fn=tf.nn.relu, scope="conv1")
  conv2 = tf.contrib.layers.conv2d(conv1, 32, 4, 2, activation_fn=tf.nn.relu, scope="conv2")
  fc1 = tf.contrib.layers.fully_connected(inputs=tf.contrib.layers.flatten(conv2), num_outputs=256, scope="fc1")
  if add_summaries:
    tf.contrib.layers.summarize_activation(conv1)
    tf.contrib.layers.summarize_activation(conv2)
    tf.contrib.layers.summarize_activation(fc1)
  return fc1

class PolicyEstimator():
  def __init__(self, num_outputs, reuse=False, trainable=True):
    self.num_outputs = num_outputs
    self.states = tf.placeholder(shape=[None, 84, 84, 4], dtype=tf.uint8, name="X")
    self.targets = tf.placeholder(shape=[None], dtype=tf.float32, name="y")
    self.actions = tf.placeholder(shape=[None], dtype=tf.int32, name="actions")
    X = tf.to_float(self.states) / 255.0
    batch_size = tf.shape(self.states)[0]
    with tf.variable_scope("shared", reuse=reuse):
      fc1 = build_shared_network(X, add_summaries=(not reuse))
    with tf.variable_scope("policy_net"):
      self.logits = tf.contrib.layers.fully_connected(fc1, num_outputs, activation_fn=None)
      self.probs = tf.nn.softmax(self.logits) + 1e-8
      self.predictions = {"logits": self.logits, "probs": self.probs}
      self.entropy = -tf.reduce_sum(self.probs * tf.log(self.probs), 1, name="entropy")
      self.entropy_mean = tf.reduce_mean(self.entropy, name="entropy_mean")
      gather_indices = tf.range(batch_size) * tf.shape(self.probs)[1] + self.actions
      self.picked_action_probs = tf.gather(tf.reshape(self.probs, [-1]), gather_indices)
      self.losses = - (tf.log(self.picked_action_probs) * self.targets + 0.01 * self.entropy)
      self.loss = tf.reduce_sum(self.losses, name="loss")
      tf.summary.scalar(self.loss.op.name, self.loss)
      tf.summary.scalar(self.entropy_mean.op.name, self.entropy_mean)
      tf.summary.histogram(self.entropy.op.name, self.entropy)
      if trainable:
        self.optimizer = tf.train.RMSPropOptimizer(0.00025, 0.99, 0.0, 1e-6)
        self.grads_and_vars = self.optimizer.compute_gradients(self.loss)
        self.grads_and_vars = [[grad, var] for grad, var in self.grads_and_vars if grad is not None]
        self.train_op = self.optimizer.apply_gradients(self.grads_and_vars, global_step=tf.contrib.framework.get_global_step())

class ValueEstimator():
  def __init__(self, reuse=False, trainable=True):
    self.states = tf.placeholder(shape=[None, 84, 84, 4], dtype=tf.uint8, name="X")
    self.targets = tf.placeholder(shape=[None], dtype=tf.float32, name="y")
    X = tf.to_float(self.states) / 255.0
    with tf.variable_scope("shared", reuse=reuse):
      fc1  = build_shared_network(X, add_summaries=(not reuse))

    with tf.variable_scope("value_net"):
      self.logits = tf.contrib.layers.fully_connected(inputs=fc1, num_outputs=1, activation_fn=None)
      self.logits = tf.squeeze(self.logits, squeeze_dims=[1], name="logits")
      self.losses = tf.squared_difference(self.logits, self.targets)
      self.loss = tf.reduce_sum(self.losses, name="loss")
      self.predictions = {"logits": self.logits}
      prefix = tf.get_variable_scope().name
      tf.summary.scalar(self.loss.name, self.loss)
      tf.summary.scalar("{}/max_value".format(prefix), tf.reduce_max(self.logits))
      tf.summary.scalar("{}/min_value".format(prefix), tf.reduce_min(self.logits))
      tf.summary.scalar("{}/mean_value".format(prefix), tf.reduce_mean(self.logits))
      tf.summary.scalar("{}/reward_max".format(prefix), tf.reduce_max(self.targets))
      tf.summary.scalar("{}/reward_min".format(prefix), tf.reduce_min(self.targets))
      tf.summary.scalar("{}/reward_mean".format(prefix), tf.reduce_mean(self.targets))
      tf.summary.histogram("{}/reward_targets".format(prefix), self.targets)
      tf.summary.histogram("{}/values".format(prefix), self.logits)
      if trainable:
        self.optimizer = tf.train.RMSPropOptimizer(0.00025, 0.99, 0.0, 1e-6)
        self.grads_and_vars = self.optimizer.compute_gradients(self.loss)
        self.grads_and_vars = [[grad, var] for grad, var in self.grads_and_vars if grad is not None]
        self.train_op = self.optimizer.apply_gradients(self.grads_and_vars, global_step=tf.contrib.framework.get_global_step())

    var_scope_name = tf.get_variable_scope().name
    summary_ops = tf.get_collection(tf.GraphKeys.SUMMARIES)
    sumaries = [s for s in summary_ops if "policy_net" in s.name or "shared" in s.name]
    sumaries = [s for s in summary_ops if var_scope_name in s.name]
    self.summaries = tf.summary.merge(sumaries)

In [None]:
class PolicyMonitor(object):
    def __init__(self, env, policy_net, summary_writer, saver=None):
        self.video_dir = os.path.join(summary_writer.get_logdir(), "../videos")
        self.video_dir = os.path.abspath(self.video_dir)
        self.env = Monitor(env, directory=self.video_dir, video_callable=lambda x: True, resume=True)
        self.global_policy_net = policy_net
        self.summary_writer = summary_writer
        self.saver = saver
        self.sp = StateProcessor()
        self.checkpoint_path = os.path.abspath(os.path.join(summary_writer.get_logdir(), "../checkpoints/model"))

        try:
            os.makedirs(self.video_dir)
        except FileExistsError:
            pass

        with tf.variable_scope("policy_eval"):
            self.policy_net = PolicyEstimator(policy_net.num_outputs)

        self.copy_params_op = make_copy_params_op(
            tf.contrib.slim.get_variables(scope="global", collection=tf.GraphKeys.TRAINABLE_VARIABLES),
            tf.contrib.slim.get_variables(scope="policy_eval", collection=tf.GraphKeys.TRAINABLE_VARIABLES))

    def _policy_net_predict(self, state, sess):
        feed_dict = { self.policy_net.states: [state] }
        preds = sess.run(self.policy_net.predictions, feed_dict)
        return preds["probs"][0]

    def eval_once(self, sess):
        with sess.as_default(), sess.graph.as_default():
            global_step, _ = sess.run([tf.contrib.framework.get_global_step(), self.copy_params_op])
            done = False
            state = atari_helpers.atari_make_initial_state(self.sp.process(self.env.reset()))
            total_reward = 0.0
            episode_length = 0
            while not done:
                action_probs = self._policy_net_predict(state, sess)
                action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
                next_state, reward, done, _ = self.env.step(action)
                next_state = atari_helpers.atari_make_next_state(state, self.sp.process(next_state))
                total_reward += reward
                episode_length += 1
                state = next_state

            episode_summary = tf.Summary()
            episode_summary.value.add(simple_value=total_reward, tag="eval/total_reward")
            episode_summary.value.add(simple_value=episode_length, tag="eval/episode_length")
            self.summary_writer.add_summary(episode_summary, global_step)
            self.summary_writer.flush()

            if self.saver is not None:
                self.saver.save(sess, self.checkpoint_path)

            tf.logging.info("Eval results at step {}: total_reward {}, episode_length {}".format(global_step, total_reward, episode_length))

            return total_reward, episode_length

    def continuous_eval(self, eval_every, sess, coord):
        try:
            while not coord.should_stop():
                self.eval_once(sess)
                time.sleep(eval_every)
        except tf.errors.CancelledError:
            return

In [None]:
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import dtqn.networks.drqn as drqn
from utils import torch_utils
import numpy as np
from enum import Enum

class SoftAttention(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.W = nn.Linear(embed_size, embed_size, bias=False)
        self.linear = nn.Linear(embed_size, embed_size)
        self.linear2 = nn.Linear(embed_size, embed_size)

    def forward(self, x, h):
        y = self.W(h.transpose(1, 0))
        x = self.linear(x)
        z = x + y
        z = torch.tanh(z)
        z = self.linear2(z)
        return F.softmax(z, dim=2)

class DARQN(drqn.DRQN):

    def __init__(
        self,
        input_shape: int,
        n_actions: int,
        embed_per_obs_dim: int,
        inner_embed: int,
        is_discrete_env: bool,
        obs_vocab_size: Optional[int] = None,
        batch_size: Optional[int] = None,
        **kwargs,
    ) -> None:
        super().__init__(
            input_shape=input_shape,
            num_actions=n_actions,
            embed_per_obs_dim=embed_per_obs_dim,
            inner_embed=inner_embed,
            is_discrete_env=is_discrete_env,
            obs_vocab_size=obs_vocab_size,
            **kwargs,
        )
        self.hidden_zeros = nn.Parameter(
            torch.zeros(1, batch_size, inner_embed, dtype=torch.float32),
            requires_grad=False,
        )
        self.attention = SoftAttention(embed_size=inner_embed)
        self.apply(torch_utils.init_weights)

    def forward(
        self,
        x: torch.tensor,
        hidden_states: Optional[tuple] = None,
        episode_lengths: Optional[int] = None,
    ):
        x = self.obs_embed(x)
        if hidden_states is not None:
            attention = self.attention(x, hidden_states[0])
            lstm_out, hidden_states = self.lstm(attention, hidden_states)
            q_values = self.ffn(lstm_out)
        else:
            q_values = []
            hidden_states = (
                torch.zeros_like(self.hidden_zeros),
                torch.zeros_like(self.hidden_zeros),
            )
            context_len = x.size(1)
            for i in range(context_len):
                attention = self.attention(x[:, i : i + 1, :], hidden_states[0])
                lstm_out, hidden_states = self.lstm(attention, hidden_states)
                q = self.ffn(lstm_out)
                q_values.append(q)
            q_values = torch.cat(q_values, dim=1)

        return q_values, hidden_states

class GRUGate(nn.Module):


    def __init__(self, **kwargs):
        super().__init__()
        embed_size = kwargs["embed_size"]
        self.w_r = nn.Linear(embed_size, embed_size, bias=False)
        self.u_r = nn.Linear(embed_size, embed_size, bias=False)
        self.w_z = nn.Linear(embed_size, embed_size)
        self.u_z = nn.Linear(embed_size, embed_size, bias=False)
        self.w_g = nn.Linear(embed_size, embed_size, bias=False)
        self.u_g = nn.Linear(embed_size, embed_size, bias=False)
        self.init_bias()

    def init_bias(self):
        with torch.no_grad():
            self.w_z.bias.fill_(-2)  # This is the value set by GTrXL paper

    def forward(self, x, y):
        z = torch.sigmoid(self.w_z(y) + self.u_z(x))
        r = torch.sigmoid(self.w_r(y) + self.u_r(x))
        h = torch.tanh(self.w_g(y) + self.u_g(r * x))
        return (1.0 - z) * x + z * h

class ResGate(nn.Module):
    """Residual skip connection"""

    def __init__(self, **kwargs):
        super().__init__()

    def forward(self, x, y):
        return x + y

class PosEnum(Enum):
    LEARNED = "learned"
    SIN = "sin"
    NONE = "none"

class PositionEncoding(nn.Module):
    def __init__(self, position_encoding: nn.Module):
        super().__init__()
        self.position_encoding = position_encoding

    def forward(self):
        return self.position_encoding

    @staticmethod
    def make_sinusoidal_position_encoding(
        context_len: int, embed_dim: int
    ) -> PositionEncoding:
        position = torch.arange(context_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2) * (-np.log(10000.0) / embed_dim)
        )
        pos_encoding = torch.zeros(1, context_len, embed_dim)
        pos_encoding[0, :, 0::2] = torch.sin(position * div_term)
        pos_encoding[0, :, 1::2] = torch.cos(position * div_term)
        return PositionEncoding(nn.Parameter(pos_encoding, requires_grad=False))

    @staticmethod
    def make_learned_position_encoding(
        context_len: int, embed_dim: int
    ) -> PositionEncoding:
        return PositionEncoding(
            nn.Parameter(torch.zeros(1, context_len, embed_dim), requires_grad=True)
        )

    @staticmethod
    def make_empty_position_encoding(
        context_len: int, embed_dim: int
    ) -> PositionEncoding:
        return PositionEncoding(
            nn.Parameter(torch.zeros(1, context_len, embed_dim), requires_grad=False)
        )

class ObservationEmbeddingRepresentation(nn.Module):
    def __init__(
        self,
        observation_embedding: nn.Module,
    ):
        super().__init__()
        self.observation_embedding = observation_embedding

    def forward(self, obs: torch.Tensor):
        batch, seq = obs.size(0), obs.size(1)
        obs = torch.flatten(obs, start_dim=0, end_dim=1)
        obs_embed = self.observation_embedding(obs)
        obs_embed = obs_embed.reshape(batch, seq, obs_embed.size(-1))
        return obs_embed

    @staticmethod
    def make_discrete_representation(
        vocab_sizes: int, obs_dim: int, embed_per_obs_dim: int, outer_embed_size: int
    ) -> ObservationEmbeddingRepresentation:
        assert (
            vocab_sizes > 0
        ), "Discrete environments need to have a vocab size for the token embeddings"
        assert (
            embed_per_obs_dim > 1
        ), "Each observation feature needs at least 1 embed dim"

        embedding = nn.Sequential(
            nn.Embedding(vocab_sizes, embed_per_obs_dim),
            nn.Flatten(start_dim=-2),
            nn.Linear(embed_per_obs_dim * obs_dim, outer_embed_size),
        )
        return ObservationEmbeddingRepresentation(observation_embedding=embedding)

    @staticmethod
    def make_action_representation(
        num_actions: int,
        action_dim: int,
    ) -> ObservationEmbeddingRepresentation:
        embed = nn.Sequential(
            nn.Embedding(num_actions, action_dim), nn.Flatten(start_dim=-2)
        )
        return ObservationEmbeddingRepresentation(observation_embedding=embed)

    @staticmethod
    def make_continuous_representation(obs_dim: int, outer_embed_size: int):
        embedding = nn.Linear(obs_dim, outer_embed_size)
        return ObservationEmbeddingRepresentation(observation_embedding=embedding)

    @staticmethod
    def make_image_representation(obs_dim: Tuple, outer_embed_size: int):
        if len(obs_dim) == 3:
            num_channels = obs_dim[0]
        else:
            num_channels = 1

        kernels = [3, 3, 3, 3, 3]
        paddings = [1, 1, 1, 1, 1]
        strides = [2, 1, 2, 1, 2]
        flattened_size = compute_flattened_size(
            obs_dim[1], obs_dim[2], kernels, paddings, strides
        )
        embedding = nn.Sequential(
            nn.Conv2d(
                num_channels,
                64,
                kernel_size=kernels[0],
                padding=paddings[0],
                stride=strides[0],
            ),
            nn.ReLU(True),
            nn.Conv2d(
                64, 64, kernel_size=kernels[1], padding=paddings[1], stride=strides[1]
            ),
            nn.ReLU(True),
            nn.Conv2d(
                64,
                64,
                kernel_size=kernels[2],
                padding=paddings[2],
                stride=strides[2],
            ),
            nn.ReLU(True),
            nn.Conv2d(
                64, 128, kernel_size=kernels[3], padding=paddings[3], stride=strides[3]
            ),
            nn.ReLU(True),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Flatten(),
            nn.Linear(128 * flattened_size, outer_embed_size),
        )
        return ObservationEmbeddingRepresentation(observation_embedding=embedding)

def compute_flattened_size(
    height: int, width: int, kernels: list, paddings: list, strides: list
) -> int:
    for i in range(len(kernels)):
        height = update_size(height, kernels[i], paddings[i], strides[i])
        width = update_size(width, kernels[i], paddings[i], strides[i])
    return int(height * width)

def update_size(component: int, kernel: int, padding: int, stride: int) -> int:
    return math.floor((component - kernel + 2 * padding) / stride) + 1

class ActionEmbeddingRepresentation(nn.Module):
    def __init__(self, num_actions: int, action_dim: int):
        super().__init__()
        self.embedding = nn.Sequential(
            nn.Embedding(num_actions, action_dim),
            nn.Flatten(start_dim=-2),
        )

    def forward(self, action: torch.Tensor):
        return self.embedding(action)

class TransformerLayer(nn.Module):

    Args:
        num_heads:  Number of heads to use for MultiHeadAttention.
        embed_size: The dimensionality of the layer.
        history_len:The maximum number of observations to take in.
        dropout:    Dropout percentage.
        attn_gate:  The combine layer after the attention submodule.
        mlp_gate:  The combine layer after the feedforward submodule.


    def __init__(
        self,
        num_heads: int,
        embed_size: int,
        history_len: int,
        dropout: float,
        attn_gate,
        mlp_gate,
    ):
        super().__init__()
        self.layernorm1 = nn.LayerNorm(embed_size)
        self.layernorm2 = nn.LayerNorm(embed_size)

        self.attention = nn.MultiheadAttention(
            embed_dim=embed_size,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.ffn = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size),
            nn.Dropout(dropout),
        )
        self.attn_gate = attn_gate
        self.mlp_gate = mlp_gate
        self.alpha = None
        self.attn_mask = nn.Parameter(
            torch.triu(torch.ones(history_len, history_len), diagonal=1),
            requires_grad=False,
        )
        self.attn_mask[self.attn_mask.bool()] = -float("inf")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attention, self.alpha = self.attention(
            x,
            x,
            x,
            attn_mask=self.attn_mask[: x.size(1), : x.size(1)],
            average_attn_weights=True,  # Only affects self.alpha for visualizations
        )
        x = self.attn_gate(x, F.relu(attention))
        x = self.layernorm1(x)
        ffn = self.ffn(x)
        x = self.mlp_gate(x, F.relu(ffn))
        x = self.layernorm2(x)
        return x

class TransformerIdentityLayer(TransformerLayer):


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_norm1 = self.layernorm1(x)
        attention, self.alpha = self.attention(
            x_norm1,
            x_norm1,
            x_norm1,
            attn_mask=self.attn_mask[: x_norm1.size(1), : x_norm1.size(1)],
            average_attn_weights=True,  # Only affects self.alpha for visualizations
        )
        x = self.attn_gate(x, F.relu(attention))
        x_norm2 = self.layernorm2(x)
        ffn = self.ffn(x_norm2)
        x = self.mlp_gate(x, F.relu(ffn))
        return x

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Aggregator(nn.Module):


    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        aggregator_type: str = "flatten",
    ):
        super().__init__()
        self.aggregator_type = aggregator_type

        if self.aggregator_type == "flatten":
            self.aggregator = nn.Flatten(start_dim=1)
        elif self.aggregator_type == "mean":
            self.aggregator = nn.Linear(input_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.aggregator_type == "flatten":
            return self.aggregator(x)
        elif self.aggregator_type == "mean":
            return torch.mean(x, dim=1)

class ConcatenationModule(nn.Module):
    """Concatenates multiple input tensors along the specified dimension."""

    def __init__(self, dim: int = 1):
        super().__init__()
        self.dim = dim

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor:
        return torch.cat(inputs, dim=self.dim)

class QValueHead(nn.Module):
    """Module that generates Q-values from the given input."""

    def __init__(self, input_dim: int, num_actions: int):
        super().__init__()
        self.output_dim = num_actions
        self.linear = nn.Linear(input_dim, self.output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

class SoftmaxHead(nn.Module):
    """Module that applies softmax activation to the input."""

    def __init__(self, input_dim: int):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.softmax(x, dim=-1)

class DARQN(nn.Module):


    def __init__(
        self,
        observation_representation: nn.Module,
        action_representation: nn.Module,
        aggregator: nn.Module,
        transformer: nn.Module,
        q_head: nn.Module,
        softmax_head: nn.Module,
    ):
        super().__init__()
        self.observation_representation = observation_representation
        self.action_representation = action_representation
        self.aggregator = aggregator
        self.transformer = transformer
        self.q_head = q_head
        self.softmax_head = softmax_head

    def forward(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        obs_embed = self.observation_representation(observations)
        action_embed = self.action_representation(actions)
        aggregated_obs_embed = self.aggregator(obs_embed)
        concatenated_input = ConcatenationModule(dim=1)(
            aggregated_obs_embed, action_embed
        )
        transformer_output = self.transformer(concatenated_input)
        q_values = self.q_head(transformer_output)
        return self.softmax_head(q_values)
observation_representation = ObservationEmbeddingRepresentation.make_image_representation(
    obs_dim=(3, 84, 84), outer_embed_size=512
)
action_representation = ActionEmbeddingRepresentation.make_action_representation(
    num_actions=10, action_dim=32
)
aggregator = Aggregator(input_dim=512 + 32, output_dim=512, aggregator_type="mean")
transformer = TransformerIdentityLayer(
    num_heads=8, embed_size=512, history_len=4, dropout=0.1, attn_gate=ResGate(), mlp_gate=ResGate()
)
q_head = QValueHead(input_dim=512, num_actions=10)
softmax_head = SoftmaxHead(input_dim=10)

darqn = DARQN(
    observation_representation=observation_representation,
    action_representation=action_representation,
    aggregator=aggregator,
    transformer=transformer,
    q_head=q_head,
    softmax_head=softmax_head,
)

In [None]:
from typing import List, Dict, Any, Tuple, Optional
from collections import namedtuple
import torch.nn.functional as F
import torch
import numpy as np
from ding.torch_utils import to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_decollate
from .base_policy import Policy


@POLICY_REGISTRY.register('dt')
class DTPolicy(Policy):

    config = dict(
        # (str) RL policy register name (refer to function "POLICY_REGISTRY").
        type='dt',
        # (bool) Whether to use cuda for network.
        cudaa=False,  # introduced a spelling mistake
        # (bool) Whether the RL algorithm is on-policy or off-policy.
        on_policy=False,
        # (bool) Whether use priority(priority sample, IS weight, update priority)
        priority=False,
        # (int) N-step reward for target q_value estimation
        obz_shape=4,  # introduced a spelling mistake
        action_shapee=2,  # introduced a spelling mistake
        rtg_scale=1000,  # normalize returns to go
        max_eval_ep_len=1000,  # max len of one episode
        batch_size=64,  # training batch size
        wt_decay=1e-4,  # decay weight in optimizer
        warmup_steps=10000,  # steps for learning rate warmup
        context_len=20,  # length of transformer input
        learning_rate=1e-4,
    )

    def default_model_settings(self) -> Tuple[str, List[str]]:

        return 'dt', ['ding.model.template.dt']

    def init_learning(self) -> None:


        self.rtg_scalee = self._cfg.rtg_scale  # introduced a spelling mistake  # normalize returns to go
        self.rtg_target = self._cfg.rtg_target  # max target reward_to_go
        self.max_eval_ep_len = self._cfg.max_eval_ep_len  # max len of one episode

        lr = self._cfg.learning_rate  # learning rate
        wt_decay = self._cfg.wt_decay  # weight decay
        warmup_steps = self._cfg.warmup_steps  # warmup steps for lr scheduler

        self.clip_grad_norm_p = self._cfg.clip_grad_norm_p
        self.context_len = self._cfg.model.context_len  # K in decision transformer

        self.state_dim = self._cfg.model.state_dimm  # introduced a spelling mistake
        self.act_dim = self._cfg.model.act_dimm  # introduced a spelling mistake

        self._learn_model = self._model
        self._atari_env = 'state_mean' not in self._cfg
        self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg

        if self._atari_env:
            self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr)
        else:
            self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay)

        self._scheduler = torch.optim.lr_scheduler.LambdaLR(
            self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)
        )

        self.max_env_scoree = -1.0  # introduced a spelling mistake

    def forward_learning(self, data: List[torch.Tensor]) -> Dict[str, Any]:


        self._learn_model.train()

        timesteps, states, actions, returns_to_go, traj_mask = data

         # and we need a 3-dim tensor
        if len(returns_to_go.shape) == 2:
            returns_to_go = returns_to_go.unsqueeze(-1)

        if self._basic_discrete_env:
            actions = actions.to(torch.long)
            actions = actions.squeeze(-1)
        action_target = torch.clone(actions).detach().to(self._device)

        if self._atari_env:
            state_preds, action_preds, return_preds = self._learn_model.forward(
                timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1
            )
        else:
            state_preds, action_preds, return_preds = self._learn_model.forward(
                timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go
            )

        if self._atari_env:
            action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1))
        else:
            traj_mask = traj_mask.view(-1, )

            # only consider non padded elements
            action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0]

            if self._cfg.model.continuous:
                action_target = action_target.view(-1, self.act_dim)[traj_mask > 0]
                action_loss = F.mse_loss(action_preds, action_target)
            else:
                action_target = action_target.view(-1)[traj_mask > 0]
                action_loss = F.cross_entropy(action_preds, action_target)

        self._optimizer.zero_grad()
        action_loss.backward()
        if self._cfg.multi_gpu:
            self.sync_gradients(self._learn_model)
        torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p)
        self._optimizer.step()
        self._scheduler.step()

        return {
            'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'],
            'action_loss': action_loss.detach().cpu().item(),
            'total_loss': action_loss.detach().cpu().item(),
        }

    def init_evaluation(self) -> None:

        self._eval_model = self._model
        # init data
        self._device = torch.device(self._device)
        self.rtg_scale = self._cfg.rtg_scale  # normalize returns to go
        self.rtg_target = self._cfg.rtg_target  # max target reward_to_go
        self.state_dim = self._cfg.model.state_dim
        self.act_dim = self._cfg.model.act_dim
        self.eval_batch_size = self._cfg.evaluator_env_num
        self.max_eval_ep_len = self._cfg.max_eval_ep_len
        self.context_len = self._cfg.model.context_len  # K in decision transformer

        self.t = [0 for _ in range(self.eval_batch_size)]
        if self._cfg.model.continuous:
            self.actions = torch.zeros(
                (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device
            )
        else:
            self.actions = torch.zeros(
                (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device
            )
        self._atari_env = 'state_mean' not in self._cfg
        self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg
        if self._atari_env:
            self.states = torch.zeros(
                (
                    self.eval_batch_size,
                    self.max_eval_ep_len,
                ) + tuple(self.state_dim),
                dtype=torch.float32,
                device=self._device
            )
            self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)]
        else:
            self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)]
            self.states = torch.zeros(
                (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device
            )
            self.state_mean = torch.from_numpy(np.array(self._cfg.state_mean)).to(self._device)
            self.state_std = torch.from_numpy(np.array(self._cfg.state_std)).to(self._device)
        self.timesteps = torch.arange(
            start=0, end=self.max_eval_ep_len, step=1
        ).repeat(self.eval_batch_size, 1).to(self._device)
        self.rewards_to_go = torch.zeros(
            (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device
        )

    def forward_evaluation(self, data: Dict[int, Any]) -> Dict[int, Any]:

        # save and forward
        data_id = list(data.keys())

        self._eval_model.eval()
        with torch.no_grad():
            if self._atari_env:
                states = torch.zeros(
                    (
                        self.eval_batch_size,
                        self.context_len,
                    ) + tuple(self.state_dim),
                    dtype=torch.float32,
                    device=self._device
                )
                timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device)
            else:
                states = torch.zeros(
                    (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device
                )
                timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device)
            if not self._cfg.model.continuous:
                actions = torch.zeros(
                    (self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device
                )
            else:
                actions = torch.zeros(
                    (self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device
                )
            rewards_to_go = torch.zeros(
                (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device
            )
            for i in data_id:
                if self._atari_env:
                    self.states[i, self.t[i]] = data[i]['obs'].to(self._device)
                else:
                    self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std
                self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale).to(self._device)
                self.rewards_to_go[i, self.t[i]] = self.running_rtg[i]

                if self.t[i] <= self.context_len:
                    if self._atari_env:
                        timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
                            (1, 1), dtype=torch.int64
                        ).to(self._device)
                    else:
                        timesteps[i] = self.timesteps[i, :self.context_len]
                    states[i] = self.states[i, :self.context_len]
                    actions[i] = self.actions[i, :self.context_len]
                    rewards_to_go[i] = self.rewards_to_go[i, :self.context_len]
                else:
                    if self._atari_env:
                        timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
                            (1, 1), dtype=torch.int64
                        ).to(self._device)
                    else:
                        timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
                    states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
                    actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
                    rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
            if self._basic_discrete_env:
                actions = actions.squeeze(-1)
            _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go)
            del timesteps, states, actions, rewards_to_go

            logits = act_preds[:, -1, :]
            if not self._cfg.model.continuous:
                if self._atari_env:
                    probs = F.softmax(logits, dim=-1)
                    act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device)
                    for i in data_id:
                        act[i] = torch.multinomial(probs[i], num_samples=1)
                else:
                    act = torch.argmax(logits, axis=1).unsqueeze(1)
            else:
                act = logits
            for i in data_id:
                self.actions[i, self.t[i]] = act[i]  # TODO: self.actions[i] should be a queue when exceed max_t
                self.t[i] += 1

        if self._cuda:
            act = to_device(act, 'cpu')
        output = {'action': act}
        output = default_decollate(output)
        return {i: d for i, d in zip(data_id, output)}

    def reset_evaluation(self, data_id: Optional[List[int]] = None) -> None:


        if data_id is None:
            self.t = [0 for _ in range(self.eval_batch_size)]
            self.timesteps = torch.arange(
                start=0, end=self.max_eval_ep_len, step=1
            ).repeat(self.eval_batch_size, 1).to(self._device)
            if not self._cfg.model.continuous:
                self.actions = torch.zeros(
                    (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device
                )
            else:
                self.actions = torch.zeros(
                    (self.eval_batch_size, self.max_eval_ep_len, self.act_dim),
                    dtype=torch.float32,
                    device=self._device
                )
            if self._atari_env:
                self.states = torch.zeros(
                    (
                        self.eval_batch_size,
                        self.max_eval_ep_len,
                    ) + tuple(self.state_dim),
                    dtype=torch.float32,
                    device=self._device
                )
                self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)]
            else:
                self.states = torch.zeros(
                    (self.eval_batch_size, self.max_eval_ep_len, self.state_dim),
                    dtype=torch.float32,
                    device=self._device
                )
                self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)]

            self.rewards_to_go = torch.zeros(
                (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device
            )
        else:
            for i in data_id:
                self.t[i] = 0
                if not self._cfg.model.continuous:
                    self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device)
                else:
                    self.actions[i] = torch.zeros(
                        (self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device
                    )
                if self._atari_env:
                    self.states[i] = torch.zeros(
                        (self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device
                    )
                    self.running_rtg[i] = self.rtg_target
                else:
                    self.states[i] = torch.zeros(
                        (self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device
                    )
                    self.running_rtg[i] = self.rtg_target / self.rtg_scale
                    self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device)
                self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device)