In [5]:
from models import S2SPL
from dataset import en_de_dataset
from config import args

In [6]:

# Initial Dataset
trainset = en_de_dataset(split='train')
# Initialize model parameters
s2s_model = S2SPL(
    vocab_size_encoder=len(trainset.input_vocab),
    emb_dim=args['EMB_DIM'],
    encoder_layers=args['ENCODER_LAYERS'],
    bidirection=args['BIDIRECTION'],
    vocab_size_decoder=len(trainset.target_vocab),
    decoder_layers=args['DECODER_LAYERS'],
    attn_head=args['ATTENTION_HEAD'],
    learning_rate=args['LEARNING_RATE'],
    encoder_type=args['ENCODER_TYPE'],
    decoder_type=args['DECODER_TYPE'],

    input_vocab=trainset.input_vocab,
    target_vocab=trainset.target_vocab,
    input_id_to_word=trainset.id_to_word_input,
    target_id_to_word=trainset.id_to_word_target
)

model = s2s_model.load_from_checkpoint('best/BATCH_SIZE400 LEARNING_RATE0.0001 EPOCHS100 ENCODER_TYPEGRU DECODER_TYPEGRU ATTENTION_HEAD1 ENCODER_LAYERS1 DECODER_LAYERS1 EMB_DIM300 BIDIRECTIONFalse/-epoch=55-valid_loss=1.45.ckpt')
model.eval()

Input vocab size: 9666
Target vocab size: 17790


S2SPL(
  (encoder): RNNEncoder(
    (emb): Embedding(9666, 300)
    (rnn): GRU(300, 300)
  )
  (decoder): RNNDecoder(
    (emb): Embedding(17790, 300)
    (rnn): GRU(300, 300)
    (attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=300, out_features=300, bias=True)
    )
    (decode): Linear(in_features=300, out_features=17790, bias=True)
  )
  (cross_entrophy): CrossEntropyLoss()
)

In [14]:
model.translation('A man walking on the street')

['man', 'walking', 'on', 'the', 'street']
[28, 64, 33, 38, 59]


['mann',
 'der',
 'der',
 'der',
 'der',
 'der',
 'der',
 'der',
 'der',
 'straße',
 'straße',
 'straße',
 'der',
 'der',
 'der',
 'der',
 'der',
 'der',
 'der',
 'der',
 'der']

In [36]:
input_pair, target_pair = trainset.random_pairs()
print((input_pair, target_pair))
model.translation(input_pair)

('Construction workers building a frame for a structure.', 'Bauarbeiter bauen einen Rahmen für eine Konstruktion.')
['construction', 'workers', 'building', 'frame', 'for', 'structure']
[408, 580, 369, 707, 297, 132]


['bauarbeiter',
 'bauen',
 'bauen',
 'bier',
 'einem',
 'einem',
 'markt',
 'markt',
 'markt',
 'markt',
 'markt',
 'markt',
 'markt',
 'markt',
 'markt',
 'markt',
 'markt',
 'markt',
 'reifenschaukel',
 'reifenschaukel',
 'reifenschaukel']