In [1]:
import collections
import re

import evaluate
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset, concatenate_datasets, load_dataset
from sklearn.metrics import classification_report, f1_score
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    pipeline,
)

SEED = 42
MODEL_NAME = "DeepPavlov/rubert-base-cased"
DATASET_NAME = "Davlan/sib200"
DATASET_LANGUAGE = "rus_Cyrl"
MINIBATCH_SIZE = 8
MAX_LEN = 512

np.random.seed(SEED)
torch.manual_seed(SEED)


<torch._C.Generator at 0x7f0ea21c6190>

Модель та же, но уменьшил размер минибатча, дало прирост итоговой метрики, и аналогично ограничил максимальную длину. Также простенькая нормализация для стабильности.


In [None]:
def normalize_text(text):
    if not isinstance(text, str):
        return ""
    import unicodedata

    text = unicodedata.normalize("NFKC", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text


def normalize_batch_texts(texts):
    return [normalize_text(text) for text in texts]


Базовые быстрые и простые аугментации чтобы увеличить количество примеров на слабо представленные классы.


In [None]:
def split_sentences(text):
    parts = re.split(r"([.!?])", text)
    sents = []
    for i in range(0, len(parts), 2):
        seg = parts[i].strip()
        if not seg:
            continue
        end = parts[i + 1] if i + 1 < len(parts) else ""
        sents.append((seg + end).strip())
    return [s for s in sents if s]


def aug_sentence_shuffle(text, p=0.35):
    if np.random.rand() > p:
        return text
    sents = split_sentences(text)
    if len(sents) < 2:
        return text
    np.random.shuffle(sents)
    return " ".join(sents)


def aug_punct_swap(text, p=0.25):
    if np.random.rand() > p:
        return text
    puncts = [".", ",", "!", "?"]
    chars = list(text)
    for i, ch in enumerate(chars):
        if ch in puncts and np.random.rand() < 0.2:
            chars[i] = np.random.choice(puncts)
    return "".join(chars)


def aug_truncate_middle(text, p=0.25):
    if np.random.rand() > p:
        return text
    words = text.split()
    n = len(words)
    if n < 12:
        return text
    cut = int(n * np.random.uniform(0.1, 0.3))
    start_keep = int((n - cut) / 2)
    new_words = words[:start_keep] + words[start_keep + cut :]
    return " ".join(new_words)


def apply_augmentations(text):
    text = aug_sentence_shuffle(text, p=0.4)
    text = aug_punct_swap(text, p=0.3)
    text = aug_truncate_middle(text, p=0.3)
    return text

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
train_set = load_dataset(DATASET_NAME, DATASET_LANGUAGE, split="train")
validation_set = load_dataset(DATASET_NAME, DATASET_LANGUAGE, split="validation")
test_set = load_dataset(DATASET_NAME, DATASET_LANGUAGE, split="test")

In [None]:
def normalize_dataset(dataset):
    normalized_texts = normalize_batch_texts(dataset["text"])
    return dataset.remove_columns(["text"]).add_column("text", normalized_texts)


train_set = normalize_dataset(train_set)
validation_set = normalize_dataset(validation_set)
test_set = normalize_dataset(test_set)

print(f"Normilized text example: {train_set[0]['text'][:100]}...")

Normilized text example: Турция с трёх сторон окружена морями: на западе — Эгейским, на севере — Чёрным и на юге — Средиземны...


In [None]:
train_labels = train_set["category"]
label_counts = collections.Counter(train_labels)

print(f"All classes: {len(label_counts)}")
print(f"Examples: {len(train_labels)}")
print("\nBy class:")
for label, count in label_counts.most_common():
    print(f"  {label}: {count} examples ({count / len(train_labels) * 100:.2f}%)")

All classes: 7
Examples: 701

By class:
  science/technology: 176 examples (25.11%)
  travel: 138 examples (19.69%)
  politics: 102 examples (14.55%)
  sports: 85 examples (12.13%)
  health: 77 examples (10.98%)
  entertainment: 65 examples (9.27%)
  geography: 58 examples (8.27%)


In [8]:
counts = collections.Counter(train_set["category"])
rare_classes = {"entertainment", "geography"}
if len(counts) > 0:
    target = int(max(counts.values()) * 0.95)
    aug_texts, aug_labels = [], []
    rng = np.random.default_rng(SEED)
    for label in counts:
        if label not in rare_classes:
            continue
        need = max(0, target - counts[label])
        if need == 0:
            continue
        idxs = [i for i, l in enumerate(train_set["category"]) if l == label]
        for _ in range(need):
            i = int(rng.choice(idxs))
            base_text = train_set[i]["text"]
            new_text = apply_augmentations(base_text)
            aug_texts.append(new_text)
            aug_labels.append(label)

    if aug_texts:
        print(f"Adding augmented examples: {len(aug_texts)} (target level ~{target})")
        aug_ds = Dataset.from_dict({"text": aug_texts, "category": aug_labels})
        train_set = concatenate_datasets([train_set, aug_ds])
        train_set = train_set.shuffle(seed=SEED)


Adding augmented examples: 211 (target level ~167)


In [9]:
USE_PSEUDO_LABELING = True
PSEUDO_FRACTION = 0.2
PSEUDO_MIN_PROB = 0.80

if USE_PSEUDO_LABELING:
    base_model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=len(set(train_set["category"])),
        classifier_dropout=0.1,
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
    )

    tmp_tokenized = train_set.map(
        lambda batch: tokenizer(batch["text"], truncation=True, max_length=MAX_LEN),
        batched=True,
    )
    tmp_tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"])

    base_model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_model.to(device)

    pseudo_texts = []
    pseudo_labels = []

    n_pseudo_candidates = int(len(tmp_tokenized) * PSEUDO_FRACTION)
    idxs_for_pseudo = np.random.choice(
        len(tmp_tokenized), size=n_pseudo_candidates, replace=False
    )

    softmax = nn.Softmax(dim=-1)
    with torch.no_grad():
        for idx in idxs_for_pseudo:
            sample = tmp_tokenized[idx]
            input_ids = sample["input_ids"].unsqueeze(0).to(device)
            attention_mask = sample["attention_mask"].unsqueeze(0).to(device)
            outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            probs = softmax(logits)[0].cpu().numpy()
            pred_idx = int(np.argmax(probs))
            conf = float(probs[pred_idx])

            if conf >= PSEUDO_MIN_PROB:
                label = base_model.config.id2label.get(pred_idx, None)
                if label == "entertainment":
                    pseudo_texts.append(train_set[idx]["text"])
                    pseudo_labels.append(label)

    if pseudo_texts:
        print(f"Pseudolabled class 'entertainment': {len(pseudo_texts)}")
        pseudo_ds = Dataset.from_dict({"text": pseudo_texts, "category": pseudo_labels})
        train_set = concatenate_datasets([train_set, pseudo_ds])
        train_set = train_set.shuffle(seed=SEED)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at DeepPavlov/rubert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Псевдолейблинг для классов. Главная проблема этого датасета - сильный дисбаланс классов. В частности, довольно маленький класс entertainment, кроме того в нём ещё и довольно плохие и "жидкие" примеры. По сути аугментации, псевдолейблинг и тп используются главным образом чтобы побороть именно этот один класс.


