In [None]:
!pip install torchdata

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchdata
  Downloading torchdata-0.5.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m39.4 MB/s[0m eta [36m0:00:00[0m
Collecting portalocker>=2.0.0
  Downloading portalocker-2.7.0-py2.py3-none-any.whl (15 kB)
Collecting urllib3>=1.25
  Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.6/140.6 KB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: urllib3, portalocker, torchdata
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.24.3
    Uninstalling urllib3-1.24.3:
      Successfully uninstalled urllib3-1.24.3
Successfully installed portalocker-2.7.0 torchdata-0.5.1 urllib3-1.26.14


In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy
import math
import torchdata

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("using", device)

torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed(777)

using cuda


In [None]:
norm_eps = 1e-5

In [None]:
class multiheadattention(torch.nn.Module):
    def __init__(self, h, d_embed, d_model, dr_rate = 0.1):
        super(multiheadattention, self).__init__()
        self.d_model = d_model
        self.h = h
        self.q_fc = torch.nn.Linear(d_embed, d_model).to(device)
        self.k_fc = torch.nn.Linear(d_embed, d_model).to(device)
        self.v_fc = torch.nn.Linear(d_embed, d_model).to(device)
        self.out_fc = torch.nn.Linear(d_model, d_embed).to(device)
        self.dropout = torch.nn.Dropout(p=dr_rate)

    def forward(self, query, key, value, mask=None):
        n_batch = query.size(0)
        d_k = self.d_model // self.h
        
        query_t = self.q_fc(query).view(n_batch, -1, self.h, d_k).transpose(1, 2)
        key_t = self.k_fc(key).view(n_batch, -1, self.h, d_k).transpose(1, 2)
        value_t = self.v_fc(value).view(n_batch, -1, self.h, d_k).transpose(1, 2)
        
        score = torch.matmul(query_t, key_t.transpose(-1, -2)) / math.sqrt(d_k)
        if mask is not None:
            score = score.masked_fill(mask, -1e9)
        prob = torch.nn.functional.softmax(score, dim=-1)
        prob = self.dropout(prob)
        out = torch.matmul(prob, value_t)

        out = out.transpose(1, 2)
        out = out.contiguous().view(n_batch, -1, self.d_model)
        out = self.out_fc(out)
        return out

In [None]:
class PositionWiseFeedForardLayer(torch.nn.Module):
    def __init__(self, d_embed, d_ff, dr_rate = 0.1):
        super(PositionWiseFeedForardLayer, self).__init__()
        self.fc1 = torch.nn.Linear(d_embed, d_ff).to(device)
        self.fc2 = torch.nn.Linear(d_ff, d_embed).to(device)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=dr_rate)
    
    def forward(self, x):
        return self.fc2(self.dropout(self.relu(self.fc1(x))))

In [None]:
class encoder_block(torch.nn.Module):
    def __init__(self, h, d_model, d_embed, d_ff, dr_rate = 0.1):
        super(encoder_block, self).__init__()
        self.attention = multiheadattention(d_model = d_model, d_embed=d_embed, h=h)
        self.feed_forward = PositionWiseFeedForardLayer(d_embed=d_embed, d_ff=d_ff)
        self.norm1 = torch.nn.LayerNorm(d_embed, eps = norm_eps)
        self.norm2 = torch.nn.LayerNorm(d_embed, eps = norm_eps)
        self.dropout = torch.nn.Dropout(p=dr_rate)
    
    def forward(self, x, mask):
        residual = x
        out = self.norm1(residual)
        out = self.attention(out, out, out, mask)
        out = self.dropout(out)
        out = out + residual
        residual = out
        out = self.norm2(residual)
        out = self.feed_forward(out)
        out = self.dropout(out)
        out = out + residual
        return out

In [None]:
class encoder(torch.nn.Module):
    def __init__(self, h, d_model, d_embed, d_ff, n_layer):
        super(encoder, self).__init__()
        self.layers = torch.nn.ModuleList([encoder_block(h, d_model, d_embed, d_ff) for _ in range(n_layer)])
        self.norm = torch.nn.LayerNorm(d_embed, eps = norm_eps)
    def forward(self, x, mask):
        out = x
        for layer in self.layers:
            out = layer(out, mask)
        out = self.norm(out)
        return out

In [None]:
class decoder_block(torch.nn.Module):
    def __init__(self, h, d_model, d_embed, d_ff, dr_rate = 0.1):
        super(decoder_block, self).__init__()
        self.self_attention = multiheadattention(d_model = d_model, d_embed=d_embed, h=h)
        self.cross_attention = multiheadattention(d_model = d_model, d_embed=d_embed, h=h)
        self.feed_forward = PositionWiseFeedForardLayer(d_embed=d_embed, d_ff=d_ff)
        self.norm1 = torch.nn.LayerNorm(d_embed, eps = norm_eps)
        self.norm2 = torch.nn.LayerNorm(d_embed, eps = norm_eps)
        self.norm3 = torch.nn.LayerNorm(d_embed, eps = norm_eps)
        self.dropout = torch.nn.Dropout(p=dr_rate)
    
    def forward(self, z, o, tgt_mask, src_tgt_mask):
        residual = z
        out = self.norm1(residual)
        out = self.self_attention(out, out, out, tgt_mask)
        out = self.dropout(out)
        out = out + residual
        residual = out
        out = self.norm2(residual)
        out = self.cross_attention(out, o, o, src_tgt_mask)
        out = self.dropout(out)
        out = out + residual
        residual = out
        out = self.norm3(residual)
        out = self.feed_forward(out)
        out = self.dropout(out)
        out = out + residual
        return out

In [None]:
class decoder(torch.nn.Module):
    def __init__(self, h, d_model, d_embed, d_ff, n_layer):
        super(decoder, self).__init__()
        self.norm = torch.nn.LayerNorm(d_embed, eps = norm_eps)
        self.layers = torch.nn.ModuleList([decoder_block(h, d_model, d_embed, d_ff) for _ in range(n_layer)])
    
    def forward(self, z, o, tgt_mask, src_tgt_mask):
        out = z
        for layer in self.layers:
            out = layer(out, o, tgt_mask, src_tgt_mask)
        out = self.norm(out)
        return out

In [None]:
class TokenEmbedding(torch.nn.Module):
    def __init__(self, d_embed, vocab_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, d_embed)
        self.d_embed = d_embed
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_embed)

In [None]:
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_embed, max_len = 256):
        super(PositionalEncoding, self).__init__()
        encoding = torch.zeros(max_len, d_embed)
        encoding.requires_grad = False
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed))
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = encoding.unsqueeze(0).to(device)
    
    def forward(self, x):
        seq_len = x.size(1)
        pos_embed = self.encoding[:, :seq_len, :]
        out = x + pos_embed
        return out

