In [6]:
# from modeling_lstm_seq2seq import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiplicativeLR
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets import load_dataset
from torchtext.data.utils import get_tokenizer
import torchtext

In [7]:
# import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [8]:
NUM_EPOCHS = 1
LEARNING_RATE = 0.001
BATCH_SIZE = 32
MAX_LENGTH = 128

LSTM_LOCAL_PATH = "local_lstm"
LSTM_GLOBAL_PATH = "global_lstm"
LSTM_PATH = "classic_lstm"

In [9]:
train_ds = load_dataset('wmt16', 'de-en', split='train[:1%]')
val_ds = load_dataset('wmt16', 'de-en', split='validation')
test_ds = load_dataset('wmt16', 'de-en', split='test')

Downloading builder script:   0%|          | 0.00/2.81k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/18.6k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.89k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/41.4k [00:00<?, ?B/s]

Downloading and preparing dataset wmt16/de-en to /root/.cache/huggingface/datasets/wmt16/de-en/1.0.0/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227...


Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/658M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/919M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/75.2M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/38.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/5 [00:00<?, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split:   0%|          | 0/4548885 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2169 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2999 [00:00<?, ? examples/s]

Dataset wmt16 downloaded and prepared to /root/.cache/huggingface/datasets/wmt16/de-en/1.0.0/746749a11d25c02058042da7502d973ff410e73457f3d305fc1177dc0e8c4227. Subsequent calls will reuse this data.




In [10]:
new = train_ds[:len(train_ds)//2]

In [11]:
train_ds = train_ds.train_test_split(test_size = 0.5)["train"]

In [12]:
de_tokenizer = get_tokenizer('spacy', language='de')
en_tokenizer = get_tokenizer('spacy', language='en')



In [13]:
def tokenize_de(text):
    return [tok for tok in de_tokenizer(text)]

def tokenize_en(text):
    return [tok for tok in en_tokenizer(text)]

In [14]:
tokenized_train_eng = [tokenize_en(text['en']) for text in tqdm(train_ds['translation'])]
tokenized_train_ger = [tokenize_de(text['de']) for text in tqdm(train_ds['translation'])]

100%|██████████| 22744/22744 [00:02<00:00, 8989.16it/s] 
100%|██████████| 22744/22744 [00:04<00:00, 5520.67it/s]


In [15]:
tokenized_val_eng = [tokenize_en(text['en']) for text in tqdm(val_ds['translation'])]
tokenized_val_ger = [tokenize_de(text['de']) for text in tqdm(val_ds['translation'])]

100%|██████████| 2169/2169 [00:00<00:00, 5198.62it/s]
100%|██████████| 2169/2169 [00:00<00:00, 3848.49it/s]


In [16]:
# create a vocabulary of the training samples for only the top 50000 most common words without torchtext

eng_vocab = {}
ger_vocab = {}

for text in tqdm(tokenized_train_eng):
    for word in text:
        if word in eng_vocab:
            eng_vocab[word] += 1
        else:
            eng_vocab[word] = 1

for text in tqdm(tokenized_train_ger):
    for word in text:
        if word in ger_vocab:
            ger_vocab[word] += 1
        else:
            ger_vocab[word] = 1

100%|██████████| 22744/22744 [00:00<00:00, 65684.64it/s]
100%|██████████| 22744/22744 [00:00<00:00, 70932.38it/s]


In [17]:
eng_vocab = {k: v for k, v in sorted(eng_vocab.items(), key=lambda item: item[1], reverse=True)}
ger_vocab = {k: v for k, v in sorted(ger_vocab.items(), key=lambda item: item[1], reverse=True)}

In [18]:
eng_vocab = dict(list(eng_vocab.items())[:50000])
ger_vocab = dict(list(ger_vocab.items())[:50000])

eng_vocab = {k: i+2 for i, k in enumerate(eng_vocab.keys())}
ger_vocab = {k: i+2 for i, k in enumerate(ger_vocab.keys())}

eng_vocab['<unk>'] = 0
ger_vocab['<unk>'] = 0

eng_vocab['<eos>'] = 1
ger_vocab['<eos>'] = 1

eng_vocab['<pad>'] = 0
ger_vocab['<pad>'] = 0

In [19]:
### pad the sequences to the same length

def pad_seq(seq, max_length):
    if len(seq)>max_length:
        return seq[:max_length-1] + ['<eos>']
    elif len(seq) == max_length:
        return seq
    seq += ["<pad>" for i in range(max_length - len(seq))]
    return seq

tokenized_train_eng = [pad_seq(text + ['<eos>'], MAX_LENGTH) for text in tqdm(tokenized_train_eng)]
tokenized_train_ger = [pad_seq(text + ['<eos>'], MAX_LENGTH) for text in tqdm(tokenized_train_ger)]

tokenized_val_eng = [pad_seq(text + ['<eos>'], MAX_LENGTH) for text in tqdm(tokenized_val_eng)]
tokenized_val_ger = [pad_seq(text + ['<eos>'], MAX_LENGTH) for text in tqdm(tokenized_val_ger)]

100%|██████████| 22744/22744 [00:00<00:00, 159231.73it/s]
100%|██████████| 22744/22744 [00:00<00:00, 137194.01it/s]
100%|██████████| 2169/2169 [00:00<00:00, 131473.57it/s]
100%|██████████| 2169/2169 [00:00<00:00, 142490.45it/s]


In [20]:
def encode_eng(text):
    encoded = []
    for token in text:
        try:
            encoded.append(eng_vocab[token])
        except:
            encoded.append(eng_vocab['<unk>'])
    return encoded

def encode_ger(text):
    encoded = []
    for token in text:
        try:
            encoded.append(ger_vocab[token])
        except:
            encoded.append(ger_vocab['<unk>'])
    return encoded

In [21]:
encoded_train_eng = [encode_eng(text) for text in tqdm(tokenized_train_eng)]
encoded_train_ger = [encode_ger(text) for text in tqdm(tokenized_train_ger)]

100%|██████████| 22744/22744 [00:00<00:00, 53528.57it/s]
100%|██████████| 22744/22744 [00:00<00:00, 49663.94it/s]


In [22]:
encoded_val_eng = [encode_eng(text) for text in tqdm(tokenized_val_eng)]
encoded_val_ger = [encode_ger(text) for text in tqdm(tokenized_val_ger)]

100%|██████████| 2169/2169 [00:00<00:00, 55206.63it/s]
100%|██████████| 2169/2169 [00:00<00:00, 50435.45it/s]


In [23]:
tokenized_train_dataloader = DataLoader(list(zip(encoded_train_eng, encoded_train_ger)), batch_size=BATCH_SIZE, shuffle=True)
tokenized_val_dataloader = DataLoader(list(zip(encoded_val_eng, encoded_val_ger)), batch_size=BATCH_SIZE, shuffle=True)

In [24]:
config = Config(
    input_size=len(ger_vocab),
    embedding_size=1000,
    hidden_size=1000,
    num_layers=4,
    vocab_size=len(ger_vocab),
    dropout=0.2,
    device="cuda",
    max_length=MAX_LENGTH,
)

In [25]:
model = LSTMSeq2Seq(config, attention = True, alignment = "global", scoring_function = "general")
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [26]:
model.to(config.device)

LSTMSeq2Seq(
  (encoder): LSTMEncoder(
    (embedding): Embedding(35353, 1000)
    (lstm): LSTM(1000, 1000, num_layers=4, batch_first=True, dropout=0.2)
  )
  (decoder): LSTMDecoder(
    (embedding): Embedding(35353, 1000)
    (lstm): LSTM(1000, 1000, num_layers=4, batch_first=True, dropout=0.2)
    (attention): Attention(
      (W): Linear(in_features=1000, out_features=1000, bias=False)
    )
  )
  (lm_head): Linear(in_features=2000, out_features=35353, bias=True)
)

In [29]:
import evaluate
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")

Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

In [30]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    for batch in tqdm(dataloader):
        src = torch.stack(batch[0]).to(torch.int64).to(device)
        trg = torch.stack(batch[1]).to(torch.int64).to(device)
        src = src.transpose(0, 1)
        trg = trg.transpose(0, 1)

        optimizer.zero_grad()
        
        output = model(src, trg)
        # output = output[1:].view(-1, output.shape[2])
        # trg = trg[1:].view(-1)
        # output = output.max(dim = 2)

        output = output.permute(0, 2, 1)

        loss = criterion(output, trg)

        for name, param in model.named_parameters():
            param.retain_grad()
            # print(name, param.grad)

        loss.backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()

        epoch_loss += loss.item()
    return model, optimizer, epoch_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader):
            src = torch.stack(batch[0]).to(torch.int64).to(device)
            trg = torch.stack(batch[1]).to(torch.int64).to(device)
            output = model(src, trg, 0)
            # output = output[1:].view(-1, output.shape[2])
            # trg = trg[1:].view(-1)
            output = output.permute(0, 2, 1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

In [31]:
for epoch in range(NUM_EPOCHS):
    model, optimizer, train_loss = train(model, tokenized_train_dataloader, optimizer, criterion, config.device)
    val_loss = evaluate(model, tokenized_val_dataloader, criterion, config.device)
    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {val_loss:.3f}')

100%|██████████| 711/711 [39:28<00:00,  3.33s/it]
100%|██████████| 68/68 [00:32<00:00,  2.11it/s]

Epoch: 01
	Train Loss: 10.410
	 Val. Loss: 10.406





In [33]:
import pickle

In [34]:
torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    },
    LSTM_GLOBAL_PATH + "/model.pt"
)
with open(LSTM_GLOBAL_PATH + "/eng_vocab.pkl", "wb+") as f:
    pickle.dump(eng_vocab, f)

