In [1]:
import torch

from pathlib import Path

from aevnmt.models import AEVNMT
from aevnmt.train import create_model
from aevnmt.train_utils import load_vocabularies
from aevnmt.hparams import Hyperparameters
from aevnmt.data import create_batch, SOS_TOKEN, EOS_TOKEN, PAD_TOKEN, batch_to_sentences
from aevnmt.components import tile_rnn_hidden, ancestral_sample
from aevnmt.data.textprocessing import Pipeline, Detokenizer, Recaser, WordDesegmenter

### Restore the model

In [2]:
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device_name)

# Load the original hyperparameters.
model_dir = Path("./flickr/models/aevnmt/supervised/")
hparams = Hyperparameters(check_required=False)
hparams_file = model_dir / "hparams"
hparams.update_from_file(hparams_file, override=False)

# Load the vocabularies.
hparams.vocab_prefix = model_dir / "vocab"
hparams.share_vocab = False
vocab_de, vocab_en = load_vocabularies(hparams)

# Restore the model.
model, _, _, translate = create_model(hparams, vocab_de, vocab_en)
model.load_state_dict(torch.load(model_dir / "model/bleu/de-en.pt", map_location=device_name))
model = model.to(device)
print(model)

AEVNMT(
  (encoder): RNNEncoder(
    (rnn): LSTM(256, 256, batch_first=True, bidirectional=True)
  )
  (decoder): BahdanauDecoder(
    (rnn): LSTM(768, 256, batch_first=True)
    (dropout_layer): Dropout(p=0.5)
    (pre_output_layer): Linear(in_features=1024, out_features=256, bias=True)
    (attention): BahdanauAttention(
      (key_layer): Linear(in_features=512, out_features=256, bias=False)
      (query_layer): Linear(in_features=256, out_features=256, bias=False)
      (scores_layer): Linear(in_features=256, out_features=1, bias=False)
    )
  )
  (language_model): RNNLM(
    (embedder): Embedding(9726, 256, padding_idx=1)
    (rnn): LSTM(256, 256, batch_first=True)
    (dropout_layer): Dropout(p=0.5)
  )
  (tgt_embedder): Embedding(9726, 256, padding_idx=1)
  (dropout_layer): Dropout(p=0.5)
  (encoder_init_layer): Sequential(
    (0): Linear(in_features=32, out_features=256, bias=True)
    (1): Tanh()
  )
  (decoder_init_layer): Sequential(
    (0): Linear(in_features=32, out_fea

### Sample source sentences from the latent space (greedy decoding)

In [3]:
num_samples = 5

prior = model.prior()
z = prior.sample(sample_shape=[num_samples])

# Construct the LM initial hidden state and inputs.
hidden_lm = tile_rnn_hidden(model.lm_init_layer(z), model.language_model.rnn)
x_init = z.new([vocab_de[SOS_TOKEN] for _ in range(num_samples)]).long()
x_embed = model.language_model.embedder(x_init)

# Keep track of model samples.
x_samples = [x_init.unsqueeze(-1)] # List of [num_samples, 1] integers.

# Sample num_samples source sentences conditioned on z.
for _ in range(hparams.max_decoding_length):
    hidden_lm, logits = model.language_model.step(x_embed, hidden_lm)
    next_word_dist = torch.distributions.categorical.Categorical(logits=logits)
    x = next_word_dist.sample()
    x_embed = model.language_model.embedder(x.squeeze())
    x_samples.append(x)

# Concatenate the samples and convert to sentences.
x_samples = torch.cat(x_samples, dim=-1)
x_samples = batch_to_sentences(x_samples, vocab_de)

# Construct a post-processing pipeline for German.
postprocess = [Detokenizer("de"),
               Recaser("de"), 
               WordDesegmenter(separator=hparams.subword_token)] # Executed in reverse order.
pipeline_de = Pipeline(pre=[], post=postprocess)

# Print the samples.
pp_x_samples = [pipeline_de.post(x) for x in x_samples]
for idx, x in enumerate(pp_x_samples, 1): print(f"{idx}: {x}")

1: Zwei hunde rennen in einer wiese während new york.
2: Ein kleiner junge liegt auf einem holzzaun über dem strand.
3: Diese menschen lesen verschiedene pinsame tag etwas in der pfanne.
4: Ein obdachloser mann sitzt auf einem hocker vor einer skulptur und liest seinen fahrrad.
5: Ein mann mit stock, der den kamm in seinem fahrrad fährt.


### Sample translations from the approximate posterior (greedy decoding)

In [4]:
num_samples = 5
x_samples = ["in kleines blondes mädchen hält ein sandwich ."] * num_samples
x_in, _, seq_mask_x, seq_len_x = create_batch(x_samples, vocab_de, device)

# Infer q(z|x). 
qz = model.approximate_posterior(x_in, seq_mask_x, seq_len_x)
z = qz.sample()

# Encode the source sentences
encoder_outputs, encoder_final = model.encode(x_in, seq_len_x, z)

# Create the initial hidden state of the TM.
hidden_tm = model.init_decoder(encoder_outputs, encoder_final, z)

# Sample target sentences conditional on the source and z.
y_samples = ancestral_sample(model.decoder,
                             model.tgt_embed,
                             model.generate,
                             hidden_tm,
                             encoder_outputs, encoder_final,
                             seq_mask_x,
                             vocab_en[SOS_TOKEN],
                             vocab_en[EOS_TOKEN],
                             vocab_en[PAD_TOKEN],
                             hparams.max_decoding_length,
                             greedy=False)["sample"]
y_samples = batch_to_sentences(y_samples, vocab_en)

# Construct a post-processing pipeline for English.
postprocess = [Detokenizer("en"),
               Recaser("en"), 
               WordDesegmenter(separator=hparams.subword_token)] # Executed in reverse order.
pipeline_en = Pipeline(pre=[], post=postprocess)

# Print the samples.
pp_y_samples = [pipeline_en.post(y) for y in y_samples]
for idx, y in enumerate(pp_y_samples, 1): print(f"{idx}: {y}")

1: Young blond children, holding a sandwich.
2: In pink and little girls holding a pair of sandwich.
3: Shirt to short blond-hair, holding a sandwich
4: A little girl with blond-hair is holding a sandwich.
5: At young blond girl holding a sandwich.
