In [1]:
!pip3 install transformers
!pip3 install datasets
!pip3 install textaugment
!pip3 install gensim==3.6.0



In [2]:
import nltk
nltk.download("stopwords")
nltk.download("wordnet")
nltk.download("omw-1.4")
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/kutuzov/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/kutuzov/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/kutuzov/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/kutuzov/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

Загрузим необходимый датасет и посмотрим, сколько в нем элементов и как в них сбалансированы классы

In [3]:
from datasets import load_dataset
from collections import Counter

rte_dataset = load_dataset("glue", "rte")

for data_type in rte_dataset.keys():
    counts = Counter(rte_dataset[data_type]['label'])
    print(f"{data_type}: {','.join([f'{key} - {value}' for key, value in counts.items()])}")

Reusing dataset glue (/home/kutuzov/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

train: 1 - 1241,0 - 1249
validation: 1 - 131,0 - 146
test: -1 - 3000


Посмотрим максимальную длину предложения в символах:

In [4]:
from typing import List, Dict
import random

import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from datasets import concatenate_datasets, DatasetDict

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [5]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

Проведем аугментацию данных, чтобы увеличить их количество

In [None]:
from gensim import downloader
from textaugment.word2vec import Word2vec

def augment_dataset(augmenter):
    def augment(x):
        return {"sentence1": augmenter(x["sentence1"]),
                "sentence2": augmenter(x["sentence2"]),
                "label": x["label"]}
    
    return augment

def random_swap_augment(sentence):
    for i in range(3):
        sentence = swap_augmenter.random_swap(sentence)
    return sentence 
    

word2vec_model = downloader.load('fasttext-wiki-news-subwords-300')
word2vec_augmenter = Word2vec(model=word2vec_model, p=0.5)

augmented_train_data_w2v = rte_dataset["train"].map(augment_dataset(word2vec_augmenter.augment))
augmented_train_data_swaper = rte_dataset["train"].map(augment_dataset(random_swap_augment))
augmented_train_data = concatenate_datasets((augmented_train_data_w2v, augmented_train_data_swaper))

augmented_rte_dataset = DatasetDict()
augmented_rte_dataset["train"] = concatenate_datasets((rte_dataset["train"], augmented_train_data))
augmented_rte_dataset["validation"] = rte_dataset["validation"]

0ex [00:00, ?ex/s]

In [None]:
class EntailmentDataset(Dataset):
    def __init__(self, data_container: Dataset):
        self.data = data_container
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row_data = self.data[idx]
        
        sentence_1 = row_data["sentence1"]
        sentence_2 = row_data["sentence2"]
        sentences = sentence_1 + sentence_2
        y = row_data["label"]
        return sentences, y

n_classes = 2

In [None]:
def prepare_sentences_for_prediction(sentences, tokenizer):
    encoded_sentences = tokenizer(list(sentences), padding=True, truncation=True, return_tensors='pt')
    encoded_sentences = {key: value.to(device) for key, value in encoded_sentences.items()}
    return encoded_sentences

def train_fn(model, tokenizer, dataloader: DataLoader,
             loss_fn: callable, metrics: Dict[str, callable], optimizer):
    model.train()
    
    running_loss = 0
    running_losses = []
    metrics_values = {key: 0 for key in metrics}
    
    for x, y in dataloader:
        x = prepare_sentences_for_prediction(x, tokenizer)
        
        optimizer.zero_grad()
        
        output = model(**x).logits
        y_preds = output.argmax(axis=-1)
        
        loss = loss_fn(output, y.to(device))
        loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        running_losses.append(loss.item())
        running_loss += loss.item()
        
        for metric_name in metrics:
            if "precision" in metric_name or "recall" in metric_name:
                metrics_values[metric_name] += metrics[metric_name](y, y_preds.cpu(), zero_division=0)
            else:
                metrics_values[metric_name] += metrics[metric_name](y, y_preds.cpu())
            
    return running_loss / len(dataloader),{key: value / len(dataloader) for key, value in metrics_values.items()}

In [None]:
def calculate_weights(labels):
    counts = Counter(labels)
    
    weights = [0 for _ in range(len(counts))]
    
    for key, value in counts.items():
        weights[key] = 2 * len(labels) / value
        
    weights = torch.FloatTensor(weights)
    
    return weights


def eval_fn(model, tokenizer, dataloader: DataLoader,
             loss_fn: callable, metrics: Dict[str, callable]):
    model.eval()
    
    mean_prediction = 0
    running_loss = 0
    metrics_values = {key: 0 for key in metrics}
    
    for x, y in dataloader:
        x = prepare_sentences_for_prediction(x, tokenizer)
        
        with torch.no_grad():
            output = model(**x).logits
        y_preds = output.argmax(axis=-1)
        
        loss = loss_fn(output, y.to(device))
        running_loss += loss.item()
        
        for metric_name in metrics:
            if "precision" in metric_name or "recall" in metric_name:
                metrics_values[metric_name] += metrics[metric_name](y, y_preds.cpu(), zero_division=0)
            else:
                metrics_values[metric_name] += metrics[metric_name](y, y_preds.cpu())
        
        mean_prediction += y_preds.cpu().sum()
    return running_loss / len(dataloader),{key: value / len(dataloader) for key, value in metrics_values.items()}

def f1_score_wrapper(n_classes):
    if n_classes != 2:
        return lambda x, y: f1_score(x, y, average="micro")
    
    return f1_score

def check_hypothesis(model_name, dataset, epochs, batch_size=32,
                     weighted=False, metrics_for_logging=dict(),
                     lr=1e-2, n_classes=2):
    model_config = AutoConfig.from_pretrained(model_name, num_labels=n_classes)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, config=model_config,
                                                              ignore_mismatched_sizes=True)
    model = model.to(device)
    
    train_dataset = EntailmentDataset(dataset["train"])
    eval_dataset = EntailmentDataset(dataset["validation"])
    
    if weighted:
        weights = calculate_weights(dataset["train"]["label"])
    else:
        weights = torch.FloatTensor([1 for _ in range(n_classes)])
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=32, shuffle=True)
    
    train_losses, train_metrics, eval_losses, eval_metrics = [], [], [], []
    
    for key in metrics_for_logging.keys():
        if "f1" in key:
            metrics_for_logging[key] = f1_score_wrapper(n_classes)
            
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = MultiStepLR(optimizer, milestones=[2, 4, 6, 8], gamma=0.25)
    
    for epoch in tqdm(range(epochs)):
        train_loss, train_metric = train_fn(model, tokenizer, train_dataloader,
                                            CrossEntropyLoss(weights.to(device)), metrics_for_logging, optimizer)
        
        eval_loss, eval_metric = eval_fn(model, tokenizer,
                                         eval_dataloader, CrossEntropyLoss(weights.to(device)), 
                                         metrics_for_logging)
        scheduler.step()
        
        train_losses.append(train_loss)
        train_metrics.append(train_metric)
        
        eval_losses.append(eval_loss)
        eval_metrics.append(eval_metric)
    return model, train_losses, train_metrics, eval_losses, eval_metrics