with open(LSTM_GLOBAL_PATH + "/ger_vocab.pkl", "wb+") as f:
    pickle.dump(ger_vocab, f)

In [None]:
# model = LSTMSeq2Seq(config, attention = True, alignment = "local-m", scoring_function = "dot")
# ckpt = torch.load(LSTM_LOCAL_PATH + "/model.pt")
# model.load_state_dict(ckpt['model_state_dict'])

In [37]:
model.eval()

LSTMSeq2Seq(
  (encoder): LSTMEncoder(
    (embedding): Embedding(35353, 1000)
    (lstm): LSTM(1000, 1000, num_layers=4, batch_first=True, dropout=0.2)
  )
  (decoder): LSTMDecoder(
    (embedding): Embedding(35353, 1000)
    (lstm): LSTM(1000, 1000, num_layers=4, batch_first=True, dropout=0.2)
    (attention): Attention(
      (W): Linear(in_features=1000, out_features=1000, bias=False)
    )
  )
  (lm_head): Linear(in_features=2000, out_features=35353, bias=True)
)

In [45]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(preds, labels):
    if isinstance(preds, tuple):
        preds = preds[0]

    print(preds, labels)
    
    decoded_preds, decoded_labels = postprocess_text(preds, labels)

    bleu_1 = bleu.compute(predictions=decoded_preds, references=decoded_labels, max_order=1)
    bleu_2 = bleu.compute(predictions=decoded_preds, references=decoded_labels, max_order=2)
    rouge_l = rouge.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=["rougeL"])
    result = {"bleu_1": bleu_1["bleu"], "bleu_2": bleu_2["bleu"], "rouge_l": rouge_l["rougeL"]}

    return result