In [10]:
list_of_categories = sorted(
    list(
        set(train_set["category"])
        | set(validation_set["category"])
        | set(test_set["category"])
    )
)
indices_of_categories = list(range(len(list_of_categories)))
n_categories = len(list_of_categories)
id2label = dict(zip(indices_of_categories, list_of_categories))
label2id = dict(zip(list_of_categories, indices_of_categories))

print("\nList categories:")
for i, c in enumerate(list_of_categories):
    print(i, c)



List categories:
0 entertainment
1 geography
2 health
3 politics
4 science/technology
5 sports
6 travel


In [None]:
def compute_class_weights(labels, label2id, method="sqrt"):
    counts_local = collections.Counter(labels)
    classes_in_order = [lbl for lbl, _ in sorted(label2id.items(), key=lambda x: x[1])]
    n_classes = len(classes_in_order)

    class_counts_by_idx = [counts_local.get(lbl, 0) for lbl in classes_in_order]

    if method == "balanced":
        total = sum(class_counts_by_idx)
        weights = [
            (total / (n_classes * c)) if c > 0 else 1.0 for c in class_counts_by_idx
        ]
    elif method == "inverse":
        max_c = max(class_counts_by_idx) if any(class_counts_by_idx) else 1
        weights = [(max_c / c) if c > 0 else 1.0 for c in class_counts_by_idx]
    elif method == "sqrt":
        max_c = max(class_counts_by_idx) if any(class_counts_by_idx) else 1
        weights = [np.sqrt(max_c / c) if c > 0 else 1.0 for c in class_counts_by_idx]
    elif method == "custom":
        max_c = max(class_counts_by_idx) if any(class_counts_by_idx) else 1
        weights = [np.sqrt(max_c / c) if c > 0 else 1.0 for c in class_counts_by_idx]
        for i, lbl in enumerate(classes_in_order):
            if lbl == "entertainment":
                weights[i] *= 5.0
            if lbl == "geography":
                weights[i] *= 1.5
    else:
        raise ValueError(f"Unknown method: {method}")

    print("\nВеса классов:")
    for i, lbl in enumerate(classes_in_order):
        print(
            f"  idx={i}, label={lbl}, count={class_counts_by_idx[i]}, weight={weights[i]:.4f}"
        )

    return torch.tensor(weights, dtype=torch.float32)


