EDA

In [None]:
!wget -O train.jsonl "https://huggingface.co/datasets/iluvvatar/NEREL/resolve/main/data/train.jsonl"
!wget -O dev.jsonl   "https://huggingface.co/datasets/iluvvatar/NEREL/resolve/main/data/dev.jsonl"
!wget -O test.jsonl  "https://huggingface.co/datasets/iluvvatar/NEREL/resolve/main/data/test.jsonl"

!wget -O ent_types.jsonl "https://huggingface.co/datasets/iluvvatar/NEREL/resolve/main/ent_types.jsonl"
!wget -O rel_types.jsonl "https://huggingface.co/datasets/iluvvatar/NEREL/resolve/main/rel_types.jsonl"

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# данные
train_path = "train.jsonl"
records = []
with open(train_path, "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        if i >= 200:
            break
        records.append(json.loads(line))

print(f"Загружено {len(records)} записей")
print(records[0].keys())

In [None]:
from collections import Counter

entity_counter = Counter()
relation_counter = Counter()
text_lengths = []
entities_per_doc = []

for rec in records:
    text = rec["text"]
    ents = rec.get("entities", [])
    rels = rec.get("relations", [])

    text_lengths.append(len(text.split()))
    entities_per_doc.append(len(ents))

    # entities
    for e in ents:
        parts = e.split("\t")
        ent_info = parts[1].split()
        ent_type = ent_info[0]
        entity_counter[ent_type] += 1

    # relations
    for r in rels:
        parts = r.split("\t")
        rel_type = parts[1].split()[0]
        relation_counter[rel_type] += 1

print("Топ 20 типов сущностей:", entity_counter.most_common(20))
print("Топ 20 типов отношений:", relation_counter.most_common(20))

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid", font="DejaVu Sans")

# топ 15 типов сущностей
top_entities = entity_counter.most_common(15)
df_entities = pd.DataFrame(top_entities, columns=["Entity", "Count"])

plt.figure(figsize=(10,6))
sns.barplot(data=df_entities, x="Count", y="Entity", palette="viridis")
plt.title("Топ 15 типов сущностей")
plt.xlabel("Частота")
plt.ylabel("Тип сущности")
plt.show()


# распределение длины текстов
plt.figure(figsize=(10,6))
sns.histplot(text_lengths, bins=30, kde=False)
plt.xlabel("Длина текста (в словах)")
plt.ylabel("Количество документов")
plt.title("Распределение длины текстов")
plt.show()


# распределение числа сущностей на документ
plt.figure(figsize=(10,6))
sns.histplot(entities_per_doc, bins=30, kde=False)
plt.xlabel("Число сущностей на документ")
plt.ylabel("Количество документов")
plt.title("Распределение числа сущностей")
plt.show()


In [None]:
from IPython.display import Markdown

Markdown("""
### Выводы
1. В корпусе встречаются очень частые типы сущностей (PERSON, PROFESSION etc.), но и редкие типы (AWARD, IDEOLOGY), может возникнуть дисбаланс классов.
2. Длина документов стакже неравномерно распределена: есть и короткие тексты, и длинные, поэтому при моделировании важно применить padding, truncation.
""")

In [None]:

import re
from collections import Counter

# Функции парсинга строкового формата NEREL
def parse_entity_line(line: str):
    parts = line.split('\t')
    if len(parts) < 3:
        return None
    ent_id = parts[0].strip()
    type_pos = parts[1].strip()
    text = parts[2].strip() if len(parts) > 2 else ''
    m = re.match(r'(\S+)\s+(\d+)\s+(\d+)', type_pos)
    if not m:
        return None
    ent_type = m.group(1)
    start = int(m.group(2))
    end = int(m.group(3))
    return {'id': ent_id, 'type': ent_type, 'start': start, 'end': end, 'text': text}

def parse_relation_line(line: str):
    parts = line.split('\t')
    if len(parts) < 2:
        return None
    rel_id = parts[0].strip()
    body = parts[1].strip()
    m = re.match(r'(\S+)\s+Arg1:(\S+)\s+Arg2:(\S+)', body)
    if not m:
        return None
    rel_type = m.group(1)
    arg1 = m.group(2); arg2 = m.group(3)
    return {'id': rel_id, 'type': rel_type, 'arg1': arg1, 'arg2': arg2}



In [None]:
def whitespace_tokenize_with_offsets(text: str):
    tokens = []
    spans = []
    pos = 0
    for tok in text.split():
        start = text.find(tok, pos)
        end = start + len(tok)
        tokens.append(tok)
        spans.append((start, end))
        pos = end
    return tokens, spans


def build_examples_from_nerel(records, event_list):
    examples = []
    event2idx = {ev: i for i, ev in enumerate(event_list)}

    for rec in records:
        text = rec["text"]
        tokens, token_spans = whitespace_tokenize_with_offsets(text)
        token_labels = ["O"] * len(tokens)
        cls_vec = [0] * len(event_list)

        # сущности
        for e in rec.get("entities", []):
            ent = parse_entity_line(e)
            if not ent:
                continue
            span_start, span_end = ent["start"], ent["end"]

            overlapping_idxs = []
            for i, (t_start, t_end) in enumerate(token_spans):
                if not (t_end <= span_start or t_start >= span_end):
                    overlapping_idxs.append(i)

            for j, tok_idx in enumerate(overlapping_idxs):
                if token_labels[tok_idx] != "O":
                    continue
                prefix = "B" if j == 0 else "I"
                token_labels[tok_idx] = f"{prefix}-{ent['type']}"

        # отношения
        for r in rec.get("relations", []):
            rel = parse_relation_line(r)
            if not rel:
                continue
            if rel["type"] in event2idx:
                cls_vec[event2idx[rel["type"]]] = 1

        examples.append({
            "id": rec["id"],
            "text": text,
            "tokens": tokens,
            "token_spans": token_spans,
            "tags": token_labels,
            "cls_vec": cls_vec
        })

    return examples


In [None]:
from collections import Counter

def make_event_list(records, K=30):
    counter = Counter()
    for rec in records:
        for r in rec.get("relations", []):
            rel = parse_relation_line(r)
            if rel:
                counter[rel["type"]] += 1

    return [t for t, _ in counter.most_common(K)]


In [None]:
event_list = make_event_list(records, K=30)
print("События:", event_list)

examples = build_examples_from_nerel(records, event_list)
print("Пример tokens, tags:", examples[0]["tokens"][:15], examples[0]["tags"][:15])
print("cls_vec:", examples[0]["cls_vec"])


In [None]:
from datasets import Dataset, DatasetDict

unique_labels = set()
for ex in examples:
    unique_labels.update(ex["tags"])
unique_labels.add("O")
label_list = sorted(unique_labels)
label2id = {lab: i for i, lab in enumerate(label_list)}
id2label = {i: lab for lab, i in label2id.items()}

for ex in examples:
    ex["tags"] = [label2id[t] for t in ex["tags"]]

full_ds = Dataset.from_list(examples)
split = full_ds.train_test_split(test_size=0.1, seed=42)
dataset = DatasetDict({"train": split["train"], "test": split["test"]})
print(dataset)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/rubert-base-cased", use_fast=True)

def tokenize_and_align_labels(examples_batch, tokenizer, max_length=256):
    tokenized = tokenizer(
        examples_batch["tokens"],
        is_split_into_words=True,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_offsets_mapping=True
    )

    labels = []
    for i, word_labels in enumerate(examples_batch["tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        label_ids = []
        prev_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != prev_word_idx:
                label_ids.append(word_labels[word_idx])
            else:
                label_ids.append(-100)
            prev_word_idx = word_idx
        labels.append(label_ids)

    tokenized["labels"] = labels
    tokenized["cls_labels"] = examples_batch["cls_vec"]
    tokenized.pop("offset_mapping")
    return tokenized


In [None]:
tokenized_dataset = dataset.map(
    lambda x: tokenize_and_align_labels(x, tokenizer),
    batched=True,
    remove_columns=["text", "tokens", "tags", "token_spans", "id"]
)

print(tokenized_dataset)
print(tokenized_dataset["train"][0])

In [None]:
from transformers import DataCollatorForTokenClassification
from torch.utils.data import DataLoader
import torch

data_collator = DataCollatorForTokenClassification(tokenizer)

def custom_collator(batch):
    features = [{
        "input_ids": item["input_ids"],
        "attention_mask": item["attention_mask"],
        "labels": item["labels"]
    } for item in batch]

    batch_enc = data_collator(features)

    batch_enc["cls_labels"] = torch.tensor([item["cls_labels"] for item in batch], dtype=torch.float)

    return batch_enc

train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=16,
    shuffle=True,
    collate_fn=custom_collator)

test_dataloader = DataLoader(
    tokenized_dataset["test"],
    batch_size=16,
    shuffle=False,
    collate_fn=custom_collator)

for batch in train_dataloader:
    print(batch.keys())
    print(batch["input_ids"].shape)
    print(batch["labels"].shape)
    print(batch["cls_labels"].shape)
    break

print("Готово. Примеры для обучения:", len(tokenized_dataset["train"]))

##### Модель: `JointModel` + custom loss (uncertainty weighting)


In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig

class JointModel(nn.Module):
    def __init__(
        self,
        encoder_name: str = "DeepPavlov/rubert-base-cased",
        num_token_labels: int = 10,
        num_cls_labels: int = 30,
        dropout: float = 0.1,
        use_uncertainty_weight: bool = True,
        pos_weight_cls: torch.Tensor | None = None,
    ):
        super().__init__()
        self.config = AutoConfig.from_pretrained(encoder_name)
        self.encoder = AutoModel.from_pretrained(encoder_name)
        hidden = self.config.hidden_size

        self.dropout = nn.Dropout(dropout)
        self.token_classifier = nn.Linear(hidden, num_token_labels)
        self.cls_classifier   = nn.Linear(hidden, num_cls_labels)


        self.token_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        self.cls_loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight_cls)

        self.use_uncertainty_weight = use_uncertainty_weight
        self.log_sigma_token = nn.Parameter(torch.tensor(0.0))
        self.log_sigma_cls   = nn.Parameter(torch.tensor(0.0))

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor | None = None,
        cls_labels: torch.Tensor | None = None,
    ):

        enc = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        seq = enc.last_hidden_state
        pooled = seq[:, 0, :]

        seq = self.dropout(seq)
        pooled = self.dropout(pooled)

        token_logits = self.token_classifier(seq)
        cls_logits   = self.cls_classifier(pooled)

        out = {"token_logits": token_logits, "cls_logits": cls_logits}

        if labels is not None and cls_labels is not None:
            token_loss = self.token_loss_fct(
                token_logits.view(-1, token_logits.size(-1)),
                labels.view(-1)
            )
            cls_loss = self.cls_loss_fct(cls_logits, cls_labels.float())

            if self.use_uncertainty_weight:
                # exp(-2*log_sigma)*L + log_sigma
                loss_token_term = torch.exp(-2.0 * self.log_sigma_token) * token_loss + self.log_sigma_token
                loss_cls_term   = torch.exp(-2.0 * self.log_sigma_cls)   * cls_loss   + self.log_sigma_cls
                loss = loss_token_term + loss_cls_term
                out.update({
                    "loss": loss,
                    "token_loss": token_loss.detach(),
                    "cls_loss": cls_loss.detach(),
                    "log_sigma_token": self.log_sigma_token.detach(),
                    "log_sigma_cls": self.log_sigma_cls.detach(),
                })
            else:
                loss = token_loss + cls_loss
                out.update({
                    "loss": loss,
                    "token_loss": token_loss.detach(),
                    "cls_loss": cls_loss.detach(),
                })

        return out


