In [13]:
import torch

### LNovelD

In [2]:
from model.modules.lnoveld import LNovelD

m = LNovelD(2, 2, 5)

In [3]:
m

LNovelD(
  (obs_noveld): NovelD(
    (target): MLPNetwork(
      (mlp): Sequential(
        (0): Linear(in_features=2, out_features=64, bias=True)
        (1): ReLU()
        (2): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=True)
          (1): ReLU()
        )
        (3): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=True)
          (1): ReLU()
        )
        (4): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=True)
          (1): ReLU()
        )
        (5): Linear(in_features=64, out_features=5, bias=True)
      )
    )
    (predictor): MLPNetwork(
      (mlp): Sequential(
        (0): Linear(in_features=2, out_features=64, bias=True)
        (1): ReLU()
        (2): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=True)
          (1): ReLU()
        )
        (3): Sequential(
          (0): Linear(in_features=64, out_features=64, bias=True)
          (1): ReLU()
        )
    

In [5]:
obs = torch.Tensor([[10.0, 10.0]])
m(obs, None)


target tensor([[-1.8384,  0.9695, -3.8054, -2.1500,  3.6722]])
pred tensor([[ 2.2023, -4.3319, -0.7381, -3.4516,  0.7183]],
       grad_fn=<AddmmBackward0>)
nov tensor([8.0162])
last nov tensor([0.8016])
comp tensor([-3.2065])
7.615405559539795


7.615405559539795

In [6]:
m.obs_noveld.episode_states_count

{(1.0, 1.0): 1, (10.0, 10.0): 1}

### Encoder

In [1]:
import torch
from model.modules.lm import GRUEncoder, OneHotEncoder

gru = GRUEncoder(
    10, 
    OneHotEncoder(['South','Not','Located','West','Object','Landmark','North','Center','East']))

opt = torch.optim.Adam(gru.gru.parameters(), lr=0.1)

sentences = [
    ["Located", "South"], 
    ["Located", "Center", "Object", "East"], 
    ["Located", "South"], 
    ["Located", "East", "Object", "South", "East"], 
    ["Located", "South", "West"]
]
sentences = [["<SOS>"] + s + ["<EOS>"] for s in sentences]

encs = gru(sentences)
gru(sentences)


tensor([[[ 0.0602, -0.1766, -0.0372,  0.1486,  0.3524,  0.0492, -0.0064,
           0.1933, -0.3469, -0.0596],
         [ 0.1362, -0.3001,  0.0173,  0.2752,  0.3259, -0.0171,  0.0233,
           0.0974, -0.2904, -0.0622],
         [ 0.0602, -0.1766, -0.0372,  0.1486,  0.3524,  0.0492, -0.0064,
           0.1933, -0.3469, -0.0596],
         [ 0.1281, -0.2525,  0.0050,  0.2414,  0.3356,  0.0387,  0.0285,
           0.1468, -0.3164, -0.0694],
         [ 0.1491, -0.2382,  0.0952,  0.1408,  0.3274,  0.1251, -0.0039,
           0.1219, -0.3603,  0.0718]]], grad_fn=<CopySlices>)

In [11]:
opt.zero_grad()
encs.mean().backward()
opt.step()

In [12]:
gru(sentences)

tensor([[[-0.4974, -0.3791, -0.1940, -0.3039, -0.1669, -0.2489, -0.5031,
          -0.1984, -0.5563, -0.0602],
         [-0.4415, -0.3436, -0.3280, -0.3862, -0.2037, -0.2943, -0.6240,
          -0.2357, -0.5833,  0.0067],
         [-0.4974, -0.3791, -0.1940, -0.3039, -0.1669, -0.2489, -0.5031,
          -0.1984, -0.5563, -0.0602],
         [-0.4591, -0.3452, -0.3098, -0.3782, -0.2319, -0.2676, -0.6360,
          -0.1869, -0.5940,  0.0045],
         [-0.5177, -0.3911, -0.2654, -0.3371, -0.1657, -0.3111, -0.5693,
          -0.2637, -0.5415, -0.0435]]], grad_fn=<CopySlices>)

### Decoder

In [3]:
import torch
from model.modules.lm import GRUDecoder, OneHotEncoder

word_encoder = OneHotEncoder(
    ['South','Not','Located','West','Object','Landmark','North','Center','East'])
