In [1]:
import torch as th
import numpy as np
import laserhockey.hockey_env as h_env

from network import Network
from own_env import OwnEnv
from action_selection import ActionSelection
from utils import ACTIONS_T

In [2]:
env = OwnEnv()
opponent = h_env.BasicOpponent(weak=True)
action_selection = ActionSelection(0.975, 5)
net = Network().eval()
net.load_state_dict(th.load('checkpoints/agent_2500.pth'))

<All keys matched successfully>

In [3]:
n_actions = len(ACTIONS_T)
gamma = 0.975
n_best=10
n_times = 2

def exploitation_v3(net, obs):
        obs_in = th.from_numpy(obs)[None].float()
        latent_states, state_values, policy_logits = net.initial_inference(obs_in)
        search_depths = th.FloatTensor([0])
        value_prefixes = th.FloatTensor([0])
        action_indices = th.arange(n_actions)

        for i in range(n_times):
            indices = th.argsort(state_values, descending=True)
            latent_states = latent_states[indices]  # [3, 64]
            state_values = state_values[indices]  # [3]
            search_depths = search_depths[indices]
            value_prefixes = value_prefixes[indices]
            if i != 0:
                action_indices = action_indices[indices]

            latent_best = latent_states[:n_best, None, None].tile(1, n_actions, n_actions,
                                                                  1)  # [n_best, n_actions, n_actions, obs_dim]
            actions1 = ACTIONS_T[None, :, None].tile(latent_best.shape[0], 1, n_actions,
                                                     1)  # [n_best, n_actions, n_actions, act_dim]
            actions2 = ACTIONS_T[None, None, :].tile(latent_best.shape[0], n_actions, 1,
                                                     1)  # [n_best, n_actions, n_actions, act_dim]
            search_depths_best = search_depths[:n_best, None, None].tile(1, n_actions,
                                                                         n_actions)  # [n_best, n_actions, n_actions]
            value_prefixes_best = value_prefixes[:n_best, None, None].tile(1, n_actions,
                                                                           n_actions)  # [n_best, n_actions, n_actions]

            latent_out, rewards, _, next_state_values, _ = net.recurrent_inference(latent_best.flatten(end_dim=2),
                                                                                   actions1.flatten(end_dim=2),
                                                                                   actions2.flatten(end_dim=2))

            latent_out = latent_out.reshape(latent_best.shape)  # [n_best, n_actions, n_actions, obs_dim]
            rewards = rewards.reshape(latent_best.shape[:3])  # [n_best, n_actions, n_actions]
            next_state_values = next_state_values.reshape(latent_best.shape[:3])  # [n_best, n_actions, n_actions]
            value_prefixes_best = value_prefixes_best + rewards * (
                        gamma ** search_depths_best)  # [n_best, n_actions, n_actions]
            state_values_best = value_prefixes_best + next_state_values * (
                        gamma ** (search_depths_best + 1))  # [n_best, n_actions, n_actions]

            n_mult = latent_out.shape[0] * latent_out.shape[1]
            state_values_append, min_indices = state_values_best.min(dim=2)  # [3, 25]
            latent_append = latent_out.flatten(end_dim=1)[th.arange(n_mult), min_indices.flatten()].reshape(
                (latent_out.shape[0], latent_out.shape[1], -1))
            value_prefixes_append = value_prefixes_best.flatten(end_dim=1)[
                th.arange(n_mult), min_indices.flatten()].reshape((latent_out.shape[0], latent_out.shape[1]))
            search_depths_append = search_depths[:n_best, None].tile(1, n_actions) + 1

            state_values = th.concat([state_values[n_best:], state_values_append.flatten()])
            value_prefixes = th.concat([value_prefixes[n_best:], value_prefixes_append.flatten()])
            latent_states = th.concat([latent_states[n_best:], latent_append.flatten(end_dim=1)])
            search_depths = th.concat([search_depths[n_best:], search_depths_append.flatten()])

            if i != 0:
                action_indices = th.concat(
                    [action_indices[n_best:], action_indices[:n_best, None].tile(1, n_actions).flatten()])

        idx = th.argmax(state_values)
        max_action_idx = action_indices[idx].item()

        policy_distr = state_values + 1.01
        policy_distr /= policy_distr.sum()
        policy_distr = policy_distr ** 4

        return max_action_idx, policy_distr

In [70]:
import torch as th

from utils import ACTIONS_T

class Node():
    def __init__(self, Q, latent_state):
        self.N = th.zeros(len(ACTIONS_T), dtype=th.float)
        self.Q = Q
        self.R = th.zeros(len(ACTIONS_T), dtype=th.float)
        self.S = [None] * len(ACTIONS_T)

        self.latent_state = latent_state