In [None]:
num_token_labels = len(set(id2label.keys()))
num_cls_labels   = len(examples[0]["cls_vec"])

model = JointModel(
    encoder_name="DeepPavlov/rubert-base-cased",
    num_token_labels=num_token_labels,
    num_cls_labels=num_cls_labels,
    dropout=0.1,
    use_uncertainty_weight=True,
    pos_weight_cls=None
).to("cuda" if torch.cuda.is_available() else "cpu")


##### Training / Validation



In [None]:
import torch
from transformers import get_linear_schedule_with_warmup

device = "cuda" if torch.cuda.is_available() else "cpu"

epochs = 10
lr = 5e-5

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

num_training_steps = epochs * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * num_training_steps),
    num_training_steps=num_training_steps
)

In [None]:
!pip install seqeval
from seqeval.metrics import f1_score as seqeval_f1
from sklearn.metrics import f1_score, precision_score, recall_score

def compute_token_f1(preds, labels, id2label):
    true_labels = []
    pred_labels = []

    for p, l in zip(preds, labels):
        cur_true, cur_pred = [], []
        for pi, li in zip(p, l):
            if li == -100:
                continue
            cur_true.append(id2label[li])
            cur_pred.append(id2label[pi])
        true_labels.append(cur_true)
        pred_labels.append(cur_pred)

    return seqeval_f1(true_labels, pred_labels, average="macro")

