In [4]:
import torch

from src.lmc.modules.lang_learner import LanguageLearner
from src.envs.ma_gym.envs.predator_prey.predator_prey import PredatorPrey
from src.envs.parsers.predator_prey import PredatorPrey_Parser 

magym_env_size = 8
magym_n_agents = 4
magym_n_preys = 2
episode_length = 100
env = PredatorPrey(
    n_agents=magym_n_agents, 
    grid_shape=(magym_env_size, magym_env_size),
    n_preys=magym_n_preys, 
    max_steps=episode_length)
parser = PredatorPrey_Parser(magym_env_size)
    
n_agents = env.n_agents
obs_space = env.observation_space
shared_obs_space = env.shared_observation_space
act_space = env.action_space

context_dim = 16
lang_hidden_dim = 32

lang_learner = LanguageLearner(
            obs_space[0].shape[0], 
            context_dim, 
            lang_hidden_dim, 
            parser.vocab)

save_dict = torch.load(
    "../../models/magym_PredPrey/mappo_perfectcomm_8x8/run10/model_ep.pt", map_location=torch.device("cpu"))

lang_learner.load_params(save_dict)

In [5]:
obs = env.reset()
obs

[[0.0,
  0.42857142857142855,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.2857142857142857,
  0.2857142857142857,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.2857142857142857,
  0.7142857142857143,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 [0.14285714285714285,
  0.8571428571428571,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0]]

In [8]:
contexts = lang_learner.encode_observations(torch.Tensor(obs))
captions = lang_learner.generate_sentences(contexts)
captions

[['Prey', 'Located', 'North', 'East'],
 ['Prey', 'Located', 'North', 'East'],
 ['Prey', 'Located', 'North', 'East'],
 ['Prey', 'Located', 'North', 'East']]

In [16]:
import numpy as np
parsed_obs = parser.get_perfect_messages(np.array([obs]))
parsed_obs

[[[], [], [], []]]

In [19]:
lang_contexts = lang_learner.encode_sentences(parsed_obs[0])
lang_captions = lang_learner.generate_sentences(lang_contexts)
lang_captions

[['Center'], ['Center'], ['Center'], ['Center']]

In [20]:
lang_contexts

tensor([[-0.3625,  1.4918, -1.7716, -0.5429,  2.1982,  1.0193,  0.4199, -0.1911,
         -2.0710,  0.6141, -0.6164, -0.3448, -0.7480, -0.4211,  0.1539,  0.3899],
        [-0.3625,  1.4918, -1.7716, -0.5429,  2.1982,  1.0193,  0.4199, -0.1911,
         -2.0710,  0.6141, -0.6164, -0.3448, -0.7480, -0.4211,  0.1539,  0.3899],
        [-0.3625,  1.4918, -1.7716, -0.5429,  2.1982,  1.0193,  0.4199, -0.1911,
         -2.0710,  0.6141, -0.6164, -0.3448, -0.7480, -0.4211,  0.1539,  0.3899],
        [-0.3625,  1.4918, -1.7716, -0.5429,  2.1982,  1.0193,  0.4199, -0.1911,
         -2.0710,  0.6141, -0.6164, -0.3448, -0.7480, -0.4211,  0.1539,  0.3899]],
       grad_fn=<SqueezeBackward1>)

In [21]:
contexts

tensor([[ 3.5465, -2.2015,  3.3764, -0.6979, -0.9155,  0.3608, -1.8657,  1.2274,
         -2.4666,  1.3235,  2.0682, -2.6199,  1.6878, -2.1415, -0.9029, -0.7794],
        [ 3.4802, -2.1770,  3.4070, -0.6994, -0.8836,  0.3601, -1.8812,  1.1857,
         -2.4349,  1.3641,  2.0760, -2.6584,  1.7461, -2.1806, -0.8869, -0.8033],
        [ 3.4802, -2.2248,  3.4232, -0.7183, -0.8630,  0.3766, -1.8462,  1.1511,
         -2.4588,  1.3690,  2.0836, -2.6205,  1.7285, -2.1801, -0.8783, -0.8056],
        [ 3.5441, -2.2246,  3.4070, -0.7067, -0.8883,  0.3736, -1.8685,  1.1570,
         -2.4829,  1.3378,  2.0745, -2.5702,  1.6991, -2.1426, -0.9283, -0.7732]],
       grad_fn=<NativeLayerNormBackward0>)