In [None]:
from scripts.custom_dataset import CustomDataset
from scripts.vectorizer import Seq2Seq_Vectorizer
from scripts.tokenizer import SeparatorTokenizer
from scripts.vocabulary import Vocabulary
from scripts.model import TransformerModel, subsequent_mask

import os
import math
import pandas
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:
TEST_PROPORTION = 0.0
EVAL_PROPORTION = 0.2

TOKENS_TRESHOLD_FREQ = 10

SHUFFLE = True
DROP_LAST = True
EPOCHS = 10
LEARNING_RATE = 0.01

LR_SCHEDULER_FACTOR = 0.5
LR_SCHEDULER_PATIENCE = 2

BATCH_SIZE = 128
EMBEDDING_DIM = 128
MODEL_DIM = 256
NUM_HEAD = 6
NUM_ENCODER_LAYERS = 5
NUM_DECODER_LAYERS = 5
FC_HIDDEN_DIM = MODEL_DIM*4
DROPOUT = 0.1
TEMPERATURE = 0.9
BATCH_FIRST = True

MAX_SOURCE_SEQ_LEN = 100
MAX_TARGET_SEQ_LEN = 114
MAX_SEQ_LEN = 200

MODEL_SAVE_FILEPATH = 'data/model_params.pt'
DATASET_PATH = 'D:/Files/Datasets/NMT_ru_en'

RANDOM_STATE = 42

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
print(DEVICE)

In [None]:
def generate_batches(dataset, batch_size, shuffle=True, drop_last=True, device='cpu'):
    '''Перенос данных на device и подготовка к упаковке в padded_seq'''
    dataloader = DataLoader(dataset, batch_size, shuffle, drop_last=drop_last)
    for data_dict in dataloader:
        out_data_dict = {}
        for name, tensor in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict

In [None]:
def normalize_sizes(y_pred, y_true):
    """Normalize tensor sizes
    
    Args:
        y_pred (torch.Tensor): the output of the model
            If a 3-dimensional tensor, reshapes to a matrix
        y_true (torch.Tensor): the target predictions
            If a matrix, reshapes to be a vector
    """
    if len(y_pred.size()) == 3:
        y_pred = y_pred.reshape(-1, y_pred.size(2))
    if len(y_true.size()) == 2:
        y_true = y_true.reshape(-1)
    return y_pred, y_true

In [None]:
def compute_accuracy(y_pred, y_true, mask_index):
    y_pred, y_true = normalize_sizes(y_pred, y_true)

    _, y_pred_indices = y_pred.max(dim=1)
    
    correct_indices = torch.eq(y_pred_indices, y_true).float()
    valid_indices = torch.ne(y_true, mask_index).float()
    
    n_correct = (correct_indices * valid_indices).sum().item()
    n_valid = valid_indices.sum().item()

    return n_correct / n_valid * 100

In [None]:
def sequence_loss(y_pred, y_true, mask_index):
    y_pred, y_true = normalize_sizes(y_pred, y_true)
    return F.cross_entropy(y_pred, y_true, ignore_index=mask_index)

In [None]:
def get_tokens_freq(dataframe : pandas.DataFrame) -> tuple[dict, dict]:
    '''Принимает токенизированный датафрейм'''
    source_freq = {}
    target_freq = {}
    for i in range(len(dataframe)):
        source_tokens, target_tokens = (dataframe.loc[i, 'source_text'], dataframe.loc[i, 'target_text'])
        for token in source_tokens:
            if token in source_freq:
                source_freq[token] += 1
            else:
                source_freq[token] = 1
        for token in target_tokens:
            if token in target_freq:
                target_freq[token] += 1
            else:
                target_freq[token] = 1
    return source_freq, target_freq

In [None]:
def get_max_tokenized_seq_len(dataframe : pandas.DataFrame) -> tuple[int, int]:
    '''Принимает датафрейм с токенизированными предложениями.
    Возвращает два списка: длина исходных текстов, длина таргет текстов'''
    source_max_len = target_max_len = -1
    for idx in range(len(dataframe)):
        source_max_len = max(len(dataframe.loc[idx, 'source_text']), source_max_len)
        target_max_len = max(len(dataframe.loc[idx, 'target_text']), target_max_len)
    return source_max_len, target_max_len