In [38]:
tokenized_test_eng = [tokenize_en(text['en']) for text in tqdm(test_ds['translation'])]
tokenized_test_ger = [tokenize_de(text['de']) for text in tqdm(test_ds['translation'])]

tokenized_test_eng = [pad_seq(text + ['<eos>'], MAX_LENGTH) for text in tqdm(tokenized_test_eng)]
tokenized_test_ger = [pad_seq(text + ['<eos>'], MAX_LENGTH) for text in tqdm(tokenized_test_ger)]

encoded_test_eng = [encode_eng(text) for text in tqdm(tokenized_test_eng)]
encoded_test_ger = [encode_ger(text) for text in tqdm(tokenized_test_ger)]

tokenized_test_dataloader = DataLoader(list(zip(encoded_test_eng, encoded_test_ger)), batch_size=BATCH_SIZE, shuffle=True)

100%|██████████| 2999/2999 [00:00<00:00, 7022.61it/s]
100%|██████████| 2999/2999 [00:00<00:00, 5568.82it/s]
100%|██████████| 2999/2999 [00:00<00:00, 87423.85it/s]
100%|██████████| 2999/2999 [00:00<00:00, 102728.69it/s]
100%|██████████| 2999/2999 [00:00<00:00, 48277.19it/s]
100%|██████████| 2999/2999 [00:00<00:00, 51360.15it/s]


In [39]:
eng_idx2word = {v: k for k, v in eng_vocab.items()}
ger_idx2word = {v: k for k, v in ger_vocab.items()}

In [41]:
bleu_scores = []
rouge_scores = []

decoded_sent = []
trg_sent = []

with torch.no_grad():
    for batch in tqdm(tokenized_test_dataloader):
        src = torch.stack(batch[0]).to(torch.int64).to(config.device)
        trg = torch.stack(batch[1]).to(torch.int64).to(config.device)
        src = src.transpose(0, 1)
        trg = trg.transpose(0, 1)
        output = model(src, trg, 0)
        for i in range(output.shape[0]):
            decoded_tokens = output[i].argmax(dim = 1)
            decoded_sent.append(' '.join([ger_idx2word[i.item()] for i in decoded_tokens]))
            trg_sent.append([' '.join([ger_idx2word[i.item()] for i in trg[i]])])

100%|██████████| 94/94 [00:58<00:00,  1.61it/s]


In [47]:
bleu_1 = bleu.compute(predictions=decoded_sent, references=trg_sent, max_order=1)
bleu_2 = bleu.compute(predictions=decoded_sent, references=trg_sent, max_order=2)
rouge_l = rouge.compute(predictions=decoded_sent, references=trg_sent, rouge_types=["rougeL"])

In [51]:
bleu_1["bleu"], bleu_2["bleu"], rouge_l['rougeL']

(0.006034340265596968, 0.004383791814859717, 0.015717983324340348)