In [None]:
class TransformerEmbedding(torch.nn.Module):
    def __init__(self, d_embed, vocab_size, max_len):
        super(TransformerEmbedding, self).__init__()
        self.embed = TokenEmbedding(d_embed=d_embed, vocab_size=vocab_size)
        self.positional = PositionalEncoding(d_embed=d_embed, max_len=max_len)
    
    def forward(self, x):
        return self.positional(self.embed(x))

In [None]:
class Transformer(torch.nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, max_len, d_embed, n_layer, d_model, h, d_ff):
        super(Transformer, self).__init__()
        self.src_embed = TransformerEmbedding(d_embed=d_embed, vocab_size=src_vocab_size, max_len=max_len)
        self.tgt_embed = TransformerEmbedding(d_embed=d_embed, vocab_size=tgt_vocab_size, max_len=max_len)
        self.generator = torch.nn.Linear(d_model, tgt_vocab_size).to(device)
        self.encoder = encoder(h, d_model, d_embed, d_ff, n_layer)
        self.decoder = decoder(h, d_model, d_embed, d_ff, n_layer)
    
    def forward(self, x, z):
        src_mask = self.make_src_mask(x)
        tgt_mask = self.make_tgt_mask(z)
        src_tgt_mask = self.make_src_tgt_mask(x, z)
        c = self.encoder(self.src_embed(x), src_mask)
        y = self.decoder(self.tgt_embed(z), c, tgt_mask, src_tgt_mask)
        y = torch.nn.functional.log_softmax(self.generator(y), dim=2)
        return y
    
    def make_pad_mask(self, query, key, pad_idx = 3):
        query_seq_len, key_seq_len = query.size(1), key.size(1)
        key_mask = key.ne(pad_idx).unsqueeze(1).unsqueeze(2).repeat(1, 1, query_seq_len, 1)
        query_mask = query.ne(pad_idx).unsqueeze(1).unsqueeze(3).repeat(1, 1, 1, key_seq_len)
        mask = key_mask & query_mask
        mask.requires_grad = False
        return mask
    
    def make_subsequent_mask(self, query, key):
        query_seq_len, key_seq_len = query.size(1), key.size(1)
        tril = numpy.tril(numpy.ones((query_seq_len, key_seq_len)), k=0).astype('uint8') # lower triangle without diagonal
        mask = torch.tensor(tril, dtype=torch.bool, requires_grad=False, device=device)
        return mask
    
    def make_src_mask(self, src):
        return self.make_pad_mask(src, src)
    
    def make_tgt_mask(self, tgt):
        pad_mask = self.make_pad_mask(tgt, tgt)
        seq_mask = self.make_subsequent_mask(tgt, tgt)
        return pad_mask & seq_mask
    
    def make_src_tgt_mask(self, src, tgt):
        return self.make_pad_mask(tgt, src)

