In [1]:
path_dev_X = "europarl/dev.en"
path_dev_Y = "europarl/dev.et"
path_train_X = "europarl/train.en"
path_train_Y = "europarl/train.et"

In [2]:
import allennlp

# Read instances

In [3]:
from allennlp.data import DatasetReader

In [4]:
from typing import Dict

from overrides import overrides

from allennlp.common.checks import ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.common.util import START_SYMBOL, END_SYMBOL
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import TextField
from allennlp.data.instance import Instance
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer

class MonolingualDatasetReader(DatasetReader):
    def __init__(self, lazy: bool = False, max_sent_len = 50) -> None:
        super().__init__(lazy)
        self._sentence_tokenizer = WordTokenizer(word_splitter=JustSpacesWordSplitter())
        self._sentence_token_indexers = {"tokens": SingleIdTokenIndexer()}
        self._sentence_add_start_token = True
        self._max_sent_len = max_sent_len

    @overrides
    def _read(self, file_path):
        with open(cached_path(file_path), "r") as data_file:
            print("Reading instances from lines in file at: %s", file_path)
            for line_num, line in enumerate(data_file):
                line = line.strip("\n")

                if not line:
                    continue
                
                line = line.lower()
                tokenized_sentence = self._sentence_tokenizer.tokenize(line)
                if len(tokenized_sentence) > self._max_sent_len:
                    continue

                yield self.text_to_instance(tokenized_sentence)

    @overrides
    def text_to_instance(self, tokenized_sentence) -> Instance:  # type: ignore
        # pylint: disable=arguments-differ
        if self._sentence_add_start_token:
            tokenized_sentence.insert(0, Token(START_SYMBOL))
        tokenized_sentence.append(Token(END_SYMBOL))
        sentence_field = TextField(tokenized_sentence, self._sentence_token_indexers)
        
        return Instance({'sentence': sentence_field})

In [5]:
mono_dataset_reader = MonolingualDatasetReader(lazy=False) 

In [6]:
en_instances_train = mono_dataset_reader.read(path_train_X)
en_instances_dev = mono_dataset_reader.read(path_dev_X)

et_instances_train = mono_dataset_reader.read(path_train_Y)
et_instances_dev = mono_dataset_reader.read(path_dev_Y)

3036it [00:00, 15171.04it/s]

Reading instances from lines in file at: %s europarl/train.en


9573it [00:00, 11584.88it/s]
500it [00:00, 22049.75it/s]
1959it [00:00, 10755.50it/s]

Reading instances from lines in file at: %s europarl/dev.en
Reading instances from lines in file at: %s europarl/train.et


9920it [00:00, 14122.22it/s]
500it [00:00, 20935.09it/s]

Reading instances from lines in file at: %s europarl/dev.et





# Create vocabs

In [7]:
from allennlp.data import Vocabulary

In [8]:
en_vocab = Vocabulary.from_instances(instances=en_instances_train + en_instances_dev, max_vocab_size=20000)
et_vocab = Vocabulary.from_instances(instances=et_instances_train + et_instances_dev, max_vocab_size=20000)

100%|██████████| 10073/10073 [00:00<00:00, 37082.52it/s]
100%|██████████| 10420/10420 [00:00<00:00, 47731.88it/s]


# Create iterators

In [9]:
#from allennlp.data.iterators import BucketIterator
from allennlp.data.iterators import BasicIterator

In [10]:
# en_iterator_creator = BucketIterator(sorting_keys = [("sentence", "num_tokens")], batch_size=32, max_instances_in_memory=None)
# et_iterator_creator = BucketIterator(sorting_keys = [("sentence", "num_tokens")], batch_size=32, max_instances_in_memory=None)

en_iterator_creator = BasicIterator(batch_size=32, max_instances_in_memory=None)
et_iterator_creator = BasicIterator(batch_size=32, max_instances_in_memory=None)

In [11]:
en_iterator_creator.index_with(en_vocab)
et_iterator_creator.index_with(et_vocab)

In [12]:
en_batch_iterator_train = en_iterator_creator(instances=en_instances_train, num_epochs=None, shuffle=False)
en_batch_iterator_dev = en_iterator_creator(instances=en_instances_dev, num_epochs=None, shuffle=False)

et_batch_iterator_train = et_iterator_creator(instances=et_instances_train, num_epochs=None, shuffle=False)
et_batch_iterator_dev = et_iterator_creator(instances=et_instances_dev, num_epochs=None, shuffle=False)

# Create models 

## Create generators

In [77]:
from torch.nn.functional import gumbel_softmax

