In [1]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset, DatasetDict

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Загружаем модель и токенизатор
model_name = "cointegrated/rubert-tiny-toxicity"
num_labels = 2  # Бинарная классификация: токсичный/не токсичный

from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Добавляем параметр ignore_mismatched_sizes=True
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, 
    num_labels=num_labels,
    ignore_mismatched_sizes=True  # Игнорируем несоответствие размеров
)

# # Замораживаем базовые слои
# for param in model.base_model.parameters():
#     param.requires_grad = False

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cointegrated/rubert-tiny-toxicity and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([5]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([5, 312]) in the checkpoint and torch.Size([2, 312]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
import pandas as pd
df = pd.read_csv('data/labeled.csv')
# Преобразование столбца float_column в int
df["toxic"] = df["toxic"].astype(int)

In [15]:
from sklearn.model_selection import train_test_split

# Разделяем данные на train и val
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df["comment"].tolist(), df["toxic"].tolist(), test_size=0.2, random_state=42
)

# Токенизация
def tokenize_function(examples):
    return tokenizer(examples["comment"], padding="max_length", truncation=True)

# Создаем датасеты
train_dataset = Dataset.from_dict({"comment": train_texts, "toxic": train_labels})
val_dataset = Dataset.from_dict({"comment": val_texts, "toxic": val_labels})

# Применяем токенизацию
train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)

# Удаляем столбец text, оставляем только токены и метки
train_dataset = train_dataset.remove_columns(["comment"])
val_dataset = val_dataset.remove_columns(["comment"])

# Преобразуем метки в формат PyTorch
train_dataset.set_format("torch")
val_dataset.set_format("torch")

Map: 100%|██████████| 11529/11529 [00:00<00:00, 11679.95 examples/s]
Map: 100%|██████████| 2883/2883 [00:00<00:00, 10642.10 examples/s]


In [30]:
device = 'mps'

In [31]:
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from tqdm import tqdm

class CustomModel(nn.Module):
    def __init__(self, base_model_name, num_labels=2):
        super(CustomModel, self).__init__()
        # Загружаем модель
        self.model = AutoModelForSequenceClassification.from_pretrained(base_model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
        
        # Замораживаем все слои, кроме последнего
        for param in self.model.base_model.parameters():
            param.requires_grad = False
        
        # Меняем последний слой (linear)
        self.model.classifier = nn.Linear(self.model.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask, token_type_ids):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

class ToxicityTrainer:
    def __init__(self, model, train_dataset, val_dataset, batch_size=16, lr=2e-5, epochs=3):
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.batch_size = batch_size
        self.lr = lr
        self.epochs = epochs
        
        # Даталоадеры для тренировочных и валидационных данных
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        self.val_loader = DataLoader(val_dataset, batch_size=self.batch_size)
        
        # Оптимизатор
        self.optimizer = AdamW(self.model.parameters(), lr=self.lr)
        
        # Функция потерь для бинарной классификации
        self.loss_fn = nn.BCEWithLogitsLoss()

    def train(self):
        for epoch in range(self.epochs):
            self.model.train()
            running_loss = 0.0
            
            # Обучаем модель
            for batch in tqdm(self.train_loader, desc=f"Training epoch {epoch+1}"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                token_type_ids = batch['token_type_ids'].to(device)
                labels = batch['labels'].to(device).float()  # Преобразуем метки в тип float для BCEWithLogitsLoss

                # Обнуляем градиенты
                self.optimizer.zero_grad()
                
                # Пропускаем данные через модель
                outputs = self.model(input_ids, attention_mask, token_type_ids)
                logits = outputs.logits
                
                # Рассчитываем потерю
                loss = self.loss_fn(logits.view(-1), labels.view(-1))  # Вычисляем loss
                loss.backward()
                self.optimizer.step()
                
                running_loss += loss.item()
            
            avg_train_loss = running_loss / len(self.train_loader)
            print(f"Epoch {epoch+1} - Loss: {avg_train_loss:.4f}")
            
            # Проводим валидацию
            self.validate()

    def validate(self):
        self.model.eval()
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validating"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                token_type_ids = batch['token_type_ids'].to(device)
                labels = batch['labels'].to(device)

                outputs = self.model(input_ids, attention_mask, token_type_ids)
                logits = outputs.logits
                
                preds = torch.sigmoid(logits).cpu().numpy()  # Для бинарной классификации применяем sigmoid
                all_preds.extend(preds)
                all_labels.extend(labels.cpu().numpy())
        
        # Применяем порог для бинарной классификации (0.5)
        all_preds = [1 if pred >= 0.5 else 0 for pred in all_preds]
        
        # Рассчитываем метрики
        accuracy = accuracy_score(all_labels, all_preds)
        print(f"Validation Accuracy: {accuracy:.4f}")

    def save_model(self, save_path):
        torch.save(self.model.state_dict(), save_path)
        print(f"Model saved to {save_path}")

# Инициализация
model_name = "cointegrated/rubert-tiny-toxicity"  # Модель
num_labels = 2  # Бинарная классификация

# Загружаем токенизатор и датасеты
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Датасеты, которые вы подготовили для обучения (train_dataset, val_dataset)

# Создаем модель
model = CustomModel(base_model_name=model_name, num_labels=num_labels)

# Создаем класс тренера
trainer = ToxicityTrainer(model=model, train_dataset=train_dataset, val_dataset=val_dataset, epochs=5, batch_size=32)

# Запуск обучения
trainer.train()

# Сохраняем модель
trainer.save_model("toxicity_model.pth")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cointegrated/rubert-tiny-toxicity and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([5]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([5, 312]) in the checkpoint and torch.Size([2, 312]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training epoch 1:   0%|          | 0/361 [00:00<?, ?it/s]


KeyError: 'labels'