In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from scripts.custom_dataset import CustomDataset
from scripts.model import MHAModel
from scripts.tokenizer import SeparatorTokenizer
from scripts.vectorizer import Vectorizer
from scripts.vocabulary import Vocabulary
import pandas as pd
import numpy as np
import json
import os

In [None]:
TEST_PROPORTION = 0.0
EVAL_PROPORTION = 0.0

ADD_BOS_EOS_TOKENS = False

SHUFFLE = True
DROP_LAST = True
EPOCHS = 5
LEARNING_RATE = 1e-4

LR_SCHEDULER_FACTOR = 0.5
LR_SCHEDULER_PATIENCE = 2

USE_PRETRAINED = False

BATCH_SIZE = 32
BIAS = True
EMBEDDING_DIM = 512
ATTENTION_DIM = 768
NUM_HEAD = 12
NUM_ENCODER_LAYERS = 8
ENCODER_FC_HIDDEN_DIM = ATTENTION_DIM*4 # Как в классическом трансформере
CLASSIFIER_FC_HIDDEN_DIM = ATTENTION_DIM*2
DROPOUT = 0.1
TEMPERATURE = 1
BATCH_FIRST = True

MODEL_SAVE_FILEPATH = 'data/model_params.pt'
DATASET_PATH = 'dataset'

RANDOM_STATE = 42

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
print(torch.backends.cuda.flash_sdp_enabled())
print(torch.backends.cuda.mem_efficient_sdp_enabled())
print(torch.backends.cuda.math_sdp_enabled())

True
True
True


In [4]:
def find_max_source_len(dataframe:pd.DataFrame)->int:
    '''Возвращает максимальную длину входной последовательности в датафрейме'''
    max_source_tokens = 0
    for i in range(len(dataframe)):
        max_source_tokens = max(len(dataframe.loc[i, 'source_tokens']), max_source_tokens)
    return max_source_tokens

In [5]:
def generate_batches(dataset:CustomDataset, batch_size:int, shuffle:bool=True, drop_last:bool=True, device='cpu'):
    '''Создает батчи из датасета и переносит данные на девайс.'''
    dataloader = DataLoader(dataset, batch_size, shuffle, drop_last=drop_last)
    for data_dict in dataloader:
        out_data_dict = {}
        for name, _ in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict

In [6]:
def save_results_to_file(model, model_filepath:str, train_states:list=None, validation_states:list=None):
    '''Сохраняет параметры модели и метрики обучения в файлы.'''
    torch.save(model, model_filepath)
    if train_states is not None:
        with open("data/train_states.json", "w", encoding="utf-8") as file:
            json.dump(train_states, file, indent=4, ensure_ascii=False)

    if validation_states is not None:
        with open("data/validation_states.json", "w", encoding="utf-8") as file:
            json.dump(validation_states, file, indent=4, ensure_ascii=False)

In [7]:
def preprocess_df(df:pd.DataFrame, source_column_name:str):
    for row in range(len(df)):
        tokens = df[source_column_name].iloc[row].copy()  # Создаем копию
        for i in range(len(tokens)):
            tokens[i] = tokens[i].lower()
        df.at[row, source_column_name] = tokens

In [8]:
def normalize_sizes(predictions:dict[str:torch.tensor], targets:dict[str:torch.tensor], target_names:list[str]):
    for key in target_names:
        # Для predictions: [B, S, C] -> [B*S, C]
        if len(predictions[key].size()) == 3:
            predictions[key] = predictions[key].contiguous().view(-1, predictions[key].size(-1))
        
        # Для targets: [B, S] -> [B*S]
        if len(targets[key].size()) == 2:
            targets[key] = targets[key].contiguous().view(-1)
    
    return predictions, targets

In [9]:
def compute_loss(predictions:dict[str:torch.tensor], targets:dict[str:list[int]], target_names:list[str], target_weights:dict[str:float], mask_idx:int=0):
    predictions, targets = normalize_sizes(predictions, targets, target_names)
    losses = {}
    total_loss = 0
    for key in target_names:
        losses[key] = torch.nn.functional.cross_entropy(predictions[key], targets[key], ignore_index=mask_idx)
        total_loss += losses[key] * target_weights[key]

    return total_loss, losses

In [10]:
def compute_accuracy(predictions:dict[str:torch.tensor], targets:dict[str:list[int]], target_names:list[str], mask_idx:int=0)->dict[str:float]:
    predictions, targets = normalize_sizes(predictions, targets, target_names)
    
    accuracies = {}
    for key in target_names:
        _, pred_indices = predictions[key].max(dim=1)
        
        correct_indices = torch.eq(pred_indices, targets[key]).float()
        valid_indices = torch.ne(targets[key], mask_idx).float()
        
        n_correct = (correct_indices * valid_indices).sum().item()
        n_valid = valid_indices.sum().item()
        accuracies[key] = n_correct / n_valid 

    return accuracies

In [11]:
train_df = pd.read_parquet(os.path.join(DATASET_PATH, 'ru_syntagrus-ud-train.parquet'))
validation_df = pd.read_parquet(os.path.join(DATASET_PATH, 'ru_syntagrus-ud-dev.parquet'))

