# Communication policy

Load data

In [1]:
import json

def load_pairs(data_path):
    with open(data_path, "r") as f:
        data = json.load(f)
    pairs = []
    for step, s_data in data.items():
        if not step.startswith("Step"):
            continue
        pairs.append({
            "observation": s_data["Agent_0"]["Observation"],
            "sentence": s_data["Agent_0"]["Sentence"][1:-1]
        })
        pairs.append({
            "observation": s_data["Agent_1"]["Observation"],
            "sentence": s_data["Agent_1"]["Sentence"][1:-1]
        })
    return pairs

data_pairs = load_pairs("test_data/Sentences_Generated_P1.json")

train_data = data_pairs[:80000]
test_data = data_pairs[80000:]

In [4]:
import torch

from model.modules.obs import ObservationEncoder
from model.modules.lm import GRUEncoder, GRUDecoder, OneHotEncoder
from model.modules.comm import CommunicationPolicy

cp = CommunicationPolicy(32, 32)

word_encoder = OneHotEncoder(['South','Not','Located','West','Object','Landmark','North','Center','East'])
lang_enc = GRUEncoder(32, word_encoder)
obs_enc = ObservationEncoder(17, 32)
dec = GRUDecoder(32, word_encoder)

obs = torch.Tensor([train_data[0]["observation"]])
sent = [train_data[0]["sentence"]]

int_context = obs_enc(obs)
ext_context = lang_enc(sent)
print(int_context)
print(ext_context)

tensor([[ 0.5278,  0.2009, -0.1439, -0.1794, -0.6431,  0.0351, -0.1024, -0.1001,
          0.6527,  0.6745,  0.0047, -0.3461, -0.2081, -0.3148,  0.1714, -0.3583,
          0.6270, -0.7164, -0.2638, -0.5845,  0.3001,  0.4664,  0.1918, -0.1113,
         -0.2141,  0.3801,  0.2028,  0.0408,  0.8720, -0.1324, -0.6387, -0.2342]],
       grad_fn=<AddmmBackward0>)
tensor([[[ 0.1246, -0.0066,  0.0979, -0.1150,  0.1686, -0.0200, -0.1921,
          -0.0318,  0.1319, -0.0009,  0.2391,  0.1782,  0.1255, -0.0446,
          -0.1537, -0.0456,  0.1508,  0.1353, -0.0618,  0.2609,  0.0446,
          -0.1089, -0.0155, -0.1736, -0.2275,  0.0869, -0.1603,  0.0746,
          -0.1047, -0.2703,  0.0377,  0.2223]]], grad_fn=<AddBackward0>)


In [5]:
output, hidden = cp(int_context, ext_context)

tensor([[[ 0.5278,  0.2009, -0.1439, -0.1794, -0.6431,  0.0351, -0.1024,
          -0.1001,  0.6527,  0.6745,  0.0047, -0.3461, -0.2081, -0.3148,
           0.1714, -0.3583,  0.6270, -0.7164, -0.2638, -0.5845,  0.3001,
           0.4664,  0.1918, -0.1113, -0.2141,  0.3801,  0.2028,  0.0408,
           0.8720, -0.1324, -0.6387, -0.2342,  0.1246, -0.0066,  0.0979,
          -0.1150,  0.1686, -0.0200, -0.1921, -0.0318,  0.1319, -0.0009,
           0.2391,  0.1782,  0.1255, -0.0446, -0.1537, -0.0456,  0.1508,
           0.1353, -0.0618,  0.2609,  0.0446, -0.1089, -0.0155, -0.1736,
          -0.2275,  0.0869, -0.1603,  0.0746, -0.1047, -0.2703,  0.0377,
           0.2223]]], grad_fn=<CatBackward0>)
tensor([[[ 0.2708,  0.1085,  0.0503,  0.3166, -0.0694,  0.0567, -0.0667,
          -0.1321, -0.1112,  0.1799,  0.1097,  0.0664,  0.0462, -0.0452,
          -0.0383, -0.1069, -0.0066, -0.1147, -0.2578, -0.1731, -0.1166,
           0.3007,  0.1309, -0.0035,  0.1634,  0.0251, -0.0161, -0.2419,
     

In [6]:
output

tensor([[ 0.0460,  0.0168,  0.0324,  0.0726, -0.1170, -0.0970, -0.0134, -0.1694,
          0.0113,  0.1130,  0.0469,  0.0866,  0.0182, -0.1436,  0.0485,  0.0856,
         -0.2451,  0.0567,  0.0902, -0.1024, -0.0987, -0.1128,  0.0418, -0.1827,
          0.1020,  0.2304,  0.1658,  0.0539, -0.2380, -0.1792, -0.0792, -0.1636]],
       grad_fn=<AddmmBackward0>)