Нейронная сеть для перевода описания на русский язык. В основе её архитектуры  использована модель трансформера. Как и стандартные языковые модели, модель нейросети-переводчика состоит из кодировщика и декодера. Кодировщик, получая тензор, соответствующий заголовку описания изображения на английском языке, не сжимает все исходное предложение в один контекстный вектор, а создает последовательность контекстных векторов, каждый из которых видит все токены во всех позициях во входной последовательности. После этого декодер декодирует их для вывода итогового предложения на русском языке.

# Импорт библиотек

In [None]:
!pip install -U spacy

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
!python -m spacy download ru_core_news_sm

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset
from torch import nn
import torch.optim as optim
import torch.nn.functional as F

import spacy
spacy_eng = spacy.load("en_core_web_sm")
spacy_rus = spacy.load("ru_core_news_sm")

import random
import time
import math

In [None]:
SEED = 1234

random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(0)

# Подготовка данных

## Словари

In [None]:
class VocabularyEN:

    def __init__(self):
        self.itos = dict()#{0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = dict()#{v: k for k, v in self.itos.items()}

    def __len__(self):
        return len(self.itos)

    def load_vocab(self):
        with open('itos_en.txt') as itos:
            for i in itos.readlines():
                key, val = i.strip().split('|', maxsplit=1)
                self.itos[int(key)] = val
        with open('stoi_en.txt') as stoi:
            for i in stoi.readlines():
                key, val = i.strip().split('|', maxsplit=1)
                self.stoi[key] = int(val)
                
    def tokenize(self,text):
        return [token.text.lower() for token in spacy_eng.tokenizer(text)]
    
    def numericalize(self,text):
        """ For each word in the text corresponding index token for that word form the vocab built as list """
        tokenized_text = self.tokenize(text)
        return [self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text]

In [None]:
class VocabularyRU:
    def __init__(self,freq_threshold=5):
        #setting the pre-reserved tokens int to string tokens
        self.itos = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
        
        #string to int tokens
        #its reverse dict self.itos
        self.stoi = {v:k for k,v in self.itos.items()}
        
        self.freq_threshold = freq_threshold
        
    def __len__(self): return len(self.itos)
    
    @staticmethod
    def tokenize(text):
        return [token.text.lower() for token in spacy_rus.tokenizer(text)]
    
    def build_vocab(self, sentence_list):
        frequencies = Counter()
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                frequencies[word] += 1
                
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self,text):
        tokenized_text = self.tokenize(text)
        return [self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text]

## Создание Датасета и Даталоадера

In [None]:
class Dataset(Dataset):
 
    def __init__(self,is_train=True):
        self.captions_en = []
        self.captions_ru = []
        if is_train:
            path_ru = r'cap_ru_train.txt'
            path_en = r'cap_en_train.txt'
        else:
            path_ru = r'cap_ru_valid.txt'
            path_en = r'cap_en_valid.txt'
        with open(path_en, "r") as cap_en:
            for line in cap_en.readlines():
                self.captions_en.append(line.replace("\n",""))
        
        with open(path_ru, "r") as cap_en:
            for line in cap_en.readlines():
                self.captions_ru.append(line.replace("\n",""))
        
        #Initialize vocabulary and build vocab
        self.vocab_en = VocabularyEN()
        self.vocab_en.load_vocab()

        self.captions_ru_all = []
        with open('cap_ru_all.txt', "r") as cap_ru_all:
            for line in cap_ru_all.readlines():
                self.captions_ru_all.append(line.replace("\n",""))
        self.vocab_ru = VocabularyRU()
        self.vocab_ru.build_vocab(self.captions_ru_all)
        
    
    def __len__(self):
        return len(self.captions_ru)
    
    def __getitem__(self,idx):
        caption_en = self.captions_en[idx]
        caption_ru = self.captions_ru[idx]
                
        #numericalize the caption text
        caption_vec_en = []
        caption_vec_en += [self.vocab_en.stoi["<SOS>"]]
        caption_vec_en += self.vocab_en.numericalize(caption_en)
        caption_vec_en += [self.vocab_en.stoi["<EOS>"]]
        
        caption_vec_ru = []
        caption_vec_ru += [self.vocab_ru.stoi["<SOS>"]]
        caption_vec_ru += self.vocab_ru.numericalize(caption_ru)
        caption_vec_ru += [self.vocab_ru.stoi["<EOS>"]]
        
        #return torch.tensor(caption_vec_en[::-1]), torch.tensor(caption_vec_ru)
        return torch.tensor(caption_vec_en), torch.tensor(caption_vec_ru)

