## Add Embedding layers to Language Modules

In [18]:
import numpy as np
import torch

from torch import nn

from src.algo.language.lm import GRUDecoder, init_rnn_params
from src.algo.language.obs import ObservationEncoder


class OneHotEncoder:
    """
    Class managing the vocabulary and its one-hot encodings
    """
    def __init__(self, vocab, max_message_len=10):
        """
        Inputs:
            :param vocab (list): List of tokens that can appear in the language
        """
        self.tokens = ["<SOS>", "<EOS>"] + vocab
        self.enc_dim = len(self.tokens)
        self.token_encodings = np.eye(self.enc_dim)
        self.max_message_len = max_message_len + 1

        self.SOS_ENC = self.token_encodings[0]
        self.EOS_ENC = self.token_encodings[1]

        self.SOS_ID = 0
        self.EOS_ID = 1

    def index2token(self, index):
        """
        Returns the token corresponding to the given index in the vocabulary
        Inputs:
            :param index (int)
        Outputs:
            :param token (str)
        """
        if type(index) in [list, np.ndarray]:
            return [self.tokens[i] for i in index]
        else:
            return self.tokens[index]

    def enc2token(self, encoding):
        """
        Returns the token corresponding to the given one-hot encoding.
        Inputs:
            :param encoding (numpy.array): One-hot encoding.
        Outputs:
            :param token (str): Corresponding token.
        """
        if len(encoding.shape) == 1:
            return self.tokens[np.argmax(encoding)]
        elif len(encoding.shape) == 2:
            return [self.tokens[np.argmax(enconding[i])] for i in range(encoding.shape[0])]
        else:
            raise NotImplementedError("Wrong index type")

    def get_onehots(self, sentence):
        """
        Transforms a sentence into a list of corresponding one-hot encodings
        Inputs:
            :param sentence (list): Input sentence, made of a list of tokens
        Outputs:
            :param onehots (list): List of one-hot encodings
        """
        onehots = [
            self.token_encodings[self.tokens.index(t)] 
            for t in sentence
        ]
        return onehots

    def get_ids(self, sentence):
        ids = [
            self.tokens.index(t) 
            for t in sentence]
        return ids
    
    def encode_batch(self, sentences, pad=False):
        enc = []
        for s in sentences:
            enc_s = self.get_ids(s)
            
            enc_s.append(self.EOS_ID)
            
            if pad:
                enc_s.extend([-1] * (self.max_message_len - len(enc_s)))
            print(len(enc_s))
            enc.append(enc_s)
        
        if pad:
        
            enc = np.array(enc)
        
        return enc

    def ids_to_onehots(self, ids_batch):
        if type(ids_batch) is list:
            onehots = [
                self.token_encodings[ids]
                for ids in ids_batch]
            return onehots
        elif type(ids_batch) is np.ndarray:
            return self.token_encodings[ids]

    def decode_batch(self, token_batch):
        """
        Decode batch of encoded sentences
        Inputs:
            :param token_batch (list): List of encoded sentences.
        Outputs:
            :param decoded_batch (list): List of sentences.
        """
        decoded_batch = []
        for enc_sentence in token_batch:
            sentence = []
            for token in enc_sentence:
                if type(token) is list:
                    sentence.append(self.enc2token(token))
                else:
                    if token == 1:
                        break
                    sentence.append(self.index2token(token))
            decoded_batch.append(sentence)
        return decoded_batch