In [None]:
logging_data = {x: [] for x in ("Model name", "Dataset", "Loss", "Lr",
                               "accuracy", "f1", "precision", "recall")}

for model_name in ["cardiffnlp/twitter-roberta-base-sentiment", "facebook/bart-base"]:
    for lr in [2e-5, 5e-5, 1e-4]:
        hypothesis_result = check_hypothesis(model_name, augmented_rte_dataset, 5, 32, True,
                                             {"accuracy": accuracy_score, "f1": f1_score,
                                              "precision":precision_score, "recall":recall_score},
                                             lr=lr)
        
        model, train_losses, train_metrics, eval_losses, eval_metrics = hypothesis_result
            
        logging_data["Model name"].append(model_name)
        logging_data["Dataset"].append("rte")
        logging_data["Loss"].append(min(eval_losses))
        logging_data["Lr"].append(lr)
            
        logging_data["accuracy"].append(max(map(lambda x: x["accuracy"], eval_metrics)))
            
        f1_maximum_index = np.argmax(list(map(lambda x: x["f1"], eval_metrics)))
        
        for key, value in eval_metrics[f1_maximum_index].items():
            if key != "accuracy":
                logging_data[key].append(value)
        
        for key, value in logging_data.items():
            print(f"{key}: {value[-1]}", end=", ")
            print()

In [None]:
import pandas as pd

df = pd.DataFrame(logging_data)
df

In [None]:
df.to_csv("logs_entailment.csv", index=False)