In [1]:
import datasets
from datasets import Dataset
import torch
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy, sigmoid, softmax
from torch.optim import Adam
import transformers
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
import numpy as np
import pandas as pd
import json
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
import evaluate
import pickle

**Конфигурация**

In [2]:
max_snt_len = 256
models = [
    '''michellejieli/emotion_text_classifier''',
    '''microsoft/deberta-v3-base''',
    '''sileod/deberta-v3-base-tasksource-nli''',
    '''microsoft/deberta-v3-xsmall'''
]
model_id = 2
model_name = models[model_id]

**Фиксируем инициализацию рандом генераторов для воспроизводимости результатов**

In [3]:
def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed()

**Читаем заранее подготовленные и предобработанные датасеты Emocause и EmpatheticDialogues**

In [4]:
with open('./data/emocause/data.pickle', 'rb') as f:
    emocause_train, emocause_val, emocause_test = pickle.load(f)
with open('./data/empatheticdialogues/data.pickle', 'rb') as f:
    empdia_train, empdia_val, empdia_test = pickle.load(f)

**Создаем даталодеры**

In [5]:
torch.cuda.empty_cache()
batch_size = 16
num_workers = 2
emocause_train_dataloader = DataLoader(dataset=emocause_train, shuffle=True,
                             batch_size=batch_size, drop_last=True, num_workers=num_workers)
emocause_val_dataloader = DataLoader(dataset=emocause_val, shuffle=False,
                             batch_size=batch_size, drop_last=False, num_workers=num_workers)
emocause_test_dataloader = DataLoader(dataset=emocause_test, shuffle=False,
                             batch_size=batch_size, drop_last=False, num_workers=num_workers)
empdia_train_dataloader = DataLoader(dataset=empdia_train, shuffle=True,
                             batch_size=batch_size, drop_last=True, num_workers=num_workers)
empdia_val_dataloader = DataLoader(dataset=empdia_val, shuffle=False,
                             batch_size=batch_size, drop_last=False, num_workers=num_workers)
empdia_test_dataloader = DataLoader(dataset=empdia_test, shuffle=False,
                             batch_size=batch_size, drop_last=False, num_workers=num_workers)

**Определяем модели**

In [6]:
model_emo = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
model_classification = AutoModelForSequenceClassification.from_pretrained(
                        model_name, num_labels=32, ignore_mismatched_sizes=True)