dec = GRUDecoder(10, word_encoder)

last_hidden = torch.ones((1, 1, 10))
last_word = torch.Tensor(word_encoder.SOS_ENC).unsqueeze(0).unsqueeze(0)

dec.forward_step(last_word, last_hidden)

(tensor([[[-2.3761, -2.1622, -2.5260, -2.9682, -2.4728, -2.2131, -2.8585,
           -2.2424, -2.4509, -1.8872, -2.7231]]], grad_fn=<LogSoftmaxBackward0>),
 tensor([[[0.4348, 0.5635, 0.5135, 0.6628, 0.8795, 0.6910, 0.2694, 0.4919,
           0.5127, 0.3439]]], grad_fn=<StackBackward0>))

In [4]:
context = torch.ones(2, 10)
targets = [
    torch.Tensor([
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
    torch.Tensor([
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
]

dec(context)

Batch # 0
Token # 0
torch.Size([1, 1, 11])
tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
torch.Size([1, 1, 10])
tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])
Token # 1
torch.Size([1, 1, 11])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]])
torch.Size([1, 1, 10])
tensor([[[0.4348, 0.5635, 0.5135, 0.6628, 0.8795, 0.6910, 0.2694, 0.4919,
          0.5127, 0.3439]]], grad_fn=<StackBackward0>)
Token # 2
torch.Size([1, 1, 11])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]])
torch.Size([1, 1, 10])
tensor([[[0.2828, 0.1417, 0.1974, 0.1851, 0.6780, 0.4170, 0.1458, 0.2004,
          0.2384, 0.1327]]], grad_fn=<StackBackward0>)
Token # 3
torch.Size([1, 1, 11])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]])
torch.Size([1, 1, 10])
tensor([[[ 0.2431, -0.0297,  0.0631, -0.0626,  0.4887,  0.2018,  0.0724,
           0.0462,  0.0697,  0.0931]]], grad_fn=<StackBackward0>)
Token # 4
torch.Size([1, 1, 11])
tensor([[[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]]]

([tensor([[[[-2.3761, -2.1622, -2.5260, -2.9682, -2.4728, -2.2131, -2.8585,
             -2.2424, -2.4509, -1.8872, -2.7231]]],
  
  
          [[[-2.3894, -2.2810, -2.2947, -2.7736, -2.6768, -2.1555, -2.7184,
             -2.4017, -2.3613, -2.0369, -2.5638]]],
  
  
          [[[-2.4044, -2.3551, -2.2080, -2.6710, -2.7801, -2.1544, -2.6412,
             -2.4887, -2.3061, -2.1284, -2.4682]]],
  
  
          [[[-2.4075, -2.3935, -2.1798, -2.6119, -2.8209, -2.1808, -2.5909,
             -2.5350, -2.2695, -2.1859, -2.4109]]],
  
  
          [[[-2.4415, -2.4617, -2.1842, -2.7214, -2.6913, -2.1454, -2.4708,
             -2.4078, -2.3018, -2.2820, -2.4321]]],
  
  
          [[[-2.4830, -2.5452, -2.1819, -2.6799, -2.7788, -2.0848, -2.4010,
             -2.4065, -2.3098, -2.3142, -2.3934]]],
  
  
          [[[-2.4895, -2.5917, -2.1836, -2.6582, -2.8194, -2.0641, -2.3460,
             -2.4049, -2.3153, -2.3508, -2.3771]]],
  
  
          [[[-2.4848, -2.6166, -2.1885, -2.6471, -2.8381, -2.0

### Train Encoder-Decoder

Load sentences

In [9]:
import json

def load_sentences(data_path):
    with open(data_path, "r") as f:
        data = json.load(f)
    sentences = []
    for step, s_data in data.items():
        if not step.startswith("Step"):
            continue
        sentences.append(s_data["Agent_0"]["Sentence"][1:-1])
        sentences.append(s_data["Agent_1"]["Sentence"][1:-1])
    return sentences
sentences = load_sentences("test_data/Sentences_Generated_P1.json")

Initialise Encoder and Decoder

In [None]:
import torch
from model.modules.lm import GRUEncoder, GRUDecoder, OneHotEncoder

word_encoder = OneHotEncoder(
    ['South','Not','Located','West','Object','Landmark','North','Center','East'])