In [1]:
!pip install -U -q torchmetrics transformers wandb

In [1]:
import os
import pandas as pd
import torch
import json
from torch.utils.data import (TensorDataset,
                              Dataset,
                              DataLoader,
                              RandomSampler,
                              SequentialSampler)
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from transformers import BertTokenizerFast, BertForSequenceClassification, BertModel
from sklearn.model_selection import train_test_split
from transformers import get_linear_schedule_with_warmup, set_seed
import torchmetrics
from sklearn.metrics import classification_report
from torch.optim import AdamW
import matplotlib.pyplot as plt
import wandb

In [3]:
torch.__version__

'2.0.1'

In [4]:
from transformers import __version__
__version__

'4.35.2'

In [5]:
!nvidia-smi

Sat Nov 25 07:16:48 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A10G                    On  | 00000000:00:1E.0 Off |                    0 |
|  0%   24C    P8              16W / 300W |      2MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
df = pd.read_csv('../cleaned.csv')

In [3]:
df.head()

Unnamed: 0,Исполнитель,Группа тем,Текст инцидента,Тема
0,Лысьвенский городской округ,Благоустройство,"Сегодня, 20.08.22, моя мать шла по улице Ленин...",★ Ямы во дворах
1,Министерство социального развития ПК,Социальное обслуживание и защита,"Пермь г, +79194692145. В Перми с ноября 2021 г...",Оказание гос. соц. помощи
2,Министерство социального развития ПК,Социальное обслуживание и защита,Скажите пожалуйста если подовала на пособие с ...,Дети и многодетные семьи
3,Город Пермь,Общественный транспорт,Каждая из них не о чем. Люди на остановках хот...,Содержание остановок
4,Министерство здравоохранения,Здравоохранение/Медицина,В Березниках у сына привитого откоронавируса з...,Технические проблемы с записью на прием к врачу


In [4]:
executor_label = df['Исполнитель'].unique().tolist()
theme_group_label = df['Группа тем'].unique().tolist()
theme_label = df['Тема'].unique().tolist()

In [5]:
executor2idx = {l:i for i, l in enumerate(executor_label)}
theme_group2idx = {l:i for i, l in enumerate(theme_group_label)}
theme2idx = {l:i for i, l in enumerate(theme_label)}

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
print(device.type)
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))

cuda
NVIDIA A10G


In [7]:
MODEL_NAME = 'ai-forever/ruBert-large'
SEED = 42
EPOCHS = 16
BATCH_SIZE = 8
LEARNING_RATE = 3e-5
MAX_LEN = 390
DROPOUT = .1
WARMUP_STEPS = 0.1

set_seed(seed=SEED)

In [8]:
def get_tensors(texts, executor_labels, theme_group_labels, theme_labels):
    inputs = tokenizer(
        texts,
        padding='max_length',
        truncation=True,
        max_length=MAX_LEN,
        return_token_type_ids=False,
        return_tensors='pt')
    executor_labels = [executor2idx[l] for l in executor_labels]
    theme_group_labels = [theme_group2idx[l] for l in theme_group_labels]
    theme_labels = [theme2idx[l] for l in theme_labels]

    executor_labels = torch.tensor(executor_labels, dtype=torch.long)
    theme_group_labels = torch.tensor(theme_group_labels, dtype=torch.long)
    theme_labels = torch.tensor(theme_labels, dtype=torch.long)

    # assert len(inputs) == len(executor_labels) == len(theme_group_labels) == len(theme_labels)
    return inputs['input_ids'], inputs['attention_mask'], executor_labels, theme_group_labels, theme_labels

