In [1]:
from modeling_lstm_seq2seq import LSTMSeq2Seq, Config
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiplicativeLR
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets import load_dataset
from torchtext.data.utils import get_tokenizer
import torchtext

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
NUM_EPOCHS = 10
LEARNING_RATE = 0.001
BATCH_SIZE = 128

In [3]:
train_ds = load_dataset('wmt16', 'de-en', split='train')
val_ds = load_dataset('wmt16', 'de-en', split='validation')
test_ds = load_dataset('wmt16', 'de-en', split='test')

Found cached dataset wmt16 (C:/Users/gener/.cache/huggingface/datasets/wmt16/de-en/1.0.0/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227)
Found cached dataset wmt16 (C:/Users/gener/.cache/huggingface/datasets/wmt16/de-en/1.0.0/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227)
Found cached dataset wmt16 (C:/Users/gener/.cache/huggingface/datasets/wmt16/de-en/1.0.0/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227)


In [4]:
config = Config(
    input_size=32000,
    embedding_size=1000,
    hidden_size=1000,
    num_layers=4,
    vocab_size=50000,
    dropout=0.2,
    device="cuda",
    max_length=1024,
)

In [5]:
model = LSTMSeq2Seq(config)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [6]:
de_tokenizer = get_tokenizer('spacy', language='de')
en_tokenizer = get_tokenizer('spacy', language='en')



In [7]:
def tokenize_de(text):
    return [tok for tok in de_tokenizer(text)]

def tokenize_en(text):
    return [tok for tok in en_tokenizer(text)]

In [8]:
tokenized_train_eng = [tokenize_en(text['en']) for text in tqdm(train_ds['translation'])]
tokenized_train_ger = [tokenize_de(text['de']) for text in tqdm(train_ds['translation'])]

100%|██████████| 4548885/4548885 [04:49<00:00, 15718.71it/s]
100%|██████████| 4548885/4548885 [13:06<00:00, 5781.03it/s] 


In [9]:
tokenized_val_eng = [tokenize_en(text['en']) for text in tqdm(val_ds['translation'])]
tokenized_val_ger = [tokenize_de(text['de']) for text in tqdm(val_ds['translation'])]

100%|██████████| 2169/2169 [00:02<00:00, 904.36it/s] 
100%|██████████| 2169/2169 [00:00<00:00, 7856.64it/s]


In [10]:
# eng_vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_train_eng, max_tokens = 50000)
# ger_vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_train_ger, max_tokens = 50000)

# create a vocabulary of the training samples for only the top 50000 most common words without torchtext

eng_vocab = {}
ger_vocab = {}

for text in tqdm(tokenized_train_eng):
    for word in text:
        if word in eng_vocab:
            eng_vocab[word] += 1
        else:
            eng_vocab[word] = 1

for text in tqdm(tokenized_train_ger):
    for word in text:
        if word in ger_vocab:
            ger_vocab[word] += 1
        else:
            ger_vocab[word] = 1

100%|██████████| 4548885/4548885 [01:39<00:00, 45571.21it/s]
100%|██████████| 4548885/4548885 [01:30<00:00, 50240.54it/s]


In [11]:
eng_vocab = {k: v for k, v in sorted(eng_vocab.items(), key=lambda item: item[1], reverse=True)}
ger_vocab = {k: v for k, v in sorted(ger_vocab.items(), key=lambda item: item[1], reverse=True)}

In [12]:
eng_vocab = dict(list(eng_vocab.items())[:50000])
ger_vocab = dict(list(ger_vocab.items())[:50000])

eng_vocab = {k: i+2 for i, k in enumerate(eng_vocab.keys())}
ger_vocab = {k: i+2 for i, k in enumerate(ger_vocab.keys())}

eng_vocab['<unk>'] = 0
ger_vocab['<unk>'] = 0

eng_vocab['<eos>'] = 1
ger_vocab['<eos>'] = 1

In [13]:
def encode_eng(text):
    encoded = []
    for token in text:
        try:
            encoded.append(eng_vocab[token])
        except:
            encoded.append(eng_vocab['<unk>'])
    encoded.append(eng_vocab['<eos>'])
    return encoded

def encode_ger(text):
    encoded = []
    for token in text:
        try:
            encoded.append(ger_vocab[token])
        except:
            encoded.append(ger_vocab['<unk>'])
    encoded.append(ger_vocab['<eos>'])
    return encoded

In [14]:
encoded_train_eng = [encode_eng(text) for text in tqdm(tokenized_train_eng)]
encoded_train_ger = [encode_ger(text) for text in tqdm(tokenized_train_ger)]

100%|██████████| 4548885/4548885 [13:33<00:00, 5589.26it/s]  
100%|██████████| 4548885/4548885 [07:08<00:00, 10626.79it/s] 


In [15]:
encoded_val_eng = [encode_eng(text) for text in tqdm(tokenized_val_eng)]
encoded_val_ger = [encode_ger(text) for text in tqdm(tokenized_val_ger)]

100%|██████████| 2169/2169 [00:00<00:00, 5389.40it/s]
100%|██████████| 2169/2169 [00:00<00:00, 490560.55it/s]


In [17]:
tokenized_train_dataloader = DataLoader(list(zip(encoded_train_ger, encoded_train_eng)), batch_size=BATCH_SIZE, shuffle=True)
tokenized_val_dataloader = DataLoader(list(zip(encoded_val_ger, encoded_val_eng)), batch_size=BATCH_SIZE, shuffle=True)

In [16]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    for batch in tqdm(dataloader):
        src = batch[0].to(device)
        trg = batch[1].to(device)
        optimizer.zero_grad()
        output = model(src, trg)
        output = output[1:].view(-1, output.shape[2])
        trg = trg[1:].view(-1)
        loss = criterion(output, trg)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader):
            src = batch[0].to(device)
            trg = batch[1].to(device)
            output = model(src, trg, 0)
            output = output[1:].view(-1, output.shape[2])
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

In [None]:
for epoch in range(NUM_EPOCHS):
    train_loss = train(model, tokenized_train_dataloader, optimizer, criterion, config.device)
    val_loss = evaluate(model, tokenized_val_dataloader, criterion, config.device)
    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {val_loss:.3f}')

: 

: 