# Create environment and load model

In [1]:
import os
import json
import time
import torch

from src.utils.config import get_config
from src.utils.utils import set_seeds, load_args, set_cuda_device
from src.envs.make_env import make_env
from src.lmc.lmc_context import LMC

def create_cfg(args_file):
    class Config:
        pass
    cfg = Config()
    
    with open(args_file) as f:
        [next(f) for i in range(3)]
        args = json.load(f)
    args.pop("cuda_device")
    
    for a in args:
        setattr(cfg, a, args[a])
    
    return cfg


cfg = create_cfg("../../models/magym_PredPrey/FT_ShMloc_shap_reccommpol/run4/args.txt")
cfg.n_parallel_envs = 1

# Create train environment
envs, parser = make_env(cfg, cfg.n_parallel_envs)

# Create model
n_agents = envs.n_agents
obs_space = envs.observation_space
shared_obs_space = envs.shared_observation_space
act_space = envs.action_space
global_state_dim = envs.global_state_dim
model = LMC(
    cfg, 
    n_agents, 
    obs_space, 
    shared_obs_space, 
    act_space,
    parser.vocab, 
    global_state_dim)

# Load params
model.load("../../models/magym_PredPrey/FT_ShMloc_shap_reccommpol/run4/model_ep.pt")
model.prep_rollout()
torch.no_grad()



<torch.autograd.grad_mode.no_grad at 0x7f9e2ab910a0>

In [57]:
print(obs_space)

[Box([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1.], (27,), float32), Box([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1.], (27,), float32), Box([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1.], (27,), float32), Box([0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1.], (27,), float32)]


# Computing perplexity of perfect sentences

In [9]:
perf_sent = [
    ["Prey", "Located", "South"],
    ["Prey", "Located", "North", "Prey", "Located", "South"],
    ["<SOS>", "<SOS>"]
]

ll = model.lang_learner

onehot_sent = ll.word_encoder.encode_batch(perf_sent)
onehot_sent

[tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]])]

In [3]:
hidden = torch.zeros((1, 1, 16))
last_tokens = torch.Tensor(ll.word_encoder.SOS_ENC).view(
    1, 1, -1)

In [4]:
output, hidden = ll.decoder.forward_step(last_tokens, hidden)
print(output)
print(output.exp())

tensor([[[-2.8816, -1.8924, -1.7783, -2.2814, -2.7779, -2.3432, -2.2925,
          -2.5420, -2.5612, -2.2344]]], grad_fn=<LogSoftmaxBackward0>)
tensor([[[0.0560, 0.1507, 0.1689, 0.1021, 0.0622, 0.0960, 0.1010, 0.0787,
          0.0772, 0.1071]]], grad_fn=<ExpBackward0>)


In [5]:
output.exp().sum()

tensor(1.0000, grad_fn=<SumBackward0>)

In [6]:
last_tokens

tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [55]:
decoder = ll.decoder
n_layers = 1
def compute_pp(enc_sent):
    batch_size = len(enc_sent)
    max_sent_len = max([len(s) for s in enc_sent])

    hidden = torch.zeros((n_layers, batch_size, decoder.hidden_dim))
    last_tokens = torch.Tensor(decoder.word_encoder.SOS_ENC).view(
        1, 1, -1).repeat(1, batch_size, 1).to(decoder.device)

    pnorm = torch.ones(batch_size)
    for t_i in range(max_sent_len):
        # RNN pass
        outputs, hidden = decoder.forward_step(last_tokens, hidden)

        # Compute PP
        probs = outputs.exp().squeeze(0)
        for s_i in range(batch_size):
            len_s = enc_sent[s_i].size(0)
            if t_i < len_s:
                token_prob = (probs[s_i] * enc_sent[s_i][t_i]).sum(-1)
                pnorm[s_i] *= (token_prob ** (1 / len_s))

        # Do teacher forcing
        last_tokens = torch.zeros_like(last_tokens).to(decoder.device)
        for s_i in range(batch_size):
            if t_i < enc_sent[s_i].size(0):
                last_tokens[0, s_i] = enc_sent[s_i][t_i]

    pnorm = 1 / pnorm

    return pnorm

compute_pp(onehot_sent)

tensor([ 2.7681,  3.9997, 13.0349], grad_fn=<MulBackward0>)