In [211]:
def get_next_batch_mask(batch_iterator, embedding):
    sampled_batch = batch_iterator.__next__()
    sampled_indeces = sampled_batch['sentence']['tokens']
    embedded_tokens = embedding.forward(sampled_indeces)
    mask = sampled_indeces != 0
    
    return embedded_tokens, mask
    

In [272]:
from typing import Dict

import numpy
from overrides import overrides

import torch
from torch.nn.modules.rnn import LSTMCell
from torch.nn.modules.linear import Linear
import torch.nn.functional as F

from allennlp.common.util import START_SYMBOL, END_SYMBOL
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder
from allennlp.modules.attention import DotProductAttention
from allennlp.modules.token_embedders import Embedding, TokenEmbedder
from allennlp.models.model import Model
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits, weighted_sum, get_final_encoder_states


class VanillaRnn2Rnn(Model):
    """
    Returns predicted indeces  
    """
    def __init__(self,
                 source_vocab: Vocabulary,
                 target_vocab: Vocabulary,
                 source_embedding: Embedding,
                 target_embedding: Embedding,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int
                 ) -> None:
        
        super(VanillaRnn2Rnn, self).__init__(source_vocab)
        
        self._source_vocab = source_vocab
        self._target_vocab = target_vocab
        
        self._source_embedder = source_embedding
        self._target_embedder = target_embedding
        
        self._encoder = encoder
        
        self._max_decoding_steps = max_decoding_steps
         
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self._target_vocab.get_token_index(START_SYMBOL, "tokens")
        self._end_index = self._target_vocab.get_token_index(END_SYMBOL, "tokens")
        
        num_classes = self._target_vocab.get_vocab_size("tokens")
        
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder. Also, if
        # we're using attention with ``DotProductSimilarity``, this is needed.
        self._decoder_output_dim = self._encoder.get_output_dim()
        
        target_embedding_dim = self._target_embedder.get_output_dim()    
            
        self._decoder_attention = DotProductAttention()
        # The output of attention, a weighted average over encoder outputs, will be
        # concatenated to the input vector of the decoder at each time step.
        self._decoder_input_dim = self._encoder.get_output_dim() + target_embedding_dim

        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)
        
        self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)

    @overrides
    def forward(self,  # type: ignore
                embedded_input: torch.FloatTensor, source_mask: torch.LongTensor) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the entire target sequence.
        Parameters
        ----------
        sentence : torch.LongTensor
           Tensor of padded batch of indexed source strings
        """
        # (batch_size, input_sequence_length, encoder_output_dim)
        #print(sentence)
        #tokens_ids = sentence["tokens"]
        #embedded_input = self._source_embedder(tokens_ids)
        batch_size, _, _ = embedded_input.size()
        #source_mask = get_text_field_mask(sentence)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        final_encoder_output = get_final_encoder_states(encoder_outputs, source_mask, True)  # (batch_size, encoder_output_dim)
        
        num_decoding_steps = self._max_decoding_steps
        
        decoder_hidden = final_encoder_output
        
        decoder_context = encoder_outputs.new_zeros(batch_size, self._decoder_output_dim)
        
        last_predictions = None
        step_logits = []
        step_probabilities = []
        step_predictions = []
        step_predictions_softmax = []
        
        step_embedded_outputs = []
        for timestep in range(num_decoding_steps):
            if timestep == 0:
                # For the first timestep, when we do not have targets, we input start symbols.
                # (batch_size,)
                input_indices = source_mask.new_full((batch_size,), fill_value=self._start_index)
            else:
                input_indices = last_predictions # TODO: SHOULD PASS VECTORS DIRECTLY HERE MAYBE. OPTINALLY REQ_GRAD_FALSE/TRUE. THINK ABOUT IT
            decoder_input = self._prepare_decode_step_input(input_indices, decoder_hidden,
                                                            encoder_outputs, source_mask)
            decoder_hidden, decoder_context = self._decoder_cell(decoder_input,
                                                                 (decoder_hidden, decoder_context))
            
            # (batch_size, num_classes)
            output_projections = self._output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
            
            # GUMBEL SOFTMAX
            
            projection_type = 'gumbel'
            if projection_type == 'gumbel':
                token_types_weights = gumbel_softmax(logits=output_projections, hard=True, tau=0.0000000001)
            elif projection_type == 'softmax':
                token_types_weights = F.softmax(output_projections, dim=-1)
            elif projection_type == 'direct':
                raise NotImplementedError("direct vector prediction is not supportet at the moment")
            else:
                raise NotImplementedError("wrong projection type value; possible are [gumbel | softmax | direct]")
            
                # predicted_classes = token_types_weights.nonzero()[:,1] #TODO MAX AND ETC
            softmax_classes = torch.argmax(output_projections, dim=-1).long()
            predicted_classes = torch.argmax(token_types_weights, dim=-1).long()
            # (batch_size, 1)
            last_predictions = predicted_classes # should be differentiable
            
            step_predictions.append(predicted_classes.unsqueeze(1))
            step_predictions_softmax.append(softmax_classes.unsqueeze(1))

            # TODO: COMPUTE MASK BASED ON ARGMAX(LOGITS) -> INDECES => padding/end? -> MASK
            # WE SHOULD USE ARGMAX TO FIND MASK BECAUSE IF VECTOR THAT CONTRIBUTED THE MOST IS END VECTOR WE TREAT IT 
            # THE WHOLE CONTRIBUTION AS END SYMBOL
            
            # PASS THIS MASK TOGETHER WITH RESULT
            # POSSIBLY BY CREATING THE TEXT FIELD WITH NAME "mask" (used in get_text_field_mask)
            # THIS WAY CHANGE EMBEDDINGS LAYERS WITH TEXTFILELD EMBEDDER
            # OR JUST PASS IT ASS SEPARATE LONGTESOR
            
            embedded_output_tokens = token_types_weights.matmul(self._target_embedder.weight)
            step_embedded_outputs.append(embedded_output_tokens.unsqueeze(1))

        # step_logits is a list containing tensors of shape (batch_size, 1, num_classes)
        # This is (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        all_predictions = torch.cat(step_predictions, 1)
        all_predictions_softmax = torch.cat(step_predictions_softmax, 1)
        all_embedded_output_tokens = torch.cat(step_embedded_outputs, 1)
        
        # finilize predictions and compute mask
        predicted_tokens_softmax, predicted_indices_softmax = self._finize_predictions(all_predictions_softmax)
        predicted_tokens, predicted_indices = self._finize_predictions(all_predictions)
        
        #self._get_mask_from_indices(predicted_indices_softmax)
        mask = self._get_mask_from_indices(predicted_indices)

        # trim embedded output tokens to match the mask
        maxlen = fake_et["mask"].size()[1]
        all_embedded_output_tokens = all_embedded_output_tokens[:, :maxlen, :] 
        
        output_dict = {"logits": logits,
                       "predicted_tokens": predicted_tokens, # this can be gumbel or softamx predictions
                       "predicted_tokens_softmax": predicted_tokens_softmax, # this is softmax predictions
                       "embedded_output_tokens": all_embedded_output_tokens,
                       "mask": mask}
        
        #self._target_vocab. all_predictions
        
        return output_dict

    def _prepare_decode_step_input(self,
                                   input_indices: torch.LongTensor,
                                   decoder_hidden_state: torch.LongTensor = None,
                                   encoder_outputs: torch.LongTensor = None,
                                   encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
        """
        Given the input indices for the current timestep of the decoder, and all the encoder
        outputs, compute the input at the current timestep.  Note: This method is agnostic to
        whether the indices are gold indices or the predictions made by the decoder at the last
        timestep. So, this can be used even if we're doing some kind of scheduled sampling.
        If we're not using attention, the output of this method is just an embedding of the input
        indices.  If we are, the output will be a concatentation of the embedding and an attended
        average of the encoder inputs.
        Parameters
        ----------
        input_indices : torch.LongTensor
            Indices of either the gold inputs to the decoder or the predicted labels from the
            previous timestep.
        decoder_hidden_state : torch.LongTensor, optional (not needed if no attention)
            Output of from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        input_indices = input_indices.long() # TODO: SHOULD I PASS EMBEDDED INPUT DICRECTLY TO MAKE IT DIFFERENTIABLE???
        # input_indices : (batch_size,)  since we are processing these one timestep at a time.
        # (batch_size, target_embedding_dim)
        embedded_input = self._target_embedder(input_indices) # this should be sperate func that work with different
        # forms of in
        
        # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
        # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
        # complain.
        encoder_outputs_mask = encoder_outputs_mask.float()
        # (batch_size, input_sequence_length)
        input_weights = self._decoder_attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
        # (batch_size, encoder_output_dim)
        attended_input = weighted_sum(encoder_outputs, input_weights)
        # (batch_size, encoder_output_dim + target_embedding_dim)
        return torch.cat((attended_input, embedded_input), -1)
    
    def _finize_predictions(self, predicted_indices):
        "strips ids till the first end symbol "
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        all_predicted_indices = []
        for indices in predicted_indices:
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [self._target_vocab.get_token_from_index(x, namespace="tokens")
                                for x in indices]
            all_predicted_tokens.append(predicted_tokens)
            all_predicted_indices.append(indices)
        
        return all_predicted_tokens, all_predicted_indices

    def _get_mask_from_indices(self, all_predicted_indices):
        bs = len(all_predicted_indices)
        lens = [len(l) for l in all_predicted_indices]
        maxlen = max(lens)
        print('Generating mask for fake batch. Len of batch =', maxlen)
        
        arr = numpy.zeros((bs,maxlen),int)
        mask = numpy.arange(maxlen) < numpy.array(lens)[:,None] # key line
        mask = torch.from_numpy(mask.astype(int)).long()
        if torch.cuda.is_available(): # optionaly move mask to GPU
            mask = mask.cuda()
        return mask
        
    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        all_predicted_indices = []
        for indices in predicted_indices:
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [self._target_vocab.get_token_from_index(x, namespace="tokens")
                                for x in indices]
            all_predicted_tokens.append(predicted_tokens)
            all_predicted_indices.append(indices)
        output_dict["predicted_tokens"] = all_predicted_tokens
        
        output_dict["predicted_indices"] = all_predicted_indices  
        
        return output_dict