def compute_cls_metrics(preds, labels):
    preds_bin = (preds > 0).astype(int)
    micro_f1 = f1_score(labels, preds_bin, average="micro", zero_division=0)
    prec = precision_score(labels, preds_bin, average="micro", zero_division=0)
    rec = recall_score(labels, preds_bin, average="micro", zero_division=0)
    return micro_f1, prec, rec


In [None]:
import numpy as np

log_table = []

for epoch in range(1, epochs+1):
    model.train()
    total_loss = 0.0
    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        cls_labels = batch["cls_labels"].to(device)

        optimizer.zero_grad()
        out = model(input_ids, attention_mask, labels=labels, cls_labels=cls_labels)
        loss = out["loss"]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_dataloader)

    model.eval()
    all_preds_token, all_labels_token = [], []
    all_preds_cls, all_labels_cls = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            cls_labels = batch["cls_labels"].to(device)

            out = model(input_ids, attention_mask)
            token_logits = out["token_logits"].cpu().numpy()
            cls_logits = out["cls_logits"].cpu().numpy()

            # токены
            preds = np.argmax(token_logits, axis=-1)
            all_preds_token.extend(preds.tolist())
            all_labels_token.extend(labels.cpu().numpy().tolist())

            # cls
            all_preds_cls.extend(cls_logits)
            all_labels_cls.extend(cls_labels.cpu().numpy())

    token_f1 = compute_token_f1(all_preds_token, all_labels_token, id2label)
    cls_f1, cls_prec, cls_rec = compute_cls_metrics(np.array(all_preds_cls), np.array(all_labels_cls))

    log_table.append({
        "epoch": epoch,
        "train_loss": round(avg_train_loss, 4),
        "token_f1": round(token_f1, 4),
        "cls_f1": round(cls_f1, 4),
        "cls_prec": round(cls_prec, 4),
        "cls_rec": round(cls_rec, 4)
    })

    print(f"Epoch {epoch} | loss {avg_train_loss:.4f} | token-F1 {token_f1:.4f} | "
          f"cls-F1 {cls_f1:.4f} | P {cls_prec:.4f} | R {cls_rec:.4f}")