Some weights of DebertaV2ForTokenClassification were not initialized from the model checkpoint at sileod/deberta-v3-base-tasksource-nli and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at sileod/deberta-v3-base-tasksource-nli and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([32, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([32]) in the model instantiated
You should probably TRAI

Получаем интерфейс 2-ух моделей, но основная часть у них одинаковая и разные только головы. На самом деле, можно смотреть на это как на одну модель с 2-мя головами.

In [7]:
model_classification.deberta = model_emo.deberta

model_emo.cuda()
model_classification.cuda()

DebertaV2ForSequenceClassification(
  (deberta): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 768, padding_idx=0)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): StableDropout()
              (dropout): StableDropout()
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine

**Процесс обучения и тестирования токен классификации (Emocause)**

In [8]:
# emo_weights = torch.tensor([1., 4.]).cuda()
def emocause_train_loop(dataloader, model, optimizer):
    model.train()
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        del batch['count']
        batch = {k: v.cuda() for k, v in batch.items()}
#         logits = model(**batch).logits.reshape([16*256, 2])
#         y_true = batch["labels"].reshape([16*256])
#         loss = cross_entropy(logits, y_true, weight=emo_weights)
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        

def emocause_test_loop(dataloader, model):
    model.eval()
    top1, top3, top5 = [], [], []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            counts = batch['count']
            del batch['count']
            batch = {k: v.cuda() for k, v in batch.items()}
            logits = model(**batch).logits
            proba = softmax(logits, dim=-1)[:,:,1]
            top_predictions = proba.argsort(dim=1, descending=True)
            labels = batch['labels']
            for i, instance in enumerate(top_predictions):
                t1, t3, t5 = [], [], []
                for token_idx in instance:
                    if labels[i][token_idx] != -100:
                        if len(t1) < 1:
                            t1.append(labels[i][token_idx].cpu())
                        if len(t3) < 3:
                            t3.append(labels[i][token_idx].cpu())
                        if len(t5) < 5:
                            t5.append(labels[i][token_idx].cpu())
                        else:
                            break
                count = counts[i]
                top1.append(sum(t1) / count)
                top3.append(sum(t3) / count)
                top5.append(sum(t5) / count)
    print("---TOKEN LABELING TEST METRICS---")
    print(f"Top-1 Recall: {sum(top1) / len(top1)}")
    print(f"Top-3 Recall: {sum(top3) / len(top3)}")
    print(f"Top-5 Recall: {sum(top5) / len(top5)}")

**Процесс обучения и тестирования просто классификации (EmpatheticDialogues)**

In [9]:
def classification_train_loop(dataloader, model, optimizer):
    model.train()
    acc = evaluate.load("accuracy")
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        batch = {k: v.cuda() for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        predictions = torch.argmax(outputs.logits, dim=-1)
        acc.add_batch(predictions=predictions, references=batch["labels"])
    print("---CLASSIFICATION TRAIN METRICS---")
    print(acc.compute())
        

def classification_test_loop(dataloader, model):
    model.eval()
    acc = evaluate.load("accuracy")
#     y_pred, y = [], []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.cuda() for k, v in batch.items()}
            with torch.no_grad():
                outputs = model(**batch)
            predictions = torch.argmax(outputs.logits, dim=-1)
            acc.add_batch(predictions=predictions, references=batch["labels"])
#             y_pred.extend(predictions)
#             y.extend(batch["labels"])
        print("---CLASSIFICATION TEST METRICS---")
        print(acc.compute())
#     return y_pred, y

In [10]:
optimizer = Adam(model_classification.parameters(), lr=2e-5)
optimizer.add_param_group({'params': model_emo.classifier.parameters()})

**Обучение и валидация**

In [11]:
epochs = 8
for epoch in range(1, epochs + 1):
    print(f"\nEpoch {epoch}")
    emocause_train_loop(emocause_train_dataloader, model_emo, optimizer)
    if epoch % 2 == 0:
        emocause_test_loop(emocause_val_dataloader, model_emo)
        classification_train_loop(empdia_train_dataloader, model_classification, optimizer)
        classification_test_loop(empdia_val_dataloader, model_classification)
        torch.save(model_emo.state_dict(), f"./tmp_models/model_emo{epoch}.bin")
        torch.save(model_classification.state_dict(), f"./tmp_models/model_class{epoch}.bin")


Epoch 1


100%|████████████████████████████████████████████████████████████████████████████████| 188/188 [01:59<00:00,  1.57it/s]



Epoch 2


100%|████████████████████████████████████████████████████████████████████████████████| 188/188 [01:58<00:00,  1.58it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 48/48 [00:25<00:00,  1.87it/s]


---TOKEN LABELING TEST METRICS---
Top-1 Recall: tensor([0.3898])
Top-3 Recall: tensor([0.7707])
Top-5 Recall: tensor([0.9053])


100%|██████████████████████████████████████████████████████████████████████████████| 1112/1112 [11:18<00:00,  1.64it/s]


---CLASSIFICATION TRAIN METRICS---
{'accuracy': 0.43884892086330934}


100%|████████████████████████████████████████████████████████████████████████████████| 173/173 [00:36<00:00,  4.74it/s]


---CLASSIFICATION TEST METRICS---
{'accuracy': 0.5697211155378487}

Epoch 3


100%|████████████████████████████████████████████████████████████████████████████████| 188/188 [01:58<00:00,  1.58it/s]



Epoch 4


100%|████████████████████████████████████████████████████████████████████████████████| 188/188 [01:59<00:00,  1.57it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 48/48 [00:25<00:00,  1.90it/s]


---TOKEN LABELING TEST METRICS---
Top-1 Recall: tensor([0.3903])
Top-3 Recall: tensor([0.7695])
Top-5 Recall: tensor([0.9067])


100%|██████████████████████████████████████████████████████████████████████████████| 1112/1112 [11:26<00:00,  1.62it/s]


---CLASSIFICATION TRAIN METRICS---
{'accuracy': 0.608363309352518}


100%|████████████████████████████████████████████████████████████████████████████████| 173/173 [00:36<00:00,  4.77it/s]


---CLASSIFICATION TEST METRICS---
{'accuracy': 0.5874683085838465}

Epoch 5


100%|████████████████████████████████████████████████████████████████████████████████| 188/188 [01:59<00:00,  1.57it/s]



Epoch 6


100%|████████████████████████████████████████████████████████████████████████████████| 188/188 [02:01<00:00,  1.55it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 48/48 [00:23<00:00,  2.01it/s]


---TOKEN LABELING TEST METRICS---
Top-1 Recall: tensor([0.3884])
Top-3 Recall: tensor([0.7648])
Top-5 Recall: tensor([0.9057])


100%|██████████████████████████████████████████████████████████████████████████████| 1112/1112 [11:28<00:00,  1.62it/s]


---CLASSIFICATION TRAIN METRICS---
{'accuracy': 0.6755845323741008}


100%|████████████████████████████████████████████████████████████████████████████████| 173/173 [00:37<00:00,  4.64it/s]


---CLASSIFICATION TEST METRICS---
{'accuracy': 0.6091995653748642}

Epoch 7


100%|████████████████████████████████████████████████████████████████████████████████| 188/188 [01:58<00:00,  1.59it/s]



Epoch 8


100%|████████████████████████████████████████████████████████████████████████████████| 188/188 [01:58<00:00,  1.58it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 48/48 [00:27<00:00,  1.76it/s]


---TOKEN LABELING TEST METRICS---
Top-1 Recall: tensor([0.3864])
Top-3 Recall: tensor([0.7630])
Top-5 Recall: tensor([0.9097])


100%|██████████████████████████████████████████████████████████████████████████████| 1112/1112 [11:30<00:00,  1.61it/s]


---CLASSIFICATION TRAIN METRICS---
{'accuracy': 0.7302158273381295}


100%|████████████████████████████████████████████████████████████████████████████████| 173/173 [00:37<00:00,  4.60it/s]


---CLASSIFICATION TEST METRICS---
{'accuracy': 0.5983339369793553}


**Проверим метрики уже на тестовых датасетах**

In [13]:
emocause_test_loop(emocause_test_dataloader, model_emo)
classification_test_loop(empdia_test_dataloader, model_classification)

100%|██████████████████████████████████████████████████████████████████████████████████| 53/53 [01:17<00:00,  1.45s/it]


---TOKEN LABELING TEST METRICS---
Top-1 Recall: tensor([0.2496])
Top-3 Recall: tensor([0.6230])
Top-5 Recall: tensor([0.8061])


100%|████████████████████████████████████████████████████████████████████████████████| 159/159 [00:33<00:00,  4.73it/s]

---CLASSIFICATION TEST METRICS---
{'accuracy': 0.5895316804407713}