#     @staticmethod
#     def _get_loss(logits: torch.LongTensor,
#                   targets: torch.LongTensor,
#                   target_mask: torch.LongTensor,
#                   label_smoothing) -> torch.LongTensor:
#         """
#         Takes logits (unnormalized outputs from the decoder) of size (batch_size,
#         num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
#         and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
#         entropy loss while taking the mask into account.
#         The length of ``targets`` is expected to be greater than that of ``logits`` because the
#         decoder does not need to compute the output corresponding to the last timestep of
#         ``targets``. This method aligns the inputs appropriately to compute the loss.
#         During training, we want the logit corresponding to timestep i to be similar to the target
#         token from timestep i + 1. That is, the targets should be shifted by one timestep for
#         appropriate comparison.  Consider a single example where the target has 3 words, and
#         padding is to 7 tokens.
#            The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
#            and the mask would be                     1   1   1   1   1   0   0
#            and let the logits be                     l1  l2  l3  l4  l5  l6
#         We actually need to compare:
#            the sequence           w1  w2  w3  <E> <P> <P>
#            with masks             1   1   1   1   0   0
#            against                l1  l2  l3  l4  l5  l6
#            (where the input was)  <S> w1  w2  w3  <E> <P>
#         """
#         relevant_targets = targets[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
#         relevant_mask = target_mask[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
#         loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask,                                                           label_smoothing = label_smoothing)
#         return loss