In [None]:
import pandas as pd
df_log = pd.DataFrame(log_table)
display(df_log)

In [None]:
from IPython.display import Markdown

Markdown("""
### Выводы
1. Модель уверенно обучается: train loss падает от 0.54 до 0.29.
2. CLS-задача стабильно даёт довольно высокий F1 ≈ 0.70 уже с первых эпох.
3. Token-level F1 находится не на очень высоком уровне, значит выделение сущностей сложнее и требует дообучения.
""")

##### Инференс, квантизация и анализ ошибок

In [None]:
import torch
import numpy as np

def predict(model, tokenizer, text, id2label, cls_event_list, device="cpu", max_length=128):
    model.eval()
    encoding = tokenizer(
        text.split(),
        is_split_into_words=True,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
    )

    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask)

    # токенные предсказания
    token_logits = outputs["token_logits"].cpu().numpy()[0]
    pred_ids = np.argmax(token_logits, axis=-1)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    pred_labels = [
        id2label[i] if mask == 1 else "PAD"
        for i, (mask, i) in enumerate(zip(attention_mask[0].tolist(), pred_ids))
    ]

    # CLS вероятности
    cls_logits = outputs["cls_logits"].cpu().numpy()[0]
    cls_probs = torch.sigmoid(torch.tensor(cls_logits)).numpy()
    cls_pred = {cls_event_list[i]: float(cls_probs[i]) for i in range(len(cls_probs))}

    return tokens, pred_labels, cls_pred