In [12]:
MAX_SOURCE_LENGTH = max(find_max_source_len(train_df), find_max_source_len(validation_df)) + 2 # Прибавляем 2 для учета доп. токенов BOS и EOS
target_names = ['upos']
source_name = 'source_tokens'
preprocess_df(train_df, source_name)
preprocess_df(validation_df, source_name)

In [13]:
train_df.head(5)

Unnamed: 0,source_tokens,lemmas,upos,xpos,feats,head,deprel,misc
0,"[анкета, .]","[анкета, .]","[NOUN, PUNCT]","[None, None]","[{'Animacy': 'Inan', 'Case': 'Nom', 'Gender': ...","[0, 1]","[root, punct]","[{'SpaceAfter': 'No'}, None]"
1,"[начальник, областного, управления, связи, сем...","[начальник, областной, управление, связь, Семе...","[NOUN, ADJ, NOUN, NOUN, PROPN, PROPN, AUX, NOU...","[None, None, None, None, None, None, None, Non...","[{'Animacy': 'Anim', 'Case': 'Nom', 'Gender': ...","[8, 3, 1, 3, 1, 5, 8, 0, 8, 11, 8, 13, 11, 11,...","[nsubj, amod, nmod, nmod, appos, flat:name, co...","[None, None, None, None, None, None, None, Non..."
2,"[в, приемной, его, с, утра, ожидали, посетител...","[в, приемная, он, с, утро, ожидать, посетитель...","[ADP, NOUN, PRON, ADP, NOUN, VERB, NOUN, PUNCT...","[None, None, None, None, None, None, None, Non...","[None, {'Animacy': 'Inan', 'Case': 'Loc', 'Gen...","[2, 6, 6, 5, 6, 0, 6, 13, 13, 13, 13, 13, 7, 1...","[case, obl, obj, case, obl, root, nsubj, punct...","[None, None, None, None, None, None, {'SpaceAf..."
3,"[однако, стиль, работы, семена, еремеевича, за...","[однако, стиль, работа, Семен, Еремеевич, закл...","[ADV, NOUN, NOUN, PROPN, PROPN, VERB, ADP, PRO...","[None, None, None, None, None, None, None, Non...","[{'Degree': 'Pos'}, {'Animacy': 'Inan', 'Case'...","[6, 6, 2, 3, 4, 0, 8, 6, 11, 11, 8, 13, 11, 16...","[advmod, nsubj, nmod, nmod, flat:name, root, c...","[None, None, None, None, None, None, None, {'S..."
4,"[приемная, была, обставлена, просто, ,, но, по...","[приемная, быть, обставить, просто, ,, но, по-...","[NOUN, AUX, VERB, ADV, PUNCT, CCONJ, ADV, PUNCT]","[None, None, None, None, None, None, None, None]","[{'Animacy': 'Inan', 'Case': 'Nom', 'Gender': ...","[3, 3, 0, 3, 7, 7, 4, 3]","[nsubj:pass, aux:pass, root, advmod, punct, cc...","[None, None, None, {'SpaceAfter': 'No'}, None,..."


In [14]:
source_vocab = Vocabulary(add_bos_eos_tokens=ADD_BOS_EOS_TOKENS)
target_vocabs = {target_name: Vocabulary(add_bos_eos_tokens=ADD_BOS_EOS_TOKENS) for target_name in target_names}
for i in range(len(train_df)):
    source_vocab.add_tokens(train_df[source_name].iloc[i])
    for target_name in target_names:
        target_vocabs[target_name].add_tokens(train_df[target_name].iloc[i])

mask_index = source_vocab.mask_idx
source_vocab_len = len(source_vocab)
cls_names_params = {key:len(target_vocabs[key]) for key in target_names}
target_weights = {key : 1.0 for key in target_names}

In [15]:
print(f'Количество батчей = {len(train_df)//BATCH_SIZE}')

print(f'Длина словаря токенов = {len(source_vocab)}')
for key in target_names:
    print(f'Длина словаря признака {key} = {len(target_vocabs[key])}')

Количество батчей = 2175
Длина словаря токенов = 121665
Длина словаря признака upos = 20


In [16]:
if USE_PRETRAINED:
    with open("data/train_states.json", "r", encoding="utf-8") as file:
        train_states = json.load(file)

    with open("data/validation_states.json", "r", encoding="utf-8") as file:
        validation_states = json.load(file)
    
    model = torch.load(MODEL_SAVE_FILEPATH, weights_only=False)
else:
    train_states = []
    validation_states = []
    model = MHAModel(MAX_SOURCE_LENGTH, source_vocab_len, EMBEDDING_DIM, ATTENTION_DIM, NUM_HEAD, NUM_ENCODER_LAYERS, CLASSIFIER_FC_HIDDEN_DIM, ENCODER_FC_HIDDEN_DIM,\
                     cls_names_params, DROPOUT, TEMPERATURE, BATCH_FIRST, BIAS, mask_index)