class GRUEncoder(nn.Module):
    """
    Class for a language encoder using a Gated Recurrent Unit network
    """
    def __init__(self, context_dim, hidden_dim, embed_dim, word_encoder, 
                 n_layers=1, device='cpu', do_embed=True):
        """
        Inputs:
            :param context_dim (int): Dimension of the context vectors (output
                of the model).
            :param hidden_dim (int): Dimension of the hidden state of the GRU
                newtork.
            :param word_encoder (OneHotEncoder): Word encoder, associating 
                tokens with one-hot encodings
            :param n_layers (int): number of layers in the GRU (default: 1)
            :param device (str): CUDA device
        """
        super(GRUEncoder, self).__init__()
        self.device = device
        self.word_encoder = word_encoder
        self.context_dim = context_dim
        self.hidden_dim = hidden_dim
        self.do_embed = do_embed
        
        self.embed_layer = nn.Embedding(self.word_encoder.enc_dim, embed_dim, padding_idx=-1)
        
        if not self.do_embed:
            embed_dim = self.word_encoder.enc_dim
            
        self.gru = nn.GRU(
            embed_dim, 
            self.hidden_dim, 
            n_layers,
            batch_first=True)
        init_rnn_params(self.gru)
        
        self.out = nn.Linear(self.hidden_dim, context_dim)
        self.norm = nn.LayerNorm(context_dim)
        
    def embed_sentences(self, sent_batch):
        # Get one-hot encodings
        enc_sent_batch = self.word_encoder.encode_batch(sent_batch)
        
        # Embed
        if self.do_embed:
            enc_ids_batch = [s.argmax(-1) for s in enc_sent_batch]
            return [self.embed_layer(s) for s in enc_ids_batch]
        else:
            return enc_sent_batch

    def forward(self, sent_batch):
        """
        Transforms sentences into embeddings
        Inputs:
            :param sentence_batch (list(list(str))): Batch of sentences.
        Outputs:
            :param unsorted_hstates (torch.Tensor): Final hidden states
                corresponding to each given sentence, dim=(1, batch_size, 
                context_dim)
        """
        # Get one-hot encodings
        enc_sent_batch = self.word_encoder.encode_batch(sent_batch)

        # Get order of sententes sorted by length decreasing
        ids = sorted(
            range(len(enc_sent_batch)), 
            key=lambda x: len(enc_sent_batch[x]), 
            reverse=True)

        # Sort the sentences by length
        sorted_list = [enc_sent_batch[i] for i in ids]

        # Embed
        if self.do_embed:
            # enc_ids_batch = [s.argmax(-1) for s in sorted_list]
            model_input = [
                self.embed_layer(torch.from_numpy(
                    np.array(s)).to(self.device)) 
                for s in sorted_list]
        else:
            model_input = [
                torch.Tensor(self.word_encoder.ids_to_onehots(s))
                for s in sorted_list]

        # Pad sentences
        padded = nn.utils.rnn.pad_sequence(
            model_input, batch_first=True)

        # Pack padded sentences (to not care about padded tokens)
        lens = [len(s) for s in sorted_list]
        packed = nn.utils.rnn.pack_padded_sequence(
            padded, lens, batch_first=True).to(self.device)

        # Initial hidden state
        hidden = torch.zeros(1, len(enc_sent_batch), self.hidden_dim, 
                        device=self.device)
        
        # Pass sentences into GRU model
        _, hidden_states = self.gru(packed, hidden)

        # Re-order hidden states
        unsorted_hstates = torch.zeros_like(hidden_states).to(self.device)
        unsorted_hstates[0,ids,:] = hidden_states[0,:,:]

        return self.norm(self.out(unsorted_hstates))

    def get_params(self):
        return {'gru': self.gru.state_dict(),
                'out': self.out.state_dict()}

In [11]:
import json
import time
import random
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn, optim

import matplotlib.pyplot as plt

def load_pairs(data_path):
    with open(data_path, "r") as f:
        data = json.load(f)
    pairs = []
    for step, s_data in data.items():
        if not step.startswith("Step"):
            continue
        pairs.append({
            "observation": s_data["Agent_0"]["Observation"],
            "sentence": s_data["Agent_0"]["Sentence"][1:-1]
        })
        pairs.append({
            "observation": s_data["Agent_1"]["Observation"],
            "sentence": s_data["Agent_1"]["Sentence"][1:-1]
        })
    return pairs

def init_training_objects(voc, context_dim, obs_dim, embed_dim, lr, do_embed=True):
    word_encoder = OneHotEncoder(voc)

    lang_enc = GRUEncoder(context_dim, 32, embed_dim, word_encoder, do_embed=do_embed)
    dec = GRUDecoder(context_dim, embed_dim, word_encoder)

    obs_enc = ObservationEncoder(obs_dim, context_dim)

    cross_ent_l = nn.CrossEntropyLoss()
    nll_l = nn.NLLLoss()
    opt = optim.Adam(list(lang_enc.parameters()) + list(dec.parameters()) + list(obs_enc.parameters()), lr=lr)
    
    return word_encoder, lang_enc, obs_enc, dec, cross_ent_l, nll_l, opt

def sample_batch(data, batch_size):
    batch = random.sample(data, batch_size)
    obs_batch = []
    sent_batch = []
    for pair in batch:
        obs_batch.append(pair["observation"])
        sent_batch.append(pair["sentence"])
    return obs_batch, sent_batch

def get_losses(obs_batch, sent_batch, obs_enc, lang_enc, dec, temp, cross_ent_loss, nll_loss, obs_learn_capt):
    # Encode observations
    obs_tensor = torch.Tensor(np.array(obs_batch))
    context_batch = obs_enc(obs_tensor)

    # Encode sentence
    lang_context_batch = lang_enc(sent_batch)
    lang_context_batch = lang_context_batch.squeeze()

    # Compute similarity
    norm_context_batch = context_batch / context_batch.norm(dim=1, keepdim=True)
    lang_context_batch = lang_context_batch / lang_context_batch.norm(dim=1, keepdim=True)
    sim = norm_context_batch @ lang_context_batch.t() * temp
    mean_sim = sim.diag().mean()

    # Compute loss
    labels = torch.arange(len(obs_batch))
    loss_o = cross_ent_loss(sim, labels)
    loss_l = cross_ent_loss(sim.t(), labels)
    clip_loss = (loss_o + loss_l) / 2
    
    # Decoding
    encoded_targets = word_encoder.encode_batch(sent_batch)
    if not obs_learn_capt:
        context_batch = context_batch.detach()
    decoder_outputs, _ = dec(context_batch, encoded_targets)

    # Compute loss
    dec_loss = 0
    for d_o, e_t in zip(decoder_outputs, encoded_targets):
        e_t = torch.argmax(e_t, dim=1)
        dec_loss += nll_loss(d_o[:e_t.size(0)], e_t)
    
    return clip_loss, dec_loss, mean_sim