In [None]:
def generate(model, tokenizer, vectorizer, query : str, max_seq_len : int = 100, temperature : float = 1.0, device : str = 'cpu'):
    model.to(device)
    model.eval()

    bos_index = vectorizer.target_vocab._bos_index
    source_mask_index = vectorizer.source_vocab.mask_token_index
    target_mask_index = vectorizer.target_vocab.mask_token_index

    tokenized_lst = tokenizer.tokenize(query)
    vectorized_dict = vectorizer.vectorize(source_tokens=tokenized_lst, use_dataset_max_len=False)

    # pass through encoder
    source = torch.tensor(vectorized_dict['source_vec'], dtype=torch.long).to(device).unsqueeze(0)
    embeded = model.source_embedding(source) * math.sqrt(model.embed_dim)
    pos_embeded = model.pos_encoding_encoder(embeded)
    source_embed_projected = model.embed_to_model_projection(pos_embeded)

    source_key_padding_mask = (source == source_mask_index).to(device)

    encoder_output = model.transformer.encoder(source_embed_projected, src_key_padding_mask=source_key_padding_mask)

    # Инициализация decoder input с <BOS> токеном
    decoder_input = torch.tensor([[vectorizer.target_vocab._bos_index]], 
                                device=device, dtype=torch.long)
        
    # Пошаговая генерация последовательности
    for _ in range(max_seq_len):
        # Forward pass через decoder
        target_embed = model.target_embedding(decoder_input) * math.sqrt(model.embed_dim)
        target_embed = model.pos_encoding_decoder(target_embed)
        target_embed_projected = model.embed_to_model_projection(target_embed)
            
        decoder_output = model.transformer.decoder(
            target_embed_projected,
            encoder_output,
            tgt_mask=subsequent_mask(decoder_input.size(1), device=device),
            tgt_key_padding_mask=(decoder_input == target_mask_index),
            memory_key_padding_mask=source_key_padding_mask)

        # Получение предсказания следующего токена
        logits = model.classifier(decoder_output[:, -1, :])
        probs = F.softmax(logits/temperature, dim=-1)
        next_token = torch.multinomial(probs, 1)
            
        # Добавление нового токена к последовательности
        decoder_input = torch.cat([decoder_input, next_token], dim=1)
            
        # Проверка на окончание последовательности
        if next_token.item() == vectorizer.target_vocab._eos_index:
            return decoder_input
            break

    # Декодирование индексов в текст
    return decoder_input
    # return vectorizer.target_vocab.get_token(next_token.item())


In [None]:
def decode_indices(indices : torch.tensor, vectorizer):
    seq_count, seq_len = (indices.size(0), indices.size(1))
    vocab = vectorizer.target_vocab
    decoded = []
    for seq in range(seq_count):
        string =''
        for idx in range(seq_len):
            index = indices[seq, idx].item()
            if index != vocab.mask_token_index:
                string += vocab.get_token(index) + ' '
            if index == vocab._eos_index:
                break
        decoded.append(string)
    return decoded

In [None]:
def save_model_to_file(model, filepath):
    torch.save(model, filepath)

In [None]:
tokenizer = SeparatorTokenizer()

In [None]:
df = pd.read_csv(os.path.join(DATASET_PATH, 'ru_en_small.csv'), index_col='Unnamed: 0')

df = df.rename(columns={'ru_text' : 'source_text', 'en_text' : 'target_text'})
df['split'] = 'train'
selected_indices = df.sample(int(EVAL_PROPORTION*len(df)), random_state=RANDOM_STATE).index
df.loc[selected_indices, 'split'] = 'validation'

# К нижнему регистру, токенизация и очистка от служебных символов
df['source_text'] = df['source_text'].apply(lambda x: tokenizer.tokenize(x.lower()))
df['target_text'] = df['target_text'].apply(lambda x: tokenizer.tokenize(x.lower()))

In [None]:
# Поиск максимальной длины сурс и таргет текста
# source_max_len = target_max_len = -1
# for i in range(len(df)):
#     source_max_len = max(source_max_len, len(df.loc[i, 'source_text']))
#     target_max_len = max(target_max_len, len(df.loc[i, 'target_text']))
# print(source_max_len)
# print(target_max_len)

In [None]:
source_vocab = Vocabulary()
target_vocab = Vocabulary()

source_freq, target_freq = get_tokens_freq(df)

for key, value in source_freq.items():
    if value > TOKENS_TRESHOLD_FREQ:
        source_vocab.add_token(key)