class TreeBuilder():

    def __init__(self, gamma, n_simulations):
        self.gamma = gamma
        self.n_simulations = n_simulations
        self.n_actions = len(ACTIONS_T)

    def build_tree(self, obs, net, evaluation=True):
        inp_obs = th.from_numpy(obs)[None].float()

        latent_state, _, q_values = net.initial_inference(inp_obs)
        root_node = Node(Q=q_values[0], latent_state=latent_state)

        if not evaluation:
            root_node.Q = root_node.Q + (th.rand(self.n_actions)*2-1)*0.1 # add random noise in [-0.25, 0.25]

        for i in range(self.n_simulations):
            l_nodes, l_action_indices = self.selection(root_node)
            last_node, last_action_idx = l_nodes[-1], l_action_indices[-1]
            value = self.expansion(net, last_node, last_action_idx)
            self.backup(l_nodes, l_action_indices, value)
            # print(root_node.Q)

        return root_node

    def selection(self, root_node):
        l_nodes = [root_node]
        l_action_indices = []

        while True:
            node = l_nodes[-1]

            Q_ = (node.Q + 1)/2
            n_sum = th.maximum(th.FloatTensor([1]), node.N.sum())
            ucb = Q_ + n_sum**0.8 / (1 + node.N) * 0.15

            action_idx = ucb.argmax()
            l_action_indices.append(action_idx)

            if node.S[action_idx] is None:
                break
            else:
                l_nodes.append(node.S[action_idx])

        return l_nodes, l_action_indices

    def expansion(self, net, node, action_idx):
        action_1 = ACTIONS_T[action_idx][None].tile(self.n_actions, 1) # [1, 4]
        action_2 = ACTIONS_T                                        # [25, 4]
        latent_inp = node.latent_state.tile(self.n_actions, 1)  # [1, 64]
        next_latent_state, rewards, rewards_logits, next_state_values, q_values = net.recurrent_inference(latent_inp, action_1, action_2)

        state_values = rewards + self.gamma * next_state_values
        min_idx = state_values.argmin()
        next_latent_state = next_latent_state[min_idx][None]
        reward = rewards[min_idx]
        value = next_state_values[min_idx]

        child_node = Node(q_values[0], next_latent_state)
        node.S[action_idx] = child_node
        node.R[action_idx] = reward

        return value

    def backup(self, l_nodes, l_action_indices, value):
        G = value
        for node, action_idx in zip(reversed(l_nodes), reversed(l_action_indices)):
            G = node.R[action_idx] + self.gamma * G

            node.Q[action_idx] = (node.N[action_idx] * node.Q[action_idx] + G) / (node.N[action_idx] + 1)
            node.Q[action_idx] = node.Q[action_idx].clip(-1, 1)
            node.N[action_idx] += 1

In [5]:
tree_builder = TreeBuilder(0.99, 15)
def exploitation_tree(net, obs):
    net.eval()

    with th.no_grad():
        root_node = tree_builder.build_tree(obs, net, evaluation=True)

        action_idx_exploitation = root_node.Q.argmax()
        policy_distr = root_node.Q
        # print(root_node.Q, (root_node.N==15).any())
        sample_distr = policy_distr**4

        return action_idx_exploitation, policy_distr, sample_distr

In [6]:
obs1, _ = env.reset()
obs2 = env.obs_agent_two()
l_obs = [obs1]

while True:
    env.render()
    max_action_idx, policy_distr,_ = exploitation_tree(net, obs1)
    action1 = ACTIONS_T[max_action_idx]
    action2 = opponent.act(obs2)

    obs1, rew, done, trunc, info = env.step(np.hstack([action1, action2]))
    obs2 = env.obs_agent_two()
    l_obs.append(obs1)

    if trunc:
        break

In [82]:
tree_builder = TreeBuilder(0.99, 5)
root_node = tree_builder.build_tree(l_obs[2], net, evaluation=True)

print(root_node.N)
print(root_node.Q)
print((root_node.N!=0)*root_node.Q)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 5., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.])
tensor([-0.9810, -0.9747, -0.9858, -0.8817, -0.8749, -0.8916, -0.8944, -0.8231,
        -0.9881, -0.9603, -0.9732, -0.8391, -0.8974, -0.8857, -0.7623, -0.0373,
        -0.9550, -0.9709, -0.9668, -0.8289, -0.8865, -0.8856, -0.8970, -0.8203,
        -0.9755], grad_fn=<AsStridedBackward0>)
tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0373,
        -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000], grad_fn=<MulBackward0>)


In [77]:
tree_builder = TreeBuilder(0.99, 0)
root_node = tree_builder.build_tree(l_obs[2], net, evaluation=True)

print(root_node.N)
print(root_node.Q)
print((root_node.N!=0)*root_node.Q)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.])
tensor([-0.9810, -0.9747, -0.9858, -0.8817, -0.8749, -0.8916, -0.8944, -0.8231,
        -0.9881, -0.9603, -0.9732, -0.8391, -0.8974, -0.8857, -0.7623, -0.3115,
        -0.9550, -0.9709, -0.9668, -0.8289, -0.8865, -0.8856, -0.8970, -0.8203,
        -0.9755], grad_fn=<SelectBackward0>)
tensor([-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.,
        -0.], grad_fn=<MulBackward0>)


In [8]:
# env.close()