In [None]:
class CapsCollate:
 
    def __init__(self,pad_idx_en,pad_idx_ru,batch_first=False):
        self.pad_idx_en = pad_idx_en
        self.pad_idx_ru = pad_idx_ru
        self.batch_first = batch_first
    
    def __call__(self,batch):
        trg_en = [item[0] for item in batch]
        trg_en = pad_sequence(trg_en, batch_first=self.batch_first, padding_value=self.pad_idx_en)
        
        trg_ru = [item[1] for item in batch]
        trg_ru = pad_sequence(trg_ru, batch_first=self.batch_first, padding_value=self.pad_idx_ru)
        
        return trg_en,trg_ru

In [None]:
def get_data_loader(dataset,batch_size,shuffle=False,num_workers=1):
    pad_idx_en = dataset.vocab_en.stoi["<PAD>"]
    pad_idx_ru = dataset.vocab_ru.stoi["<PAD>"]

    collate_fn=CapsCollate(pad_idx_en=pad_idx_en,pad_idx_ru=pad_idx_ru,batch_first=True)

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    return data_loader

In [None]:
dataset_train = Dataset(is_train=True)
dataset_valid = Dataset(is_train=False)

In [None]:
#writing the dataloaders
#setting the constants
BATCH_SIZE = 256
NUM_WORKERS = 2

data_loader_train = get_data_loader(dataset_train, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
data_loader_valid = get_data_loader(dataset_valid, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

## Проверка работы Даталоадера

In [None]:
dataiter = iter(data_loader_train)
batch = next(dataiter)
caption_en, caption_ru = batch

for i in range(BATCH_SIZE):
    cap_en,cap_ru = caption_en[i],caption_ru[i]
    caption_label_en = [dataset_train.vocab_en.itos[token] for token in cap_en.tolist()]
    eos_index_en = caption_label_en.index('<EOS>')
    caption_label_en = caption_label_en[1:eos_index_en]
    caption_label_en = ' '.join(caption_label_en)
    
    caption_label_ru = [dataset_train.vocab_ru.itos[token] for token in cap_ru.tolist()]
    eos_index_ru = caption_label_ru.index('<EOS>')
    caption_label_ru = caption_label_ru[1:eos_index_ru]
    caption_label_ru = ' '.join(caption_label_ru) 

    print(caption_label_en, caption_label_ru)

# Создание модели

## Слой внимания "с несколькими головами"

In [None]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
    def forward(self, query, key, value, mask = None):
        
        batch_size = query.shape[0]
                        
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
                        
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
                    
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
              
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim = -1)
                
        x = torch.matmul(self.dropout(attention), V)
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim)
        x = self.fc_o(x)
        
        return x

## Слой Position-wise Feedforward

In [None]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)
        
        return x

## Слой кодировщика

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim,  
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        _src = self.self_attention(src, src, src, src_mask)        
        src = self.self_attn_layer_norm(src + self.dropout(_src))        
        _src = self.positionwise_feedforward(src)
        src = self.ff_layer_norm(src + self.dropout(_src))
        
        return src

## Кодировщик

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 input_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim,
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()

        self.device = device
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim,
                                                  dropout, 
                                                  device) 
                                     for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, src, src_mask):
                
        batch_size = src.shape[0]
        src_len = src.shape[1]
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        src = self.dropout((self.tok_embedding(src.to(device)) * self.scale) + self.pos_embedding(pos))
        
        for layer in self.layers:
            src = layer(src, src_mask)
        
        return src

