In [1]:
import torch.backends.mps

from decoder import AttnDecoderRNN
from encoder import EncoderRNN
from train import *
from evaluate import evaluate
from dataloader import *

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f'using device: {device}')

using device: cuda


In [4]:
models_dir = "models"
model_name = "4_layers_512_hidden_1e-4_lr"
plots_dir = "plots"

teacher_forcing_ratio = 1
lr = 1e-4
hidden_size = 512
n_encoder_layers = 4
n_decoder_layers = 4
dropout = 0.1

start_from_sample = 0
n_samples = 10

print_every = 100
plot_every = 1000
save_every = 1000

input_lang, output_lang, pairs = prepare_data('data/train.en', 'data/train.de', n_samples, start_from_sample)
print(f'number of pairs: {len(pairs)}')

encoder = EncoderRNN(input_lang.n_words, hidden_size, num_layers=n_encoder_layers).to(device)
attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, num_layers=n_decoder_layers, dropout_p=dropout,
                              max_length=MAX_LENGTH).to(device)

encoder.load_state_dict(torch.load(os.path.join(models_dir, model_name, "encoder.pt"), map_location=device))
attn_decoder.load_state_dict(torch.load(os.path.join(models_dir, model_name, "decoder.pt"), map_location=device))

Reading lines...
Normalizing...
Read 10 sentence pairs
Trimmed to 6 sentence pairs
Counting words...
Counted words:
en 50002
de 50002
number of pairs: 6


<All keys matched successfully>

In [13]:
decoded_words, _ = evaluate(encoder, attn_decoder, "you and me", input_lang, output_lang, max_length=MAX_LENGTH, device=device)

In [14]:
' '.join(decoded_words)

'<unk> <unk> <unk> <unk> <unk> <EOS>'

In [7]:
from dataloader import *
n_samples = 100000
start_from_sample = 0
input_lang, output_lang, pairs = prepare_data('data/train.en', 'data/train.de', n_samples, start_from_sample)

Reading lines...
Normalizing...
Read 100000 sentence pairs
Trimmed to 18381 sentence pairs
Counting words...
Counted words:
en 50002
de 50002


In [12]:
def find_all_unk_pairs(pairs):
    n_unks = 0
    for pair in pairs:
        is_unk = False
        for word in pair[0].split(' '):
            if word not in input_lang.word2index:
                is_unk = True
                break
        if not is_unk:
            for word in pair[1].split(' '):
                if word not in output_lang.word2index:
                    is_unk = True
                    break
        if is_unk:
            n_unks += 1

    return n_unks


def find_all_unks(pairs):
    n_unks = 0
    for pair in pairs:
        for word in pair[0].split(' '):
            if word not in input_lang.word2index:
                n_unks += 1
        for word in pair[1].split(' '):
            if word not in output_lang.word2index:
                n_unks += 1

    return n_unks

In [13]:
print(f'number of pairs: {len(pairs)}')
print(f'number of words: {sum([len(pair[0].split(" ")) + len(pair[1].split(" ")) for pair in pairs])}')
print(f'number of unk pairs: {find_all_unk_pairs(pairs)}')
print(f'number of unk words: {find_all_unks(pairs)}')

number of pairs: 18381
number of words: 397916
number of unk pairs: 17609
number of unk words: 66606


In [24]:
import torch
t = torch.tensor([1, 2, 3])
x = t[torch.randperm(t.size(0))]

In [25]:
t, x

(tensor([1, 2, 3]), tensor([2, 1, 3]))