In [None]:
import torch
import torch.nn as nn
from torchtext.vocab import GloVe

import NMT

#https://tanmay17061.medium.com/load-pre-trained-glove-embeddings-in-torch-nn-embedding-layer-in-under-2-minutes-f5af8f57416a
# 
embeddings_dict = {}
with open(f"glove.6B.{dim}d.txt", 'r') as f:
    for line in f:
        values = line.split()
        word = values[0]
        vector = np.asarray(values[1:], "float32")
        embeddings_dict[word] = vector
global_vectors = GloVe(name='6B', dim=embed_dim)
glove_weights = torch.load(f".vector_cache/glove.6B.{embed_dim}d.txt.pt")
emb_layer = nn.Embedding.from_pretrained(glove_weights[2], freeze=True, padding_idx=NMT.Constants.PAD)

In [None]:
#Parameters
dim = 100
lstm_layers = 2
dropout = 0.2
batch_size = 32

#training params
epochs = 20
learning_rate = 1
optimizer = "SGD"
#Next options are outdated in favour of optimizer schedulers?
#max_grad_norm
#learning_rate_decay
#start_decay_at

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# https://huggingface.co/dbmdz/bert-base-german-cased

In [None]:
tokenizer_en = AutoTokenizer.from_pretrained('distilbert-base-uncased')
tokenizer_de = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")

In [None]:
from torch.utils.data import DataLoader
def encode_trans(examples):
  examples = examples["translation"]
  ens = []
  des = []
  for ex in examples:
    # possible filter short sentences so no padding is needed
      ens.append(ex['en'])
      des.append(ex['de'])
  inputs = tokenizer_en(ens, padding='longest', truncation=True, max_length=40)
  targets = tokenizer_de(des, padding='longest', truncation=True, max_length=40)
  return {'input': inputs["input_ids"], "target": targets["input_ids"]}

from torch.utils.data import DataLoader
def collate_custom(batch):
  inputs = batch[0]["input"]
  targets = batch[0]["target"]
  return torch.tensor(inputs, dtype=torch.long), torch.tensor(targets, dtype=torch.long)



In [None]:
dataset_stream = load_dataset("wmt16", "de-en", streaming=True, split="train", trust_remote_code=True)
dataset_batched = dataset_stream.batch(batch_size=32)
dataset_m = dataset_batched.map(encode_trans,remove_columns="translation")
train_dataloader = DataLoader(dataset_m, collate_fn=collate_custom)

In [None]:
from NMT import Models

encoder = Models.Encoder(num_layers=lstm_layers, bidirectional=True, dropout=dropout, rnn_size=dim)
decoder = Models.Decoder(num_layers=lstm_layers, bidirectional=False, dropout=dropout, rnn_size=dim)
model = Models.NMTModel(encoder, decoder)

In [None]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,

          decoder_optimizer, criterion):

    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor = data

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)