## Слой декодера

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, 
                 hid_dim, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device):
        super().__init__()
        
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 
                                                                     pf_dim, 
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        _trg = self.self_attention(trg, trg, trg, trg_mask)
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))        
        _trg = self.encoder_attention(trg, enc_src, enc_src, src_mask)  
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))                    
        _trg = self.positionwise_feedforward(trg)
        trg = self.ff_layer_norm(trg + self.dropout(_trg))]
        
        return trg

## Декодер

In [None]:
class Decoder(nn.Module):
    def __init__(self, 
                 output_dim, 
                 hid_dim, 
                 n_layers, 
                 n_heads, 
                 pf_dim, 
                 dropout, 
                 device,
                 max_length = 100):
        super().__init__()
        
        self.device = device
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                  n_heads, 
                                                  pf_dim, 
                                                  dropout, 
                                                  device)
                                     for _ in range(n_layers)])
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        trg = self.dropout((self.tok_embedding(trg.to(self.device)) * self.scale) + self.pos_embedding(pos))
        
        for layer in self.layers:
            trg = layer(trg, enc_src, trg_mask, src_mask)

        output = self.fc_out(trg)
        
        return output

## Языковая модель

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, 
                 encoder, 
                 decoder, 
                 src_pad_idx, 
                 trg_pad_idx, 
                 device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2).to(self.device)

        return src_mask
    
    def make_trg_mask(self, trg):
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        trg_len = trg.shape[1]        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        trg_mask = trg_pad_mask.to(self.device) & trg_sub_mask.to(self.device)
        
        return trg_mask

    def forward(self, src, trg):

        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        output = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        return output

# Создание экземпляра модели

In [None]:
INPUT_DIM = len(dataset_train.vocab_en)
OUTPUT_DIM = len(dataset_train.vocab_ru)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 4
DEC_HEADS = 4
ENC_PF_DIM = 256
DEC_PF_DIM = 256
ENC_DROPOUT = 0.2
DEC_DROPOUT = 0.2

PAD_IDX = dataset_train.vocab_en.stoi['<PAD>']
SOS_IDX = dataset_train.vocab_ru.stoi['<SOS>']
EOS_IDX = dataset_train.vocab_ru.stoi['<EOS>']

In [None]:
enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device)
dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device)

SRC_PAD_IDX = dataset_train.vocab_en.stoi['<PAD>']
TRG_PAD_IDX = dataset_train.vocab_ru.stoi['<PAD>']

model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.05)
        else:
            nn.init.constant_(param.data, 0)
        
model.apply(init_weights)

In [None]:
LEARNING_RATE = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)

# Обучение и оценка

In [None]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        src = batch[0]        
        trg = batch[1]
        
        optimizer.zero_grad()
        
        if int(len(batch[0])) == BATCH_SIZE:
            output = model(src, trg[:,:-1])                
            output_dim = output.shape[-1]            
            output = output.view(-1, output_dim)            
            trg = trg[:,1:].contiguous().view(-1).to(device)
                
            if epoch_loss == 0:
                print(" ".join([dataset_train.vocab_ru.itos[token] for token in output.argmax(1).tolist()[:25]]))
                print(" ".join([dataset_train.vocab_ru.itos[token] for token in trg.tolist()[:25]]))
            
            loss = criterion(output, trg)        
            loss.backward()                
            optimizer.step()        
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [None]:
def evaluate(model, iterator, criterion):
    
    model.eval()    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch[0]
            trg = batch[1]

            if int(len(batch[0])) == BATCH_SIZE:
                output = model(src, trg[:,:-1])                
                output_dim = output.shape[-1]
                output = output.contiguous().view(-1, output_dim)                
                trg = trg[:,1:].contiguous().view(-1)
                loss = criterion(output, trg.to(device))
                epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
N_EPOCHS = 30
CLIP = 1
date = '29-05'

best_valid_loss = float('inf')
k = 0

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    train_iterator = iter(data_loader_train)
    valid_iterator = iter(data_loader_valid)
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        name = 'translater_' + date + '-' + str(k) + '.pt'
        k += 1
        torch.save(model.state_dict(), name)
        saved_best_model = True
    else:
        saved_best_model = False
    
    print('Epoch: {:02} Time: {}m {}s | Train: Loss = {:.7f}  | Val: Loss = {:.7f}  | SAVE = {}'.format(
         epoch+1,
         epoch_mins,
         epoch_secs,
         train_loss,
         valid_loss,
         saved_best_model))

