In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter 
import os
from datetime import datetime
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, matthews_corrcoef
from sklearn.utils import class_weight
import warnings
warnings.filterwarnings('ignore')

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
train_df = pd.read_csv("train_ajt_df.csv", index_col=0)
test_df = pd.read_csv("test_ajt_df.csv", index_col=0)

In [4]:
df = pd.read_csv("./ajt_dataset.csv")

In [5]:
train_df = pd.concat([train_df, df.loc[train_df.index]['type_mistake'], ], axis=1)
test_df = pd.concat([test_df, df.loc[test_df.index]['type_mistake'], ], axis=1)

In [6]:
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

In [7]:
label_mapper = {"нет ошибки": -1, "речевая": 0, "стилистическая": 1, "пунктуационная": 2, "грамматическая": 3, "лексическая": 4, "логическая": 5}
reverse_label_mapper = {-1: "нет ошибки", 0: "речевая", 1: "стилистическая", 2: "пунктуационная", 3: "грамматическая", 4: "лексическая", 5: "логическая"}

In [8]:
train_df['type_mistake'] = train_df['type_mistake'].map(label_mapper)
test_df['type_mistake'] = test_df['type_mistake'].map(label_mapper)

In [9]:
# # if ru-en-RosBerta
# prefix = "classification: "
# train_df['text'] = prefix + train_df['text'] 
# test_df['text'] = prefix + test_df['text'] 

In [10]:
class TextDataset(Dataset):
    def __init__(self, texts, binary_labels, error_labels, tokenizer, max_length):
        self.texts = texts
        self.binary_labels = binary_labels
        self.error_labels = error_labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'binary_label': torch.tensor(self.binary_labels[idx], dtype=torch.float),
            'error_label': torch.tensor(self.error_labels[idx], dtype=torch.float)
            }