In [17]:
vectorizer = Vectorizer(source_vocab, target_vocabs, MAX_SOURCE_LENGTH, mask_index)
dataset = CustomDataset(vectorizer, train_df, target_names, add_bos_eos_tokens=ADD_BOS_EOS_TOKENS, valid_df=validation_df)

model = model.to(device=DEVICE)
optimizer = optim.Adam(model.parameters(), LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=LR_SCHEDULER_FACTOR, patience=LR_SCHEDULER_PATIENCE)

In [18]:
try:
    for epoch in range(EPOCHS):
        dataset.set_dataframe_split('train')
        batch_generator = generate_batches(dataset, BATCH_SIZE, SHUFFLE, DROP_LAST, DEVICE)
        epoch_sum_train_loss = 0.0
        epoch_running_train_loss = 0.0
        train_running_acc = {key:0.0 for key in target_names}
        model.train()
        for batch_idx, batch_dict in enumerate(batch_generator):
            optimizer.zero_grad()
            
            predictions = model(batch_dict['source_x'])

            total_loss, losses = compute_loss(predictions, batch_dict, target_names, target_weights, mask_index)

            total_loss.backward()

            scheduler.step(epoch_running_train_loss)

            optimizer.step()

            # Средние потери и точность
            epoch_running_train_loss += (total_loss.item() - epoch_running_train_loss) / (batch_idx + 1)
            epoch_sum_train_loss += total_loss.item()

            acc_t = compute_accuracy(predictions, batch_dict, target_names, mask_index)
            for key in target_names:
                train_running_acc[key] += (acc_t[key] - train_running_acc[key]) / (batch_idx + 1)

        dataset.set_dataframe_split('validation')
        batch_generator = generate_batches(dataset, BATCH_SIZE, SHUFFLE, DROP_LAST, DEVICE)
        epoch_sum_valid_loss = 0.0
        epoch_running_valid_loss = 0.0
        valid_running_acc = {key:0.0 for key in target_names}
        model.eval()

        with torch.no_grad():
            for batch_idx, batch_dict in enumerate(batch_generator):
                
                predictions = model(batch_dict['source_x'])

                total_loss, losses = compute_loss(predictions, batch_dict, target_names, target_weights, mask_index)

                # Средние потери и точность
                epoch_running_valid_loss += (total_loss.item() - epoch_running_valid_loss) / (batch_idx + 1)
                epoch_sum_valid_loss += total_loss.item()

                acc_t = compute_accuracy(predictions, batch_dict, target_names, mask_index)
                for key in target_names:
                    valid_running_acc[key] += (acc_t[key] - valid_running_acc[key]) / (batch_idx + 1)

        train_states.append({'epoch' : epoch+1, 'epoch_sum_train_loss' : epoch_sum_train_loss, 'epoch_running_train_loss' : epoch_running_train_loss, 'accuracy' : train_running_acc})
        validation_states.append({'epoch' : epoch+1, 'epoch_sum_valid_loss' : epoch_sum_valid_loss, 'epoch_running_valid_loss' : epoch_running_valid_loss, 'accuracy' : valid_running_acc})
        
        print('-'*40)
        print(f'Epoch {epoch+1}')
        print(f'Train: Суммированная ошибка эпохи {epoch_sum_train_loss}')
        print(f'Train: Средняя ошибка эпохи {epoch_running_train_loss}')
        for key in target_names:
            print(f'Train: Точность на признаке {key}: {train_running_acc[key]*100}')

        print('-'*10)
        print(f'Validation: Суммированная ошибка эпохи {epoch_sum_valid_loss}')
        print(f'Validation: Средняя ошибка эпохи {epoch_running_valid_loss}')
        for key in target_names:
            print(f'Validation: Точность на признаке {key}: {valid_running_acc[key]*100}')

except KeyboardInterrupt:
    print('Принудительная остановка')

----------------------------------------
Epoch 1
Train: Суммированная ошибка эпохи 5274.725795984268
Train: Средняя ошибка эпохи 2.425161285510011
Train: Точность на признаке upos: 35.84785403339689
----------
Validation: Суммированная ошибка эпохи 645.2154250144958
Validation: Средняя ошибка эпохи 2.3209187950161736
Validation: Точность на признаке upos: 37.95035796643219
----------------------------------------
Epoch 2
Train: Суммированная ошибка эпохи 4937.2707624435425
Train: Средняя ошибка эпохи 2.270009545951055
Train: Точность на признаке upos: 38.20643496385146
----------
Validation: Суммированная ошибка эпохи 611.0797562599182
Validation: Средняя ошибка эпохи 2.1981286196399936
Validation: Точность на признаке upos: 39.53194746407379
----------------------------------------
Epoch 3
Train: Суммированная ошибка эпохи 4708.779181003571
Train: Средняя ошибка эпохи 2.1649559452889924
Train: Точность на признаке upos: 40.02086826406932
----------
Validation: Суммированная ошибка эпо

In [19]:
save_results_to_file(model, MODEL_SAVE_FILEPATH, train_states)