In [None]:
name_best = 'translater_' + date + '-' + str(0) + '.pt'
torch.save(model.state_dict(), name_best)
model_loaded = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
model_loaded.load_state_dict(torch.load(name_best))

In [None]:
model.load_state_dict(torch.load(r'translater_29-05-1.pt'))

# Проверка

In [None]:
def translate_sentence(sentence, model, device, max_len = 25):
    
    model.eval()

    caption_vec_en = []
    caption_vec_en += [dataset_train.vocab_en.stoi["<SOS>"]]
    caption_vec_en += dataset_train.vocab_en.numericalize(sentence)
    caption_vec_en += [dataset_train.vocab_en.stoi["<EOS>"]]

    src_tensor = torch.LongTensor(caption_vec_en).unsqueeze(0).to(device) 
    src_mask = model.make_src_mask(src_tensor)
    
    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    trg_indexes = [dataset_train.vocab_ru.stoi["<SOS>"]]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        
        with torch.no_grad():
            output = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        
        pred_token = output.argmax(2)[:,-1].item()        
        trg_indexes.append(pred_token)

        if pred_token == dataset_train.vocab_ru.stoi["<EOS>"]:
            break
    
    caption_label_ru = [dataset_train.vocab_ru.itos[token] for token in trg_indexes]
    
    return caption_label_ru[1:]

In [None]:
src = ['A computer desk with a laptop and a keyboard.',
       "A woman is cooking a dinner",
       "A refrigerator with a bunch of food on it",
       "A group of people sitting on a train platform", 
       "A cookie sitting on top of a white plate.",
       "A vintage stove and washing tub on a brick floor.",
       "A man on a motorcycle is going down the street.",
       "A man riding a motorcycle next to a lush green forest.",
       "A person is riding a motorcycle down a country road.",
       "A motorcyclist travels down a country two-lane highway.",
       "a person riding a motorcycle on a road with trees",
       "A smiling woman holding a baby has a camera in her hand.",
       "a woman takes a photo of herself and her child."]

for i in src:    
    caption_label_ru = translate_sentence(i, model, device)
    print(i, ' '.join(caption_label_ru[:-1]))

In [None]:
dataiter = iter(data_loader_valid)
batch = next(dataiter)
caption_en, caption_ru = batch

for i in range(BATCH_SIZE):
    cap_en,cap_ru = caption_en[i],caption_ru[i] 
    caption_label_en = [dataset_train.vocab_en.itos[token] for token in cap_en.tolist()]
    eos_index_en = caption_label_en.index('<EOS>')
    caption_label_en = caption_label_en[1:eos_index_en]
    caption_label_en = ' '.join(caption_label_en)
    caption_label_ru_tr = translate_sentence(caption_label_en, model, device)
    
    caption_label_ru = [dataset_train.vocab_ru.itos[token] for token in cap_ru.tolist()]
    eos_index_ru = caption_label_ru.index('<EOS>')
    caption_label_ru = caption_label_ru[1:eos_index_ru]
    caption_label_ru = ' '.join(caption_label_ru) 
    
    model_ru = ' '.join(caption_label_ru_tr[:-1])
    print("CAPTION :", caption_label_en, "\nGT :", caption_label_ru, "\nTRANSLATION :", model_ru, "\n----------------------------------------")

# BLEU

In [None]:
from torchtext.data.metrics import bleu_score
import nltk

def calculate_bleu(data, model, device, max_len = 25):

    trgs = []
    pred_trgs = []
    
    for i in range(len(data.captions_en)):
        
        src = data.captions_en[i]
        trg = data.captions_ru[i]
        pred_trg = translate_sentence_1(src, model, device, max_len)
        pred_trg = pred_trg[:-1]        
        pred_trgs.append(pred_trg)
        trgs.append([trg])

    return bleu_score(pred_trgs, trgs), nltk.translate.bleu_score.corpus_bleu(trgs, pred_trgs)

In [None]:
print(calculate_bleu(dataset_valid, model, device, max_len = 25))