In [None]:
def qualitative_analysis(model, tokenizer, dataset, id2label, cls_event_list, n=10, device="cpu"):
    for i in range(n):
        ex = dataset[i]
        text = ex["text"]
        true_tags = [id2label[t] for t in ex["tags"]]
        true_cls = ex["cls_vec"]

        tokens, pred_tags, cls_pred = predict(model, tokenizer, text, id2label, cls_event_list, device=device)

        print(f"\nПример {i+1}")
        print("Текст:", text[:200], "...")
        print("GT tags :", true_tags[:20])
        print("Pred tags:", pred_tags[:20])

        true_cls_labels = [cls_event_list[j] for j, v in enumerate(true_cls) if v == 1]
        pred_cls_labels = [k for k, v in cls_pred.items() if v > 0.5]

        print("GT CLS :", true_cls_labels)
        print("Pred CLS:", pred_cls_labels)

        fp = set(pred_cls_labels) - set(true_cls_labels)
        fn = set(true_cls_labels) - set(pred_cls_labels)
        if fp:
            print("False Positives:", fp)
        if fn:
            print("False Negatives:", fn)


In [None]:
qualitative_analysis(
    model, tokenizer, dataset["test"],
    id2label=id2label, cls_event_list=event_list,
    n=10, device=device
)

In [None]:
import torch.nn as nn
import torch
import time

model_cpu = model.to("cpu").eval()

# динамическая квантизация
quantized_model = torch.quantization.quantize_dynamic(
    model_cpu, {nn.Linear}, dtype=torch.qint8
)

print("Обычная модель (fp32):")
print(model_cpu)
print("Квантизированная модель (int8):")
print(quantized_model)

In [None]:
text = "Россия и США подписали новое соглашение о ядерном разоружении."

_ = predict(model_cpu, tokenizer, text, id2label, event_list)
_ = predict(quantized_model, tokenizer, text, id2label, event_list)

N = 50
start = time.time()
for _ in range(N):
    _ = predict(model_cpu, tokenizer, text, id2label, event_list)
time_fp32 = (time.time() - start) / N

start = time.time()
for _ in range(N):
    _ = predict(quantized_model, tokenizer, text, id2label, event_list)
time_int8 = (time.time() - start) / N

print(f"Среднее время fp32: {time_fp32*1000:.2f} ms")
print(f"Среднее время int8: {time_int8*1000:.2f} ms")


In [None]:
def evaluate_model(eval_model, dataloader, id2label, event_list, device="cpu"):
    eval_model.eval()
    all_preds_token, all_labels_token = [], []
    all_preds_cls, all_labels_cls = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            cls_labels = batch["cls_labels"].to(device)

            out = eval_model(input_ids, attention_mask)
            token_logits = out["token_logits"].cpu().numpy()
            cls_logits = out["cls_logits"].cpu().numpy()

            # токены
            preds = token_logits.argmax(axis=-1)
            all_preds_token.extend(preds.tolist())
            all_labels_token.extend(labels.cpu().numpy().tolist())

            # CLS
            all_preds_cls.extend(cls_logits)
            all_labels_cls.extend(cls_labels.cpu().numpy())

    token_f1 = compute_token_f1(all_preds_token, all_labels_token, id2label)
    cls_f1, cls_prec, cls_rec = compute_cls_metrics(
        np.array(all_preds_cls), np.array(all_labels_cls)
    )
    return token_f1, cls_f1, cls_prec, cls_rec

print("FP32 модель:")
print(evaluate_model(model_cpu, test_dataloader, id2label, event_list))

print("INT8 модель:")
print(evaluate_model(quantized_model, test_dataloader, id2label, event_list))


In [None]:
from IPython.display import Markdown

Markdown("""
### Выводы
1. Модель уверенно предсказывает частые сущности и отношения, но систематически пропускает редкие классы.
2. В CLS-задаче заметен избыток ложноположительных меток: модель добавляет отношения, которых нет в тексте.
3. Квантизация заметно ускорила инференс (примерно на 25%), качество при этом значительно не пострадало.
""")