In [None]:
"""
def build_model(src_vocab_size, tgt_vocab_size, device="cuda", max_len=256, d_embed=512, n_layer=6, d_model=512, h=8, d_ff=2048):
    
    #attention = multiheadattention(d_model = d_model, h=h, q_fc=torch.nn.Linear(d_embed, d_model).to(device), k_fc=torch.nn.Linear(d_embed, d_model).to(device), v_fc=torch.nn.Linear(d_embed, d_model).to(device), out_fc=torch.nn.Linear(d_model, d_embed).to(device))
    #position_ff = PositionWiseFeedForardLayer(fc1=torch.nn.Linear(d_embed, d_ff).to(device), fc2=torch.nn.Linear(d_ff, d_embed).to(device))
    #encoder_blk = encoder_block(attention=copy.deepcopy(attention), feed_forward=copy.deepcopy(position_ff))
    #decoder_blk = decoder_block(self_attention=copy.deepcopy(attention), cross_attention=copy.deepcopy(attention), feed_forward=copy.deepcopy(position_ff))
    #ecd = encoder(encoder_block=encoder_blk, n_layer=n_layer)
    #dcd = decoder(decoder_block=decoder_blk, n_layer=n_layer)
    
    model = Transformer(src_vocab_size=src_vocab_size, tgt_vocab_size=tgt_vocab_size, max_len=256, d_embed=512, n_layer=6, d_model=512, h=8, d_ff=2048, generator=generator).to(device)
    model.device = device

    return model
"""