In [9]:
class JointClassifier(nn.Module):
    def __init__(self, model_name: str = 'ai-forever/ruBert-large', dropout: float = .1):
        super(JointClassifier, self).__init__()
        self.language_model = BertModel.from_pretrained(model_name)
        self.executor_cls = nn.Sequential(
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.SiLU(),
            nn.Linear(1024, len(executor_label)),
        )
        self.theme_group_cls = nn.Sequential(
            nn.BatchNorm1d(1024),
            nn.Dropout(p=dropout),
            nn.SiLU(),
            nn.Linear(1024, len(theme_group_label)),
        )
        self.theme_cls = nn.Sequential(
            nn.BatchNorm1d(1024 + len(theme_group_label)),
            nn.Dropout(p=dropout),
            nn.SiLU(),
            nn.Linear(1024 + len(theme_group_label), len(theme_label)),
        )

    def forward(self, input_ids, attention_mask):
        output = self.language_model(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        executor_logits = self.executor_cls(output)
        theme_group_logits = self.theme_group_cls(output)
        theme_inputs = torch.cat(
            (output, theme_group_logits), 1
        )
        theme_logits = self.theme_cls(theme_inputs)
        return executor_logits, theme_group_logits, theme_logits

In [10]:
def train():
    model.train()
    total_loss = 0
    executor_F1 = torchmetrics.classification.MulticlassF1Score(
        num_classes=len(executor_label),
        average='weighted'
    )
    theme_group_F1 = torchmetrics.classification.MulticlassF1Score(
        num_classes=len(theme_group_label),
        average='weighted'
    )
    theme_F1 = torchmetrics.classification.MulticlassF1Score(
        num_classes=len(theme_label),
        average='weighted'
    )
    for input_ids, attention_mask, executor_labels, theme_group_labels, theme_labels in train_dataloader:
        optimizer.zero_grad()

        executor_logits, theme_group_logits, theme_logits = model(input_ids.to(device), attention_mask.to(device))
        executor_loss = criterion(executor_logits, executor_labels.to(device))
        theme_group_loss = criterion(theme_group_logits, theme_group_labels.to(device))
        theme_loss = criterion(theme_logits, theme_labels.to(device))
        loss = (executor_loss + theme_group_loss + theme_loss) / 3
        wandb.log({'step_loss': loss.item()})
        total_loss += loss.item()

        executor_preds = torch.argmax(torch.softmax(executor_logits.detach().cpu(), dim=-1), dim=-1)
        theme_group_preds = torch.argmax(torch.softmax(theme_group_logits.detach().cpu(), dim=-1), dim=-1)
        theme_preds = torch.argmax(torch.softmax(theme_logits.detach().cpu(), dim=-1), dim=-1)

        wandb.log({'executor_F1': executor_F1(executor_preds, executor_labels).item()})
        wandb.log({'theme_group_F1': theme_group_F1(theme_group_preds, theme_group_labels).item()})
        wandb.log({'theme_F1': theme_F1(theme_preds, theme_labels).item()})
        
        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if WARMUP_STEPS > 0:
            scheduler.step()
    
    total_loss /= len(train_dataloader)
    wandb.log({'epoch_train_loss': total_loss})
    wandb.log({'epoch_train_executor_F1': executor_F1.compute().item()})
    wandb.log({'epoch_train_theme_group_F1': theme_group_F1.compute().item()})
    wandb.log({'epoch_train_theme_F1': theme_F1.compute().item()})

In [11]:
@torch.no_grad()
def evaluate():
    model.eval()
    total_loss = 0
    executor_F1 = torchmetrics.classification.MulticlassF1Score(
        num_classes=len(executor_label),
        average='weighted'
    )
    theme_group_F1 = torchmetrics.classification.MulticlassF1Score(
        num_classes=len(theme_group_label),
        average='weighted'
    )
    theme_F1 = torchmetrics.classification.MulticlassF1Score(
        num_classes=len(theme_label),
        average='weighted'
    )
    for input_ids, attention_mask, executor_labels, theme_group_labels, theme_labels in test_dataloader:
        executor_logits, theme_group_logits, theme_logits = model(input_ids.to(device), attention_mask.to(device))
        executor_loss = criterion(executor_logits, executor_labels.to(device))
        theme_group_loss = criterion(theme_group_logits, theme_group_labels.to(device))
        theme_loss = criterion(theme_logits, theme_labels.to(device))
        loss = (executor_loss + theme_group_loss + theme_loss) / 3
        total_loss += loss.item()

        executor_preds = torch.argmax(torch.softmax(executor_logits.detach().cpu(), dim=-1), dim=-1)
        theme_group_preds = torch.argmax(torch.softmax(theme_group_logits.detach().cpu(), dim=-1), dim=-1)
        theme_preds = torch.argmax(torch.softmax(theme_logits.detach().cpu(), dim=-1), dim=-1)

        executor_F1(executor_preds, executor_labels)
        theme_group_F1(theme_group_preds, theme_group_labels)
        theme_F1(theme_preds, theme_labels)
    total_loss /= len(test_dataloader)
    wandb.log({'epoch_test_loss': total_loss})
    wandb.log({'epoch_test_executor_F1': executor_F1.compute().item()})
    wandb.log({'epoch_test_theme_group_F1': theme_group_F1.compute().item()})
    wandb.log({'epoch_test_theme_F1': theme_F1.compute().item()})

In [12]:
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

In [13]:
df_train, df_test = train_test_split(df, random_state=SEED, test_size=.1, stratify=df['Тема'])

In [14]:
# texts, executor_labels, theme_group_labels, theme_labels
train_data = TensorDataset(*get_tensors(
    df_train['Текст инцидента'].to_list(),
    df_train['Исполнитель'].to_list(),
    df_train['Группа тем'].to_list(),
    df_train['Тема'].to_list()
    ))
train_dataloader = DataLoader(
    train_data,
    sampler=RandomSampler(train_data),
    batch_size=BATCH_SIZE
)

test_data = TensorDataset(*get_tensors(
    df_test['Текст инцидента'].to_list(),
    df_test['Исполнитель'].to_list(),
    df_test['Группа тем'].to_list(),
    df_test['Тема'].to_list()
))
test_dataloader = DataLoader(
    test_data,
    sampler=SequentialSampler(test_data),
    batch_size=BATCH_SIZE
)

In [15]:
model = JointClassifier()
model.to(device)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
if WARMUP_STEPS > 0:
    scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=len(train_dataloader) * WARMUP_STEPS * EPOCHS,
        num_training_steps=len(train_dataloader) * EPOCHS
    )
criterion = torch.nn.CrossEntropyLoss()

In [16]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mblanchefort[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [17]:
wandb.init(
    project='loon-bit-loop-text-classifier'
)

In [None]:
os.makedirs('../models', exist_ok=True)
for epoch in range(EPOCHS):
    train()
    evaluate()
    model.to('cpu')
    path = f'../models/{epoch}.pt'
    torch.save(model.state_dict(), path)
    model.to(device)

In [None]:
1