for key, value in target_freq.items():
    if value > TOKENS_TRESHOLD_FREQ:
        target_vocab.add_token(key)


source_vocab.to_json('data/source_vocab.json')
target_vocab.to_json('data/target_vocab.json')

In [None]:
source_vocab = Vocabulary.from_json('data/source_vocab.json')
target_vocab = Vocabulary.from_json('data/target_vocab.json')

In [None]:
vectorizer = Seq2Seq_Vectorizer(source_vocab, target_vocab, MAX_SOURCE_SEQ_LEN, MAX_TARGET_SEQ_LEN)
dataset = CustomDataset(df, tokenizer, vectorizer)

In [None]:
source_vocab_size = len(source_vocab)
target_vocab_size = len(target_vocab)
mask_index = target_vocab.mask_token_index

In [None]:
# model = TransformerModel(source_vocab_size, target_vocab_size, EMBEDDING_DIM, MODEL_DIM, NUM_HEAD, NUM_ENCODER_LAYERS,\
#                          NUM_DECODER_LAYERS, FC_HIDDEN_DIM, DROPOUT, MAX_SEQ_LEN, BATCH_FIRST, mask_index)

In [None]:
model = torch.load(MODEL_SAVE_FILEPATH, weights_only=False)

In [None]:
model = model.to(DEVICE)

optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=LR_SCHEDULER_FACTOR, patience=LR_SCHEDULER_PATIENCE)

In [None]:
try:
    for epoch in range(EPOCHS):
        dataset.set_dataframe_split('train')
        batch_generator = generate_batches(dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, drop_last=DROP_LAST, device=DEVICE)
        train_running_loss = 0.0
        train_running_acc = 0.0
        epoch_err = 0.0

        model.train()
        
        for batch_index, batch_dict in enumerate(batch_generator):

            optimizer.zero_grad()

            # compute the output
            y_pred = model(batch_dict['source_vec'],
                           batch_dict['target_x_vec'],
                           apply_softmax = False,
                           temperature = TEMPERATURE)

            # compute the loss
            loss = sequence_loss(y_pred, batch_dict['target_y_vec'], mask_index=mask_index)

            # use loss to produce gradients
            loss.backward()

            # use optimizer to take gradient step
            optimizer.step()

            # compute the running loss and running accuracy
            train_running_loss += (loss.item() - train_running_loss) / (batch_index + 1)
            epoch_err += loss.item()

            acc_t = compute_accuracy(y_pred, batch_dict['target_y_vec'], mask_index)
            train_running_acc += (acc_t - train_running_acc) / (batch_index + 1)

        print('-'*40)
        print(f'epoch {epoch+1}')
        print(f'train_epoch_error {epoch_err}')
        print(f'train loss {train_running_loss}   ,   train accuracy {train_running_acc}')


        dataset.set_dataframe_split('validation')
        batch_generator = generate_batches(dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE, drop_last=DROP_LAST, device=DEVICE)
        valid_running_loss = 0.0
        valid_running_acc = 0.0
        epoch_err = 0.0

        model.eval()

        for batch_index, batch_dict in enumerate(batch_generator):
            # compute the output
            y_pred = model(batch_dict['source_vec'],
                           batch_dict['target_x_vec'],
                           apply_softmax = False,
                           temperature = TEMPERATURE)

            # compute the loss
            loss = sequence_loss(y_pred, batch_dict['target_y_vec'], mask_index)

            # compute the running loss and accuracy
            valid_running_loss += (loss.item() - valid_running_loss) / (batch_index + 1)
            epoch_err += loss.item()

            acc_t = compute_accuracy(y_pred, batch_dict['target_y_vec'], mask_index)
            valid_running_acc += (acc_t - valid_running_acc) / (batch_index + 1)

        print(f'validation_epoch_error {epoch_err}')
        print(f'validation loss {valid_running_loss}   ,   validation accuracy {valid_running_acc}')
        
except KeyboardInterrupt:
    print("Exiting loop")

In [None]:
query = 'Я ем яблоко'

In [None]:
indices = generate(model, tokenizer, vectorizer, query, max_seq_len=MAX_SEQ_LEN, temperature=TEMPERATURE, device=DEVICE)
response = decode_indices(indices, vectorizer)

In [None]:
response

In [None]:
save_model_to_file(model, MODEL_SAVE_FILEPATH)