def train(data, obs_enc, lang_enc, dec, word_encoder, cross_ent_loss, nll_loss, opt, 
          n_iters=80000, batch_size=128, temp=0.07, eval_data=None, eval_evry=1000,
          sample_fn=sample_batch, clip_weight=1.0, capt_weight=1.0, obs_learn_capt=False):
    start = time.time()
    
    clip_train_losses = []
    clip_eval_losses = []
    dec_train_losses = []
    dec_eval_losses = []
    eval_sims = []
    
    for s_i in tqdm(range(n_iters)):
        opt.zero_grad()
        
        # Sample batch
        obs_batch, sent_batch = sample_fn(data, batch_size)
        
        # Compute both losses
        clip_loss, dec_loss, _ = get_losses(obs_batch, sent_batch, obs_enc, lang_enc, dec, temp, 
                                         cross_ent_loss, nll_loss, obs_learn_capt)
        
        # Backprop
        tot_loss = clip_weight * clip_loss + capt_weight * dec_loss
        tot_loss.backward()
        opt.step()
        
        clip_train_losses.append(clip_loss.item())
        dec_train_losses.append(dec_loss.item() / batch_size)
        
        if eval_data is not None and (s_i + 1) % eval_evry == 0:
            with torch.no_grad():
                # Sample batch
                obs_batch, sent_batch = sample_fn(eval_data, batch_size)
                
                # Get both losses
                clip_loss, dec_loss, sim = get_losses(
                    obs_batch, sent_batch, obs_enc, lang_enc, dec, temp, cross_ent_loss, nll_loss, obs_learn_capt)
                clip_eval_losses.append(clip_loss.item())
                dec_eval_losses.append(dec_loss.item() / batch_size)
                eval_sims.append(sim)
    
    return clip_train_losses, clip_eval_losses, dec_train_losses, dec_eval_losses, eval_sims

def plot_curves(curves, titles):
    """
    Inputs:
        curves (list(list(list(float)))): list of list of training curves, each element i of the main list is a list
            of all the training curves to plot in the subplot i.
    """
    nb_subplots = len(curves)
    
    fig, axs = plt.subplots(1, nb_subplots, figsize=(15,6))
    if type(axs) is not np.ndarray:
        axs = [axs]
    for ax, plot, title in zip(axs, curves, titles):
        max_len = max([len(c) for c in plot])
        for c in plot:
            c_len = len(c)
            if c_len == max_len:
                ax.plot(c)
            else:
                inter = max_len / c_len
                ax.plot((np.arange(c_len) + 1) * inter, c)
        ax.set_title(title)

In [12]:
data_pairs = load_pairs("../MALNovelD/test_data/Sentences_Generated_P1.json")

train_data = data_pairs[:80000]
test_data = data_pairs[80000:]

In [13]:
b = sample_batch(train_data, 10)
s = b[1]
s

[['Located', 'South'],
 ['Located', 'Center'],
 ['Located',
  'South',
  'West',
  'Object',
  'Not',
  'Located',
  'South',
  'West',
  'Landmark',
  'Not',
  'Located',
  'South',
  'West'],
 ['Located', 'South'],
 ['Located', 'North', 'East'],
 ['Located', 'Center', 'Landmark', 'South', 'West'],
 ['Located', 'South', 'West'],
 ['Located', 'North'],
 ['Located', 'Center'],
 ['Located', 'North']]

In [19]:
word_encoder, lang_enc, obs_enc, dec, cross_ent_l, nll_l, opt = init_training_objects(
    ['South','Not','Located','West','Object','Landmark','North','Center','East'],
    16, 17, 4, 0.007)
word_encoder.encode_batch(s, pad=True)

[[4, 2, 1, -1, -1, -1, -1, -1, -1, -1, -1], [4, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1], [4, 2, 5, 6, 3, 4, 2, 5, 7, 3, 4, 2, 5, 1], [4, 2, 1, -1, -1, -1, -1, -1, -1, -1, -1], [4, 8, 10, 1, -1, -1, -1, -1, -1, -1, -1], [4, 9, 7, 2, 5, 1, -1, -1, -1, -1, -1], [4, 2, 5, 1, -1, -1, -1, -1, -1, -1, -1], [4, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1], [4, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1], [4, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1]]


ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (10,) + inhomogeneous part.