'\ndef build_model(src_vocab_size, tgt_vocab_size, device="cuda", max_len=256, d_embed=512, n_layer=6, d_model=512, h=8, d_ff=2048):\n    \n    #attention = multiheadattention(d_model = d_model, h=h, q_fc=torch.nn.Linear(d_embed, d_model).to(device), k_fc=torch.nn.Linear(d_embed, d_model).to(device), v_fc=torch.nn.Linear(d_embed, d_model).to(device), out_fc=torch.nn.Linear(d_model, d_embed).to(device))\n    #position_ff = PositionWiseFeedForardLayer(fc1=torch.nn.Linear(d_embed, d_ff).to(device), fc2=torch.nn.Linear(d_ff, d_embed).to(device))\n    #encoder_blk = encoder_block(attention=copy.deepcopy(attention), feed_forward=copy.deepcopy(position_ff))\n    #decoder_blk = decoder_block(self_attention=copy.deepcopy(attention), cross_attention=copy.deepcopy(attention), feed_forward=copy.deepcopy(position_ff))\n    #ecd = encoder(encoder_block=encoder_blk, n_layer=n_layer)\n    #dcd = decoder(decoder_block=decoder_blk, n_layer=n_layer)\n    \n    model = Transformer(src_vocab_size=src_vocab

In [None]:
import torchdata
import random
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence

multi_train, multi_valid, multi_test = Multi30k(language_pair=('en', 'de'))
for i, (eng, de) in enumerate(multi_train):
    if i==5:
        break
    print("index", i)
    print(eng)
    print(de)

index 0
Two young, White males are outside near many bushes.
Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
index 1
Several men in hard hats are operating a giant pulley system.
Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.
index 2
A little girl climbing into a wooden playhouse.
Ein kleines Mädchen klettert in ein Spielhaus aus Holz.
index 3
A man in a blue shirt is standing on a ladder cleaning a window.
Ein Mann in einem blauen Hemd steht auf einer Leiter und putzt ein Fenster.
index 4
Two men are at the stove preparing food.
Zwei Männer stehen am Herd und bereiten Essen zu.


In [None]:
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

2023-02-03 02:59:26.008350: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting en-core-web-sm==3.4.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m90.1 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
2023-02-03 02:59:39.860514: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from differ

In [None]:
en_tokenizer = get_tokenizer(tokenizer='spacy', language='en_core_web_sm')
de_tokenizer = get_tokenizer(tokenizer='spacy', language='de_core_news_sm')
en_vocab = build_vocab_from_iterator(map(en_tokenizer, [english for english, _ in multi_train]), min_freq=2, specials=["<unk>", "<sos>", "<eos>", "<pad>"])
de_vocab = build_vocab_from_iterator(map(de_tokenizer, [de for _ , de in multi_train]), min_freq=2, specials=["<unk>", "<sos>", "<eos>", "<pad>"])
en_token2id = en_vocab.get_stoi()
de_token2id = de_vocab.get_stoi()

en_id2token = en_vocab.get_itos()
de_id2token = de_vocab.get_itos()

en_vocab_size = len(en_token2id)
de_vocab_size = len(de_token2id)

print("English vocab size :", len(en_token2id))
print("Deutsch vocab size :", len(de_token2id))

English vocab size : 6191
Deutsch vocab size : 8014


In [None]:
class language:
    unk_token_id = 0
    sos_token_id = 1
    eos_token_id = 2
    pad_token_id = 3

    def __init__(self, src_tokenizer, tgt_tokenizer, src_token2id, tgt_token2id, src_id2token, tgt_id2token):
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer

        self.src_token2id = src_token2id
        self.tgt_token2id = tgt_token2id

        self.src_id2token = src_id2token
        self.tgt_id2token = tgt_id2token
    
    def src_encode(self, src_text):
        source_token = [self.src_token2id.get(token, self.unk_token_id) for token in self.src_tokenizer(src_text)]
        return source_token

    def tgt_encode(self, tgt_text):
        target_token = [self.sos_token_id] + [self.tgt_token2id.get(token, self.unk_token_id) for token in self.tgt_tokenizer(tgt_text)] + [self.eos_token_id]
        return target_token

    def src_decode(self, src_token):
        source_sentence = list(map(lambda x: self.src_id2token[x], src_token))
        return " ".join(source_sentence)

    def tgt_decode(self, tgt_token):
        source_sentence = list(map(lambda x: self.tgt_id2token[x], tgt_token))[1:-1]
        return " ".join(source_sentence)