### Create embeddings

In [80]:
from allennlp.common import Params

In [81]:
en_emb_params = Params({"vocab_namespace": "tokens",
                      "embedding_dim": 300,
                      "pretrained_file": None,
                      "trainable": True
                      })

et_emb_params =  en_emb_params.duplicate()

en_embedding = Embedding.from_params(vocab=en_vocab, params=en_emb_params)
et_embedding = Embedding.from_params(vocab=et_vocab, params=et_emb_params)

In [273]:
seq2seq_lstm_params_en = Params({
    "type": "lstm",
    "num_layers": 1,
    "bidirectional": True,
    "input_size": 300,
    "hidden_size": 600
})

seq2seq_lstm_params_et = seq2seq_lstm_params_en.duplicate()

et2en_model = VanillaRnn2Rnn(source_vocab = et_vocab,
                            target_vocab = en_vocab,
                            source_embedding = et_embedding,
                            target_embedding=en_embedding,
                            encoder=Seq2SeqEncoder.from_params(params=seq2seq_lstm_params_et),
                            max_decoding_steps=50)


en2et_model = VanillaRnn2Rnn(source_vocab = en_vocab,
                            target_vocab = et_vocab,
                            source_embedding = en_embedding,
                            target_embedding=et_embedding,
                            encoder=Seq2SeqEncoder.from_params(params=seq2seq_lstm_params_en),
                            max_decoding_steps=50)



In [274]:
embedded_input_et, input_mask_et  = get_next_batch_mask(et_batch_iterator_dev, et_embedding)

In [275]:
fake_et = en2et_model.forward(embedded_input=embedded_input_et, source_mask=input_mask_et)

Generating mask for fake batch. Len of batch = 50


In [279]:
fake_en = et2en_model.forward(fake_et["embedded_output_tokens"], fake_et["mask"])

Generating mask for fake batch. Len of batch = 50


