In [7]:
import torch as th
import torch.nn.functional as F
import numpy as np
import laserhockey.hockey_env as h_env

from own_env import OwnEnv
from action_selection import ActionSelection
from utils import ACTIONS_T

In [2]:
class ResidualBlock(th.nn.Module):
    def __init__(self, inp_out_dim, hidden_dim=128):
        super(ResidualBlock, self).__init__()

        self.inp_out_dim = inp_out_dim

        self.model = th.nn.Sequential(
            th.nn.Linear(inp_out_dim, hidden_dim),
            th.nn.ReLU(),
            th.nn.Linear(hidden_dim, inp_out_dim)
        )

    def forward(self, x):
        out = self.model(x)

        return out + x


class Network(th.nn.Module):
    def __init__(self, obs_dim=18, action_dim=4):
        super(Network, self).__init__()

        self.obs_dim = obs_dim
        self.action_dim = action_dim

        # Dynamic network for predicting next observation and reward
        self.model_repres = th.nn.Sequential(
            th.nn.Linear(obs_dim, 128),
            # th.nn.BatchNorm1d(128),
            th.nn.ReLU(),
            th.nn.Linear(128, 64),
            th.nn.Tanh()
        )

        self.model_dynamic_linear = th.nn.Sequential(
            th.nn.Linear(64 + 2 * action_dim, 64)
        )
        self.model_dynamic_res_block = ResidualBlock(inp_out_dim=64, hidden_dim=128)

        self.model_reward = th.nn.Sequential(
            th.nn.Linear(64 + 2 * action_dim, 128),
            th.nn.ReLU(),
            th.nn.Linear(128, 3)
        )

        self.model_state_value = th.nn.Sequential(
            th.nn.Linear(64, 128),
            th.nn.ReLU(),
            th.nn.Linear(128, 1+ACTIONS_T.shape[0]),
            th.nn.Tanh()
        )

        self.observation_mean = th.FloatTensor([[-2.07, 0, 0, 0, 0, 0, 2.07, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
        self.observation_std = th.FloatTensor([[1.57, 2.91, 1.04, 4, 4, 6, 1.57, 2.91, 1.04, 4, 4, 6, 3.7, 3, 12, 12, 15, 15]])

    def forward(self, x):
        raise NotImplementedError('This function should not be used!')

    def initial_inference(self, obs):
        obs_in = (obs - self.observation_mean) / self.observation_std

        latent_state = self.model_repres(obs_in)
        out = self.model_state_value(latent_state)
        state_values, policy_logits = out[..., 0], out[..., 1:]

        return latent_state, state_values, policy_logits

    def recurrent_inference(self, latent_state, action_1, action_2):
        next_latent_state = self.forward_dynamic(latent_state, action_1, action_2)

        out = self.model_state_value(next_latent_state)
        next_state_values, policy_logits = out[..., 0], out[..., 1:]

        net_inp = th.concat([latent_state, action_1, action_2], dim=1)
        rewards_logits = self.model_reward(net_inp)
        rewards = th.FloatTensor([-1, 0, 1])[rewards_logits.argmax(dim=-1)]

        return next_latent_state, rewards, rewards_logits, next_state_values, policy_logits

    def forward_dynamic(self, latent_state, action_1, action_2):
        net_inp = th.concat([latent_state, action_1, action_2], dim=1)
        state = self.model_dynamic_linear(net_inp)
        state = F.relu(state + latent_state)
        state = self.model_dynamic_res_block(state)

        return F.tanh(state)


In [446]:
env = OwnEnv()
opponent = h_env.BasicOpponent(weak=True)
net = Network().eval()
net.load_state_dict(th.load('checkpoints/agent_1250000.pth'))
action_selection = ActionSelection(0.95, 5)

In [490]:
N_BEST = 10
N_TIMES = 2
gamma = 0.95

n_actions = ACTIONS_T.shape[0]

class Node():
    def __init__(self, latent_state, state_value, policy):
        self.state_value = state_value
        self.latent_state = latent_state
        self.R = th.zeros(n_actions, dtype=th.float)
        self.S = n_actions * [None]
        self.N = th.zeros(n_actions, dtype=th.float)
        self.Q = th.zeros(n_actions, dtype=th.float)
        self.P = policy        

class TreeBuilder():
    def __init__(self, n_simulations):
        self.n_simulations = n_simulations

    def build_tree(self, obs, net):
        obs_in = th.from_numpy(obs)[None].float()
        latent_state, state_value, policy_logits = net.initial_inference(obs_in)
        policy = th.softmax(policy_logits[0], dim=0)
        
        root_node = Node(latent_state, state_value[0], policy)

        for _ in range(self.n_simulations):
            l_nodes, l_action_indices = self.selection(root_node)
            last_node, last_action_idx = l_nodes[-1], l_action_indices[-1]
            self.expansion(net, last_node, last_action_idx)
            self.backup()

        raise NotImplementedError()
        return root_node

    def selection(self, root_node):
        l_nodes = [root_node]
        l_action_indices = []
        
        while True:
            node = l_nodes[-1]
            n_sum = th.maximum(node.N.sum(), th.FloatTensor([1]))
            
            ucb = node.Q + 1.25 * node.P * th.sqrt(n_sum) / (1 + node.N)
            action_idx = ucb.argmax()
            l_action_indices.append(action_idx)

            if node.S[action_idx] is None:
                break
            else:
                l_nodes.append(node.S[action_idx])
        
        return l_nodes, l_action_indices

    def expansion(self, net, node, action_idx):
        action_1 = ACTIONS_T[action_idx][None].tile(n_actions, 1)
        action_2 = ACTIONS_T
        latent_inp = node.latent_state.tile(n_actions, 1)
        
        next_latent_state, rewards, _, next_state_values, policy_logits = net.recurrent_inference(latent_inp, action_1, action_2)
        
        min_idx = next_state_values.argmin()
        latent_state = next_latent_state[min_idx][None]
        value = next_state_values[min_idx]
        policy = th.softmax(policy_logits[min_idx], dim=0)

        child_node = Node(latent_state, value, policy)
        
        node.R[action_idx] = rewards[min_idx]
        node.S[action_idx] = child_node

        return value

        
    def backup(self):
        raise NotImplementedError()
        

tree_builder = TreeBuilder(50)
def exploitation_v4(net, obs):
    net.eval()

    with th.no_grad():
        root_node = tree_builder.build_tree(obs, net)
        
        return max_action_idx, policy_distr

In [491]:
%%time
obs1, _ = env.reset()
obs2 = env.obs_agent_two()
l_obs = [obs1]

im = 0
while True:
    # env.render()
    # max_action_idx, policy_distr = action_selection.exploitation_v2(net, obs1)
    # max_action_idx, policy_distr = exploitation_v3(net, obs1)
    max_action_idx, policy_distr = exploitation_v4(net, obs1)
    action1 = ACTIONS_T[max_action_idx]
    action2 = opponent.act(obs2)

    obs1, rew, done, trunc, info = env.step(np.hstack([action1, action2]))
    obs2 = env.obs_agent_two()
    l_obs.append(obs1)

    im += 1

    if trunc:
        break

print(im)

torch.Size([25])


NotImplementedError: 

In [430]:
obs = l_obs[0]

max_action_idx, policy_distr, latent_states, search_depths, value_prefixes, action_indices, state_values = exploitation_v2(net, obs)

print(latent_states.sum(dim=1))
# print(state_values)

tensor([10.4211, 10.5620, 10.9890, 11.1034, 10.6201, 10.0202, 10.1682, 10.2735,
        10.3625, 10.5093, 11.1403, 10.9367, 10.9626, 10.4226,  9.9553, 10.8211,
        11.1367, 10.4718, 10.7605, 10.9001, 10.9816, 10.4303,  9.9579, 10.1462,
        10.4361], grad_fn=<SumBackward1>)


In [431]:
obs_in = th.from_numpy(obs)[None].float()
latent_states, _, _ = net.initial_inference(obs_in)
# for j in range(25):
#     latent_out, rewards, _, next_state_values, _ = net.recurrent_inference(latent_states, ACTIONS_T[1][None], ACTIONS_T[16][None])


#     print(latent_out.sum()==)
t = th.empty((25, 25))
for i in range(n_actions):
    min_value = th.FloatTensor([1e9])
    min_idx = -1
    min_latent_state = 0
    for j in range(n_actions):
        latent_out, rewards, _, next_state_values, _ = net.recurrent_inference(latent_states, ACTIONS_T[i][None], ACTIONS_T[j][None])

        t[i,j] = next_state_values*gamma
        if t[i,j]<min_value:
            min_value=t[i,j]
            min_idx = j
            min_latent_state = latent_out
    print(min_value, min_idx, min_latent_state.sum()) 

tensor(0.2055, grad_fn=<AsStridedBackward0>) 16 tensor(10.4211, grad_fn=<SumBackward0>)
tensor(0.2244, grad_fn=<AsStridedBackward0>) 16 tensor(10.5620, grad_fn=<SumBackward0>)
tensor(0.2241, grad_fn=<AsStridedBackward0>) 16 tensor(10.9890, grad_fn=<SumBackward0>)
tensor(0.2063, grad_fn=<AsStridedBackward0>) 16 tensor(11.1034, grad_fn=<SumBackward0>)
tensor(0.2018, grad_fn=<AsStridedBackward0>) 16 tensor(10.6201, grad_fn=<SumBackward0>)
tensor(0.2015, grad_fn=<AsStridedBackward0>) 16 tensor(10.0202, grad_fn=<SumBackward0>)
tensor(0.1716, grad_fn=<AsStridedBackward0>) 2 tensor(10.1682, grad_fn=<SumBackward0>)
tensor(0.1989, grad_fn=<AsStridedBackward0>) 16 tensor(10.2735, grad_fn=<SumBackward0>)
tensor(0.2047, grad_fn=<AsStridedBackward0>) 8 tensor(10.3625, grad_fn=<SumBackward0>)
tensor(0.1875, grad_fn=<AsStridedBackward0>) 16 tensor(10.5093, grad_fn=<SumBackward0>)
tensor(0.1850, grad_fn=<AsStridedBackward0>) 9 tensor(11.1403, grad_fn=<SumBackward0>)
tensor(0.1850, grad_fn=<AsStridedBa

In [445]:
env.close()