class_weights = compute_class_weights(train_set["category"], label2id, method="custom")



Веса классов:
  idx=0, label=entertainment, count=167, weight=5.1330
  idx=1, label=geography, count=167, weight=1.5399
  idx=2, label=health, count=77, weight=1.5119
  idx=3, label=politics, count=102, weight=1.3136
  idx=4, label=science/technology, count=176, weight=1.0000
  idx=5, label=sports, count=85, weight=1.4390
  idx=6, label=travel, count=138, weight=1.1293


Из-за дисбаланса классов сделал взвешенное обучение, но никакие методы кроме агрессивного custom для самого проблемного класса особо не помогли. Во-первыз аугментации уже выравнивают количество примеров в обучающей выборке, во-вторых - проблем с перекрытием между классами лёгкие подправки весов не меняют ситуацию. 5.0 - подобрано ручками, и похоже является оптимумом для данного класса и датасета.


In [12]:
def tok(batch):
    return tokenizer(batch["text"], truncation=True, max_length=MAX_LEN)


tokenized_train_set = train_set.map(tok, batched=True)
tokenized_validation_set = validation_set.map(tok, batched=True)

labeled_train_set = tokenized_train_set.add_column(
    "label", [label2id[val] for val in tokenized_train_set["category"]]
)
labeled_validation_set = tokenized_validation_set.add_column(
    "label", [label2id[val] for val in tokenized_validation_set["category"]]
)


Map:   0%|          | 0/912 [00:00<?, ? examples/s]

Map:   0%|          | 0/99 [00:00<?, ? examples/s]

In [13]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8)

In [14]:
cls_metric = evaluate.load("f1")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    f1_macro = cls_metric.compute(
        predictions=predictions, references=labels, average="macro"
    )["f1"]
    accuracy = (predictions == labels).mean()
    return {"f1": f1_macro, "accuracy": accuracy}


In [15]:
classifier = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=n_categories,
    id2label=id2label,
    label2id=label2id,
    classifier_dropout=0.2,
    hidden_dropout_prob=0.2,
    attention_probs_dropout_prob=0.2,
)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at DeepPavlov/rubert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Добавлен дропаут для улучшения обучаемости, хотя реальных изменений в данном случае не обнаружено.