In [None]:
class MultiDataset(torch.utils.data.Dataset):
    def __init__(self, data, language):
        self.data = data
        self.language = language
        self.sentences = self.preprocess()

    def preprocess(self):
        sentences = [(self.language.src_encode(eng), self.language.tgt_encode(de)) for eng, de in self.data if len(eng) > 0 and len(de) > 0]
        return sentences

    def __getitem__(self, idx):
        return self.sentences[idx]
    
    def __len__(self):
        return len(self.sentences)

In [None]:
language_preprocess = language(en_tokenizer, de_tokenizer, en_token2id, de_token2id, en_id2token, de_id2token)

In [None]:
multi_train_dataset = MultiDataset(multi_train, language_preprocess)
multi_val_dataset = MultiDataset(multi_valid, language_preprocess)
#multi_test_dataset = MultiDataset(multi_test, language_preprocess)

In [None]:
print(len(multi_train_dataset))
print(len(multi_val_dataset))
#print(len(multi_test_dataset))

29000
1014


In [None]:
def collate_fn(batch_samples):
    src_sentences = pad_sequence([torch.tensor(src) for src, _ in batch_samples], batch_first=True, padding_value=language_preprocess.pad_token_id)
    tgt_sentences = pad_sequence([torch.tensor(tgt) for _, tgt in batch_samples], batch_first=True, padding_value=language_preprocess.pad_token_id)
    return src_sentences, tgt_sentences

In [None]:
def batch_sampling(sequence_lengths, batch_size):
    seq_lens = [(i, seq_len, tgt_len) for i, (seq_len, tgt_len) in enumerate(sequence_lengths)]
    seq_lens = sorted(seq_lens, key = lambda x:x[1])
    seq_lens = [sample[0] for sample in seq_lens]
    sample_indices = [seq_lens[i:i+batch_size] for i in range(0, len(seq_lens), batch_size)]
    random.shuffle(sample_indices)
    return sample_indices

In [None]:
batch_size = 100
seq_lengths = list(map(lambda x: (len(x[0]), len(x[1])), multi_train_dataset))
batch_sampler = batch_sampling(seq_lengths, batch_size)
train_loader = torch.utils.data.DataLoader(multi_train_dataset, collate_fn=collate_fn, batch_sampler=batch_sampler)

In [None]:
def train_epoch(model, data_loader, optimizer, criterion):
    model.train()
    loss_epoch = 0
    for idx, (src, tgt) in enumerate(data_loader):
        src = src.to(device)
        tgt = tgt.to(device)
        tgt_x = tgt[:, :-1]
        tgt_y = tgt[:, 1:]

        optimizer.zero_grad()
        
        output = model(src, tgt_x)
        y_hat = output.contiguous().view(-1, output.shape[-1])
        y_gt = tgt_y.contiguous().view(-1)
        loss = criterion(y_hat, y_gt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        loss_epoch += loss.item()
    
    return loss_epoch / (idx + 1)

In [None]:
def eval(model, data_loader, criterion, print_first_batch):
    model.eval()
    loss_epoch = 0
    with torch.no_grad():
        for idx, (src, tgt) in enumerate(data_loader):
            src = src.to(device)
            tgt = tgt.to(device)
            tgt_x = tgt[:, :-1]
            tgt_y = tgt[:, 1:]
            
            output = model(src, tgt_x)
            y_hat = output.contiguous().view(-1, output.shape[-1])
            y_gt = tgt_y.contiguous().view(-1)
            loss = criterion(y_hat, y_gt)

            if print_first_batch:
                for i in range(10):
                    source = language_preprocess.src_decode(src[i, :])
                    target = language_preprocess.tgt_decode(tgt[i, :])
                    predict = language_preprocess.tgt_decode(torch.argmax(output[i, :], dim=1))
                    print(i, "en :" ,source)
                    print(i, "gt :" ,target)
                    print(i, "de :", predict)
                print_first_batch = False
            
            loss_epoch += loss.item()
    
    return loss_epoch / (idx + 1)

In [None]:
def train(model, data_loader, optimizer, scheduler, criterion, epoch):
    for i in range(epoch):
        train_loss = train_epoch(model, data_loader, optimizer, criterion)
        print("EPOCH[" + str(i) + "] Train Loss : " + str(train_loss))
        scheduler.step()
        if(i % 10 == 0):
            print("Eval")
            eval_loss = eval(model, data_loader, criterion, True)
            print("Eval Loss : " + str(eval_loss))

In [None]:
learning_rate = 0.001
max_lr = 0.1
epoch = 100

model = Transformer(src_vocab_size=en_vocab_size, tgt_vocab_size=de_vocab_size, max_len=256, d_embed=512, n_layer=6, d_model=512, h=8, d_ff=2048).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = max_lr, steps_per_epoch=int(len(multi_train_dataset)/batch_size), epochs = epoch)
criterion = torch.nn.CrossEntropyLoss(ignore_index=language_preprocess.pad_token_id)