In [282]:
def print_graph(g, level=0):
    if g == None: return
    print('*'*level*4, g)
    for subg in g.next_functions:
        print_graph(subg[0], level+1)

In [281]:
fake_en['embedded_output_tokens'].grad_fn.

<AliasBackward at 0x7f171f9632b0>

In [271]:
fake_et.keys()

dict_keys(['logits', 'predicted_tokens', 'predicted_tokens_softmax', 'embedded_output_tokens', 'mask'])

In [237]:
fake_et

{'logits': tensor([[[ 0.0041, -0.0117, -0.0087,  ..., -0.0034,  0.0028,  0.0197],
          [ 0.0012, -0.0140, -0.0054,  ..., -0.0054,  0.0010,  0.0201],
          [ 0.0003, -0.0145, -0.0037,  ..., -0.0069, -0.0003,  0.0203],
          ...,
          [-0.0014, -0.0157, -0.0020,  ..., -0.0087, -0.0010,  0.0178],
          [-0.0010, -0.0148, -0.0018,  ..., -0.0093, -0.0010,  0.0179],
          [-0.0012, -0.0157, -0.0008,  ..., -0.0100, -0.0012,  0.0185]],
 
         [[ 0.0039, -0.0115, -0.0085,  ..., -0.0037,  0.0032,  0.0196],
          [ 0.0013, -0.0131, -0.0057,  ..., -0.0060,  0.0010,  0.0203],
          [ 0.0005, -0.0147, -0.0046,  ..., -0.0066,  0.0001,  0.0205],
          ...,
          [-0.0017, -0.0160, -0.0002,  ..., -0.0089, -0.0002,  0.0179],
          [-0.0018, -0.0156, -0.0012,  ..., -0.0089, -0.0000,  0.0185],
          [-0.0016, -0.0159, -0.0004,  ..., -0.0096, -0.0004,  0.0190]],
 
         [[ 0.0040, -0.0117, -0.0087,  ..., -0.0034,  0.0029,  0.0196],
          [ 0.0018

In [84]:
#en2et_model.decode(en2et_model.forward(**batch1))

## Create discriminators

In [85]:
from allennlp.modules import Seq2VecEncoder

In [86]:
class Seq2Binary(Model):
    """
    Logistic regression on sentence.
    """
    def __init__(self,
                vocab: Vocabulary,
                embedding: Embedding,
                seq2vec_encoder: Seq2VecEncoder):
    
        super(Seq2Binary, self).__init__(vocab)
        
        self._embedding = embedding
        self._encoder = seq2vec_encoder
        self._projection_layer = Linear(self._encoder.get_output_dim(), 1)
        
        
    @overrides
    def forward(self,  # type: ignore
                sentence: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the entire target sequence.
        Parameters
        ----------
        sentence : torch.LongTensor
           Tensor of padded batch of indexed source strings
        """
        # (batch_size, input_sequence_length, encoder_output_dim)
        #print(sentence)
        tokens_ids = sentence["tokens"]
        embedded_input = self._embedding(tokens_ids)
        batch_size, _, _ = embedded_input.size()
        source_mask = get_text_field_mask(sentence)
        final_encoder_output = self._encoder(embedded_input, source_mask)
        logits = self._projection_layer(final_encoder_output)
        probs = torch.sigmoid(logits)
        
        return {"probs": probs}

In [87]:
classifier_params_en = Params({
                        "type": "lstm",
                        "bidirectional": True,
                        "num_layers": 1,
                        "input_size": 300,
                        "hidden_size": 600
                        })

classifier_params_et = classifier_params_en.duplicate()

en_classifier = Seq2Binary(vocab=en_vocab, embedding=en_embedding, seq2vec_encoder=Seq2VecEncoder.from_params(classifier_params_en))
et_classifier = Seq2Binary(vocab=et_vocab, embedding=et_embedding, seq2vec_encoder=Seq2VecEncoder.from_params(classifier_params_et))

In [88]:
#et_classifier(**batch1)

In [31]:
fake_et = en2et_model.decode(en2et_model.forward(**batch1))["predicted_indices"]

In [28]:
fake_et = prepare_fake_batch(fake_et)

In [None]:
embedded_input_et

In [243]:
fake_et["embedded_output_tokens"].size(), fake_et["mask"].size()

(torch.Size([32, 50, 300]), torch.Size([32, 50]))

In [415]:
a = et_classifier(**fake_et)["probs"]

TypeError: forward() got an unexpected keyword argument 'logits'

In [107]:
len(batch1["sentence"]["tokens"])

32