In [16]:
training_args = TrainingArguments(
    output_dir="rubert_sib200_weighted_v2",
    learning_rate=2.5e-5,
    per_device_train_batch_size=MINIBATCH_SIZE,
    per_device_eval_batch_size=MINIBATCH_SIZE,
    gradient_accumulation_steps=2,
    num_train_epochs=10,
    weight_decay=0.05,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    logging_steps=50,
    warmup_ratio=0.12,
    lr_scheduler_type="linear",
    seed=SEED,
    data_seed=SEED,
    report_to=["none"],
    fp16=torch.cuda.is_available(),
    no_cuda=not torch.cuda.is_available(),
    label_smoothing_factor=0.03,
    gradient_checkpointing=True,
    max_grad_norm=1.0,
)


Подобраны оптимальные гиперпараметры. Сложно оценить как именно они повлияли и почему, но во многом всё упирается в размер датасета и то, что модель сходится довольно быстро.


In [17]:
class WeightedLossTrainer(Trainer):
    def __init__(self, class_weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        labels = inputs.get("labels")
        outputs = model(**{k: v for k, v in inputs.items() if k != "labels"})
        logits = outputs.get("logits")

        gamma = 1.5

        logits_flat = logits.view(-1, logits.size(-1))
        labels_flat = labels.view(-1)

        log_probs = nn.functional.log_softmax(logits_flat, dim=-1)
        probs = log_probs.exp()

        labels_flat_long = labels_flat.long()
        idx = torch.arange(labels_flat_long.size(0), device=logits.device)
        log_p_t = log_probs[idx, labels_flat_long]
        p_t = probs[idx, labels_flat_long]

        class_weights = self.class_weights.to(logits.device)
        alpha_t = class_weights[labels_flat_long]

        focal_factor = (1.0 - p_t) ** gamma
        loss = -alpha_t * focal_factor * log_p_t
        loss = loss.mean()

        return (loss, outputs) if return_outputs else loss


trainer = WeightedLossTrainer(
    class_weights=class_weights,
    model=classifier,
    args=training_args,
    train_dataset=labeled_train_set,
    eval_dataset=labeled_validation_set,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print(f"Using CUDA: {torch.cuda.is_available()}")
print(f"Class weights: {class_weights}")

trainer.train()


  super().__init__(*args, **kwargs)


Using CUDA: True
Class weights: tensor([5.1330, 1.5399, 1.5119, 1.3136, 1.0000, 1.4390, 1.1293])


Epoch,Training Loss,Validation Loss,F1,Accuracy
1,2.6931,2.06808,0.077662,0.151515
2,1.3258,0.59089,0.846934,0.858586
3,0.3494,0.434496,0.838882,0.838384
4,0.1303,0.655787,0.850209,0.848485
5,0.0855,0.793864,0.842218,0.848485
6,0.0222,0.696421,0.864,0.868687
7,0.0129,0.760679,0.860219,0.858586
8,0.0017,0.802027,0.855349,0.858586
9,0.0025,0.791048,0.85257,0.858586
10,0.0007,0.793358,0.85257,0.858586


TrainOutput(global_step=570, training_loss=0.40600183783262445, metrics={'train_runtime': 76.278, 'train_samples_per_second': 119.563, 'train_steps_per_second': 7.473, 'total_flos': 218063076192000.0, 'train_loss': 0.40600183783262445, 'epoch': 10.0})

In [18]:
results = trainer.evaluate()
print("\nValidation results:")
for key, value in results.items():
    if isinstance(value, (int, float)):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")



Validation results:
  eval_loss: 0.6964
  eval_f1: 0.8640
  eval_accuracy: 0.8687
  eval_runtime: 0.1326
  eval_samples_per_second: 746.8070
  eval_steps_per_second: 98.0660
  epoch: 10.0000


In [19]:
best_model = trainer.model


def create_classification_pipeline_with_normalization(model, tokenizer, device=-1):
    clf = pipeline(
        "text-classification",
        model=model,
        tokenizer=tokenizer,
        device=device,
    )

    def predict_with_normalization(texts):
        normalized_texts = normalize_batch_texts(texts)
        return clf(normalized_texts, truncation=True, max_length=MAX_LEN)

    return predict_with_normalization


clf_normalized = create_classification_pipeline_with_normalization(
    model=best_model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,
)


Device set to use cuda:0


In [20]:
validation_texts = list(validation_set["text"])
pred_val = [label2id[x["label"]] for x in clf_normalized(validation_texts)]
true_val = [label2id[val] for val in validation_set["category"]]

print("\n" + "=" * 50)
print("Validation report:")
print("=" * 50)
print(
    classification_report(
        y_true=true_val, y_pred=pred_val, target_names=list_of_categories, digits=4
    )
)



Validation report:
                    precision    recall  f1-score   support

     entertainment     1.0000    0.7778    0.8750         9
         geography     0.7500    0.7500    0.7500         8
            health     1.0000    0.7273    0.8421        11
          politics     0.9286    0.9286    0.9286        14
science/technology     0.8889    0.9600    0.9231        25
            sports     1.0000    0.9167    0.9565        12
            travel     0.7083    0.8500    0.7727        20

          accuracy                         0.8687        99
         macro avg     0.8965    0.8443    0.8640        99
      weighted avg     0.8827    0.8687    0.8702        99



In [21]:
test_texts = list(test_set["text"])
pred_test = [label2id[x["label"]] for x in clf_normalized(test_texts)]
true_test = [label2id[val] for val in test_set["category"]]

print("\n" + "=" * 50)
print("Test report:")
print("=" * 50)
print(
    classification_report(
        y_true=true_test, y_pred=pred_test, target_names=list_of_categories, digits=4
    )
)



Test report:
                    precision    recall  f1-score   support

     entertainment     0.8667    0.6842    0.7647        19
         geography     0.8889    0.9412    0.9143        17
            health     0.9130    0.9545    0.9333        22
          politics     1.0000    0.9000    0.9474        30
science/technology     0.8909    0.9608    0.9245        51
            sports     0.9231    0.9600    0.9412        25
            travel     0.9500    0.9500    0.9500        40

          accuracy                         0.9216       204
         macro avg     0.9189    0.9072    0.9108       204
      weighted avg     0.9224    0.9216    0.9201       204



Удалось немного побить бэйзлайн. Ключевые изменения - веса, именно агрессивное навязывание большого веса плохому классу помогло выиграть ~2% к f1 модели. Остальные изменения скорее стабилизировали обучение и внесли менее значимые изменения именно в метрику.


#### Трэшхолды

Попробовал изменить трэшхолды классификации, чтобы ещё немного поднять метрику. Результаты ниже.


In [None]:
def get_logits_and_labels(dataset, model, data_collator, batch_size=MINIBATCH_SIZE):
    model.eval()
    device = model.device

    all_logits = []
    all_labels = []

    torch_dataset = dataset.with_format(
        type="torch",
        columns=["input_ids", "attention_mask", "label"],
    )

    data_loader = DataLoader(
        torch_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=data_collator,
    )

    with torch.no_grad():
        for batch in data_loader:
            labels = batch["labels"].to(device)
            inputs = {
                "input_ids": batch["input_ids"].to(device),
                "attention_mask": batch["attention_mask"].to(device),
            }
            outputs = model(**inputs)
            logits = outputs.logits
            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return all_logits, all_labels


val_logits, val_labels = get_logits_and_labels(
    labeled_validation_set, trainer.model, data_collator
)

val_probs = F.softmax(val_logits, dim=-1).numpy()
val_true = val_labels.numpy()

ent_idx = label2id["entertainment"]


def predict_with_delta_entertainment(probs, delta):
    preds = []
    for p in probs:
        top1_idx = int(p.argmax())
        top1_prob = float(p[top1_idx])
        ent_prob = float(p[ent_idx])

        if top1_idx == ent_idx:
            preds.append(ent_idx)
            continue

        if (top1_prob - ent_prob) <= delta and ent_prob >= 0.20:
            preds.append(ent_idx)
        else:
            preds.append(top1_idx)
    return np.array(preds, dtype=np.int64)


candidate_deltas = [0.0, 0.03, 0.05, 0.07, 0.1, 0.15, 0.2, 0.3, 0.35, 0.4, 0.45]
best_delta = 0.0
best_f1 = -1.0

for d in candidate_deltas:
    preds_d = predict_with_delta_entertainment(val_probs, d)
    f1_macro_d = f1_score(val_true, preds_d, average="macro")
    print(f"delta={d:.3f} => macro F1 (val) = {f1_macro_d:.4f}")
    if f1_macro_d > best_f1:
        best_f1 = f1_macro_d
        best_delta = d

print(f"\nbest_delta={best_delta:.3f}; macro F1 on valid={best_f1:.4f}")


delta=0.000 => macro F1 (val) = 0.8640
delta=0.030 => macro F1 (val) = 0.8640
delta=0.050 => macro F1 (val) = 0.8640
delta=0.070 => macro F1 (val) = 0.8526
delta=0.100 => macro F1 (val) = 0.8526
delta=0.150 => macro F1 (val) = 0.8526
delta=0.200 => macro F1 (val) = 0.8526
delta=0.300 => macro F1 (val) = 0.8526
delta=0.350 => macro F1 (val) = 0.8526
delta=0.400 => macro F1 (val) = 0.8526
delta=0.450 => macro F1 (val) = 0.8526

best_delta=0.000; macro F1 on valid=0.8640


In [26]:
val_preds_threshold = predict_with_delta_entertainment(val_probs, best_delta)
print("\n" + "=" * 50)
print("Validation report (with entertainment threshold tuning):")
print("=" * 50)
print(
    classification_report(
        y_true=val_true,
        y_pred=val_preds_threshold,
        target_names=list_of_categories,
        digits=4,
    )
)

tokenized_test_set = test_set.map(tok, batched=True)
labeled_test_set = tokenized_test_set.add_column(
    "label", [label2id[val] for val in tokenized_test_set["category"]]
)

test_logits, test_labels = get_logits_and_labels(
    labeled_test_set, trainer.model, data_collator
)
test_probs = F.softmax(test_logits, dim=-1).numpy()
test_true = test_labels.numpy()

test_preds_threshold = predict_with_delta_entertainment(test_probs, best_delta)

print("\n" + "=" * 50)
print("Test report (with entertainment threshold tuning):")
print("=" * 50)
print(
    classification_report(
        y_true=test_true,
        y_pred=test_preds_threshold,
        target_names=list_of_categories,
        digits=4,
    )
)



Validation report (with entertainment threshold tuning):
                    precision    recall  f1-score   support

     entertainment     1.0000    0.7778    0.8750         9
         geography     0.7500    0.7500    0.7500         8
            health     1.0000    0.7273    0.8421        11
          politics     0.9286    0.9286    0.9286        14
science/technology     0.8889    0.9600    0.9231        25
            sports     1.0000    0.9167    0.9565        12
            travel     0.7083    0.8500    0.7727        20

          accuracy                         0.8687        99
         macro avg     0.8965    0.8443    0.8640        99
      weighted avg     0.8827    0.8687    0.8702        99


Test report (with entertainment threshold tuning):
                    precision    recall  f1-score   support

     entertainment     0.8667    0.6842    0.7647        19
         geography     0.8889    0.9412    0.9143        17
            health     0.9130    0.9545    0.9

Как видно, трэшхолды не влияют на улучшение метрики, значит скорее всего эта модель уже оптимум. В такой ситуации скорее всего может помось либо использования более крупной модели (хотел попробовать, но не влезло в VRAM), либо порабоать с классом entertainment, добавив новых примеров и уменьшив его схожесть с другими классами.