train(model, train_loader, optimizer, scheduler, criterion, epoch)

EPOCH[0] Train Loss : 2.3937683561752583
Eval
0 en : A woman in a blue top and hat is stretching or performing yoga on the beach .
0 de : Frau mit Hut in einem Jungen Frauen macht am Strand trinkt in umher . <eos> <eos> <eos> <eos> <eos>
1 en : A young man in odd clothes leaning against a brick wall with a garden in it .
1 de : junger Mann in <unk> Kleidung , an einer Bank , die einen Baumstamm verkauft . <eos> <eos> <eos> <eos>
2 en : Two men are looking toward the ground while one is wearing gloves and holding a tool .
2 de : Männer bereiten zum Kleidung , die weißen ihnen trägt Jacke und hält ein Bild . <eos> <eos> <eos> <eos>
3 en : A young child with a gray shirt and pacifier is standing next to a toy horse .
3 de : kleines Kind mit einem grauen Hemd und <unk> steht neben einem Pool . <eos> <eos> <eos> <eos> <eos> <eos>
4 en : Three blond girls and a dark - haired girl try to sell decorated <unk> and rocks .
4 de : arbeiten Mädchen und ein großes Mädchen schläft , <unk> Armen und 

In [None]:
torch.save(model, "model.pth")
model = torch.load("model.pth")

In [None]:

seq_lengths = list(map(lambda x: (len(x[0]), len(x[1])), multi_val_dataset))
batch_sampler = batch_sampling(seq_lengths, batch_size)
val_loader = torch.utils.data.DataLoader(multi_val_dataset, collate_fn=collate_fn, batch_sampler=batch_sampler)

eval_loss = eval(model, val_loader, criterion, True)
print(eval_loss)
eval_loss = eval(model, train_loader, criterion, True)
print(eval_loss)

0 en : A barefoot boy with a blue and white striped towel is standing on the beach . <pad> <pad>
0 gt : Ein barfüßiger Junge mit einem blau-weiß gestreiften Handtuch steht am Strand . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
0 de : Junge Junge mit einem blau-weiß gestreiften T-Shirt steht am Strand . <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
1 en : A uniformed man in the Army is training a German Shepherd using an arm guard . <pad> <pad>
1 gt : Ein uniformierter Mann von der Armee trainiert einen <unk> Schäferhund mit einem <unk> . <eos> <pad> <pad> <pad> <pad> <pad> <pad>
1 de : Mann Mann , der trainiert trainiert einen <unk> stehen mit einem <unk> . <eos> <eos> <eos> <eos> <eos> <eos> <eos>
2 en : A young woman in a pink shirt attempting to rope a calf at the rodeo . <pad> <pad>
2 gt : Eine junge Frau in einem pinkfarbenen Shirt versucht bei einem Rodeo , ein Kalb einzufangen . <eos> <pad> <pad> <pad> <pad>
2 de : Frau Frau in einem pinkfarbenen versucht vers