In [11]:
def get_pooling(outputs, attention_masks, pooling_name="cls", hidden_states=1):
    last_hidden_state = outputs.hidden_states[-1]
    if hidden_states > 1:
      last_hidden_state = torch.cat(tuple([outputs.hidden_states[-i] for i in range(hidden_states, 0, -1)]), dim=-1)

    input_mask_expanded = attention_masks.unsqueeze(-1).expand(last_hidden_state.size()).float()
    if pooling_name == "mean":
        mean_pooling = torch.sum(last_hidden_state * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return mean_pooling
    elif pooling_name == "cls":
        cls_pooling = last_hidden_state[:, 0, :]
        cls_pooling = torch.nn.functional.normalize(cls_pooling)
        return cls_pooling
    elif pooling_name == "max":
       max_pooling, _ = torch.max(last_hidden_state * input_mask_expanded, dim=1)
       return max_pooling


In [12]:
# Model for multi-task learning
class ModelMTL(nn.Module):
    def __init__(self, model_name, num_error_types, num_hidden_states=1, pooling_name="cls", freeze=False):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        self.pooling_name = pooling_name
        self.num_hidden_states = num_hidden_states
        self.freeze = freeze
        self.binary_classifier = nn.Linear(self.num_hidden_states * self.model.config.hidden_size, 2)
        self.error_classifier = nn.Linear(self.num_hidden_states * self.model.config.hidden_size, num_error_types)
        self.dropout = nn.Dropout(0.1)

        # Freeze encoder
        if self.freeze:
            for layer in self.model.encoder.layer[:-3]: 
                for param in layer.parameters():
                    param.requires_grad = False

        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        frozen_params = total_params - trainable_params

        print(f"Trainable parameters: {trainable_params / 1e6:.2f}M")
        print(f"Frozen parameters: {frozen_params / 1e6:.2f}M")

    def forward(self, input_ids, attention_mask):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        pooled_output = get_pooling(outputs, attention_mask, pooling_name=self.pooling_name, hidden_states=self.num_hidden_states)
        pooled_output = self.dropout(pooled_output)
        
        # out for binary clf
        binary_logits = self.binary_classifier(pooled_output)
        # out for multi-class clf
        error_logits = self.error_classifier(pooled_output)
        
        return binary_logits, error_logits

In [19]:
ALPHA = 0.4

binary_target = "is_mistake"
multi_target = "type_mistake"

binary_class_weights = class_weight.compute_class_weight(class_weight='balanced',
                                                    classes=np.unique(train_df[binary_target].values),
                                                    y=train_df[binary_target].values)
multi_class_weights = class_weight.compute_class_weight(class_weight='balanced',
                                                    classes=np.unique(train_df[train_df[multi_target]!=-1][multi_target].values),
                                                    y=train_df[train_df[multi_target]!=-1][multi_target].values)

binary_label_weights = {id: weight for weight, id in zip(np.unique(train_df[binary_target].values), binary_class_weights)}
multi_label_weights = {id: weight for weight, id in zip(np.unique(train_df[train_df[multi_target]!=-1][multi_target].values), multi_class_weights)}

def mtl_loss(binary_logits, error_logits, binary_labels, error_labels, weight_error=ALPHA, loss_type='default'):

    if loss_type == 'balanced':
        binary_weighted_loss = nn.CrossEntropyLoss(weight=torch.from_numpy(binary_class_weights).float().to(device))
        multi_weighted_loss = nn.CrossEntropyLoss(weight=torch.from_numpy(multi_class_weights).float().to(device))
    else:
        binary_weighted_loss = nn.CrossEntropyLoss()
        multi_weighted_loss =  nn.CrossEntropyLoss()

    loss_binary = binary_weighted_loss(binary_logits, binary_labels.long())
    
    mask = (binary_labels == 1)
    if mask.sum() > 0: 
        masked_error_logits = error_logits[mask]
        masked_error_labels = error_labels[mask].long()
        
        loss_error = multi_weighted_loss(
            masked_error_logits,
            masked_error_labels
        )
    else:
        loss_error = torch.tensor(0.0, device=binary_logits.device)
    
    # total_loss = loss_binary + weight_error * loss_error
    total_loss = (1 - weight_error) * loss_binary + weight_error * loss_error
    
    return total_loss, loss_binary, loss_error

In [20]:
# Train
def train(model, dataloader, optimizer, device, writer, epoch, scheduler):
    model.train()
    total_loss = 0
    
    for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        binary_labels = batch['binary_label'].to(device)
        error_labels = batch['error_label'].to(device)
        
        binary_logits, error_logits = model(input_ids, attention_mask)
        loss, loss_bin, loss_err = mtl_loss(
            binary_logits, error_logits, 
            binary_labels, error_labels
        )
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()
        
        total_loss += loss.item()
        
        # Log in tensorboard
        if writer:
            writer.add_scalar('Loss/train', loss.item(), epoch * len(dataloader) + step)
            writer.add_scalar('Loss/train_binary', loss_bin.item(), epoch * len(dataloader) + step)
            writer.add_scalar('Loss/train_error', loss_err.item(), epoch * len(dataloader) + step)
            if scheduler:
                writer.add_scalar('Learning Rate', scheduler.get_last_lr()[0], epoch * len(dataloader) + step)
    
    return total_loss / len(dataloader)

# Validation
def evaluate(model, dataloader, device, writer, epoch):
    model.eval()
    total_loss = 0
    binary_preds, binary_labels = [], []
    error_preds, error_labels = [], []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            binary_labels_batch = batch['binary_label'].to(device)
            error_labels_batch = batch['error_label'].to(device)
            
            binary_logits, error_logits = model(input_ids, attention_mask)
            loss, loss_bin, loss_err = mtl_loss(
                binary_logits, error_logits, 
                binary_labels_batch, error_labels_batch
            )
            
            total_loss += loss.item()
            
            # Save prds
            binary_preds.extend(torch.argmax(binary_logits, dim=1).cpu().numpy())
            binary_labels.extend(binary_labels_batch.cpu().numpy())
            error_preds.extend(torch.argmax(error_logits, dim=1).cpu().numpy())
            error_labels.extend(error_labels_batch.cpu().numpy())
    
    # Metrics for binary clf
    binary_accuracy = accuracy_score(binary_labels, binary_preds)
    mcc = matthews_corrcoef(binary_labels, binary_preds)
    binary_f1 = f1_score(binary_labels, binary_preds, average='binary')
    binary_precision = precision_score(binary_labels, binary_preds,  average='binary')
    binary_recall = recall_score(binary_labels, binary_preds, average='binary')
    
    # Metrics for multi-class clf
    error_mask = [label != -1 for label in error_labels]  # Exclude correct texts
    error_preds_filtered = [p for p, m in zip(error_preds, error_mask) if m]
    error_labels_filtered = [l for l, m in zip(error_labels, error_mask) if m]
    
    if error_labels_filtered:
        error_accuracy = accuracy_score(error_labels_filtered, error_preds_filtered)
        error_f1 = f1_score(error_labels_filtered, error_preds_filtered, average='macro')
        error_mcc = matthews_corrcoef(error_labels_filtered, error_preds_filtered)
        error_report = classification_report(error_labels_filtered, error_preds_filtered)
    else:
        error_accuracy, error_f1, error_report = 0, 0, "No errors in validation set"
    
    # Log in tensorboard
    if writer:
        writer.add_scalar('Loss/val', total_loss / len(dataloader), epoch)
        writer.add_scalar('Accuracy/binary_val', binary_accuracy, epoch)
        writer.add_scalar('MCC/binary_val', mcc, epoch)
        writer.add_scalar('F1/binary_val', binary_f1, epoch)
        writer.add_scalar('Accuracy/error_val', error_accuracy, epoch)
        writer.add_scalar('MCC/error_val', error_mcc, epoch)
        writer.add_scalar('F1/error_val', error_f1, epoch)
    
    # Metrics output
    print(f"Validation Loss: {total_loss / len(dataloader):.3f}")
    print(f"Binary Accuracy: {binary_accuracy:.3f}, F1: {binary_f1:.3f}, MCC: {mcc:.3f}")
    print(f"Error Accuracy: {error_accuracy:.3f}, F1: {error_f1:.3f}, MCC: {error_mcc:.3f}")
    # print("Error Classification Report:")
    # print(error_report)
    
    return total_loss / len(dataloader)

In [21]:
# model_name = "DeepPavlov/rubert-base-cased"
model_name = "RussianNLP/ruRoBERTa-large-rucola"
# model_name = "ai-forever/ru-en-RoSBERTa"
max_length = 128
batch_size = 16
num_error_types = 6
lr = 2e-5
num_epochs = 20

In [22]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

train_texts = train_df.text.values.tolist()
train_labels = train_df.is_mistake.values.tolist()
train_errors = train_df.type_mistake.values.tolist()
train_dataset = TextDataset(train_texts, train_labels, train_errors, tokenizer, max_length)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


val_texts = test_df.text.values.tolist()
val_labels = test_df.is_mistake.values.tolist()
val_errors = test_df.type_mistake.values.tolist()
val_dataset = TextDataset(val_texts, val_labels, val_errors, tokenizer, max_length)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = ModelMTL(model_name, num_error_types, pooling_name='cls', freeze=True).to(device)
optimizer = AdamW(model.parameters(), lr=lr)

Some weights of RobertaModel were not initialized from the model checkpoint at RussianNLP/ruRoBERTa-large-rucola and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: 90.84M
Frozen parameters: 264.52M


In [23]:
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

In [None]:
log_dir = os.path.join('./runs', datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
save_model_path = "./trained_models/" + model_name.split('/')[1] + f"_lr-{lr}"
writer = SummaryWriter(log_dir=log_dir)
print(f"Model name: {model_name.split('/')[1]}")

# Train-val cycle
for epoch in range(num_epochs):
    train_loss = train(model, train_dataloader, optimizer, device, writer, epoch, scheduler=scheduler,)
    print(f"Epoch {epoch + 1}/{num_epochs} Train Loss: {train_loss:.3f}")
    
    # Val
    val_loss = evaluate(model, val_dataloader, device, writer, epoch)
    print(f"Epoch {epoch + 1}/{num_epochs} Validation Loss: {val_loss:.3f}")

    # Save model
    # torch.save(model.state_dict(), f"{save_model_path}_epoch-{epoch}.pt")
    # print(f"Model saved to {save_model_path}_epoch-{epoch}")

writer.close()

In [None]:
# # rubert-base-cased best
# Binary Accuracy: 0.787, F1: 0.731, MCC: 0.600
# Error Accuracy: 0.372, F1: 0.206, MCC: 0.162
# Epoch 13/20 Validation Loss: 1.003

In [None]:
# # ru-en-RoSBERTa
# Binary Accuracy: 0.701, F1: 0.642, MCC: 0.405
# Error Accuracy: 0.298, F1: 0.169, MCC: 0.072
# Epoch 11/20 Validation Loss: 1.047

In [None]:
# # ruRoBERTa-large-rucola
# Binary Accuracy: 0.736, F1: 0.644, MCC: 0.512
# Error Accuracy: 0.394, F1: 0.239, MCC: 0.215
# Epoch 9/20 Validation Loss: 1.024

In [None]:
# # sbert_large_nlu_ru
# Binary Accuracy: 0.772, F1: 0.720, MCC: 0.558
# Error Accuracy: 0.383, F1: 0.171, MCC: 0.176
# Epoch 14/20 Validation Loss: 1.011