In [5]:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64,expandable_segments:True"

import random
import json
import numpy as np
import torch
import torch.nn as nn
from itertools import combinations
from collections import defaultdict, Counter
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from sklearn.metrics import precision_recall_curve
from torch.utils.data import DataLoader, Dataset

from transformers import (
    AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling,
    Trainer, TrainingArguments, EarlyStoppingCallback
)
from sentence_transformers import SentenceTransformer, models, InputExample, losses, evaluation
from sentence_transformers.evaluation import BinaryClassificationEvaluator

In [6]:
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("torch.cuda.is_available():", torch.cuda.is_available())

CORPUS_FILES = [
    "/content/text_corpus_grls_rlsnet.txt",
    "/content/rlsnet_texts_from_json.txt",
    "/content/corpus_modified_3.txt"
]
SYN_DATA_FILE_LIST = [
    "/content/clusters_2025_08_20_15.json",
    "/content/clusters_2025_08_20_200.json",
    "/content/clusters_2025_08_20_400.json",
    "/content/clusters_synosym_dict.json",
    "/content/clusters_side_e_dict.json"
]

OUTPUT_DIR = "trained_synonym_model"

torch.cuda.is_available(): True


In [7]:
# 1. Domain-Adaptive Pretraining (MLM) с короткими последовательностями
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten()
        }

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

texts = set()
for filename in CORPUS_FILES:
    with open(filename, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if len(line.split()) > 3:
                texts.add(line)
texts = list(texts)

dataset = TextDataset(texts, tokenizer, max_len=128)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(
    output_dir="dapt_mlm",
    overwrite_output_dir=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    fp16=True,
    save_steps=5000,
    logging_steps=2000,
    learning_rate=3e-5,
    weight_decay=0.01,
    max_grad_norm=1.0
)

model = AutoModelForMaskedLM.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model.resize_token_embeddings(len(tokenizer))

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [8]:
trainer.train()
model.save_pretrained("dapt_mlm")
tokenizer.save_pretrained("dapt_mlm")

Step,Training Loss


('dapt_mlm/tokenizer_config.json',
 'dapt_mlm/special_tokens_map.json',
 'dapt_mlm/vocab.txt',
 'dapt_mlm/added_tokens.json',
 'dapt_mlm/tokenizer.json')

In [9]:
# --- Шаг 2: подготовка примеров с улучшенной балансировкой ---
pos_pairs = []
neg_pairs = []

for filename in SYN_DATA_FILE_LIST:
    with open(filename, 'r', encoding="utf-8") as f:
        clusters = json.load(f)["clusters"]

    for cluster_data in clusters.values():
        grouped = defaultdict(list)
        all_labels = []

        for label, group_id in cluster_data["labels"]:
            all_labels.append((label, group_id))
            if group_id is not None:
                grouped[group_id].append(label)

        # Положительные пары — внутри одного подмножества с одинаковыми group_id ≠ None
        for group in grouped.values():
            if len(group) > 1:
                pos_pairs.extend(combinations(group, 2))

        # Отрицательные пары — между разными номерами внутри кластера
        group_ids = list(grouped.keys())
        for i in range(len(group_ids)):
            for j in range(i + 1, len(group_ids)):
                group_a = grouped[group_ids[i]]
                group_b = grouped[group_ids[j]]
                neg_pairs.extend((a, b) for a in group_a for b in group_b)

        # Отрицательные пары для всех с null
        null_items = [label for label, gid in all_labels if gid is None]
        non_null_items = [label for label, gid in all_labels if gid is not None]

        # Более тонкий подход для null-элементов
        for null_label in null_items:
            for other_label in non_null_items:
                # Проверяем, не являются ли термины похожими по ключевым словам
                null_words = set(null_label.lower().split())
                other_words = set(other_label.lower().split())

                # Если есть значительное пересечение, возможно, это не строгий негатив
                intersection = null_words & other_words
                if len(intersection) < 2:  # Менее 2 общих слов - считаем строгим негативом
                    neg_pairs.append((null_label, other_label))

In [10]:
# Удаление дубликатов
pos_pairs = list(set(pos_pairs))
neg_pairs = list(set(neg_pairs) - set(pos_pairs))

# Улучшенная балансировка с oversampling положительных пар
pos_count = len(pos_pairs)
neg_count = len(neg_pairs)

print(f"Positive pairs: {pos_count}, Negative pairs: {neg_count}")

# Oversampling положительных пар до 1:4
target_neg_count = pos_count * 4
if neg_count > target_neg_count:
    random.shuffle(neg_pairs)
    neg_pairs = neg_pairs[:target_neg_count]
else:
    # Если негативов меньше, чем нужно, добавляем случайные пары
    all_terms = list(set([p for pair in pos_pairs + neg_pairs for p in pair]))
    additional_neg = []
    while len(additional_neg) < target_neg_count - neg_count:
        t1, t2 = random.sample(all_terms, 2)
        if (t1, t2) not in pos_pairs and (t2, t1) not in pos_pairs:
            additional_neg.append((t1, t2))
    neg_pairs.extend(additional_neg)

print(f"After balancing - Positive pairs: {len(pos_pairs)}, Negative pairs: {len(neg_pairs)}")

examples = [InputExample(texts=[a, b], label=1.0) for a, b in pos_pairs] + \
           [InputExample(texts=[a, b], label=0.0) for a, b in neg_pairs]
random.shuffle(examples)

Positive pairs: 5607, Negative pairs: 11754
After balancing - Positive pairs: 5607, Negative pairs: 16821


In [11]:
# Деление на тренировочную, валидационную и тестовую выборки
n = len(examples)
train_ex = examples[:int(0.8 * n)]
val_ex   = examples[int(0.8 * n):int(0.9 * n)]
test_ex  = examples[int(0.9 * n):]

train_loader = DataLoader(train_ex, shuffle=True, batch_size=4, collate_fn=lambda batch: batch)
val_loader   = DataLoader(val_ex, shuffle=False, batch_size=8, collate_fn=lambda batch: batch)
test_loader  = DataLoader(test_ex, shuffle=False, batch_size=8, collate_fn=lambda batch: batch)

# --- Шаг 3: SentenceTransformer из DAPT-чекпоинта с ClinicalBERT ---
transformer = models.Transformer("dapt_mlm", max_seq_length=128)
pooling     = models.Pooling(transformer.get_word_embedding_dimension(),
                             pooling_mode_mean_tokens=True,
                             pooling_mode_cls_token=False,
                             pooling_mode_max_tokens=False)
st_model    = SentenceTransformer(modules=[transformer, pooling], device=device)


Some weights of BertModel were not initialized from the model checkpoint at dapt_mlm and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# Улучшенная система "жестких" негативов с warm-up
class HardNegativeMiner:
    def __init__(self, model, warmup_epochs=2):
        self.model = model
        self.warmup_epochs = warmup_epochs
        self.current_epoch = 0

    def set_epoch(self, epoch):
        self.current_epoch = epoch

    def mine_hard_negatives(self, batch, k=2):
        # Включаем hard negatives только после warm-up
        if self.current_epoch < self.warmup_epochs:
            return []

        texts = [t for ex in batch for t in ex.texts]
        emb = self.model.encode(texts, convert_to_tensor=True, batch_size=4)
        emb = torch.nn.functional.normalize(emb, dim=1)
        sims = emb @ emb.T

        hard_negatives = []
        for i in range(0, len(texts), 2):
            row = sims[i].clone()
            row[i] = row[i+1] = -1  # Исключаем саму пару

            # Только если сходство высокое, считаем "жестким" негативом
            neg_indices = torch.where(row > 0.7)[0]
            if len(neg_indices) > 0:
                top_k = min(k, len(neg_indices))
                selected = neg_indices[torch.topk(row[neg_indices], top_k).indices]
                for idx in selected:
                    hard_negatives.append(InputExample(texts=[texts[i], texts[idx]], label=0.0))

        return hard_negatives

hard_miner = HardNegativeMiner(st_model, warmup_epochs=2)

class HardNegDataset(Dataset):
    def __init__(self, examples, hard_miner):
        self.examples = examples
        self.hard_miner = hard_miner

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

    def collate(self, batch):
        hard_negs = self.hard_miner.mine_hard_negatives(batch)
        return batch + hard_negs

hard_ds = HardNegDataset(train_ex, hard_miner)
train_loader_hard = DataLoader(
    hard_ds,
    shuffle=True,
    batch_size=4,
    collate_fn=lambda batch: hard_ds.collate(batch)
)

In [13]:
import re

class PostProcessor:
    # Для греческих букв
    GREEK_MAP = {'альфа': 'α', 'альфa': 'α', 'alpha': 'α', 'а': 'α',
                 'бета': 'β', 'бетa': 'β', 'beta': 'β', 'в': 'β',
                 'гамма': 'γ', 'гаммa': 'γ', 'gamma': 'γ',
                 'дельта': 'δ', 'дельтa': 'δ', 'delta': 'δ'}

    # Корни электролитов
    ELECTROLYTES = {'натри', 'кали', 'магни', 'кальц', 'хлорид', 'фосфат', 'магнез'}

    # Регулярки
    WORDS_RU = re.compile(r'[а-яё]+', re.IGNORECASE)
    CYP_PATTERN = re.compile(r'cyp(\d+[a-z]+\d+)', re.IGNORECASE)

    @staticmethod
    def has(text, *words):
        text = text.lower()
        return any(w.lower() in text for w in words)

    @staticmethod
    def extract_greek(text):
        text = text.lower()
        return {PostProcessor.GREEK_MAP[key] for key in PostProcessor.GREEK_MAP if key in text}

    @staticmethod
    def extract_cyp(text):
        return set(PostProcessor.CYP_PATTERN.findall(text))

    @staticmethod
    def extract_electrolytes(text):
        text = text.lower()
        return {stem for stem in PostProcessor.ELECTROLYTES if stem in text}

    @staticmethod
    def should_filter_similarity(term1, term2, similarity):
        t1, t2 = term1.lower(), term2.lower()

        # 1. Разные греческие буквы (α, β, γ)
        g1, g2 = PostProcessor.extract_greek(t1), PostProcessor.extract_greek(t2)
        if g1 and g2 and g1 != g2:
            return min(similarity, 0.88)

        # 2. Разные CYP
        c1, c2 = PostProcessor.extract_cyp(t1), PostProcessor.extract_cyp(t2)
        if c1 and c2 and c1 != c2:
            return min(similarity, 0.85)

        # 3. Разные электролиты
        e1, e2 = PostProcessor.extract_electrolytes(t1), PostProcessor.extract_electrolytes(t2)
        if e1 and e2 and e1 != e2:
            return min(similarity, 0.83)

        # 4. Бета1 vs бета2 + блокатор
        is_beta1_t1 = PostProcessor.has(t1, 'бета1', 'beta1')
        is_beta2_t1 = PostProcessor.has(t1, 'бета2', 'beta2')
        is_beta1_t2 = PostProcessor.has(t2, 'бета1', 'beta1')
        is_beta2_t2 = PostProcessor.has(t2, 'бета2', 'beta2')

        if ((is_beta1_t1 and is_beta2_t2) or (is_beta2_t1 and is_beta1_t2)):
            if PostProcessor.has(t1, 'блокатор', 'антагонист', 'ингибитор') or \
               PostProcessor.has(t2, 'блокатор', 'антагонист', 'ингибитор'):
                return min(similarity, 0.80)
            return min(similarity, 0.84)

        return similarity  # не меняем


# Инициализация
post_processor = PostProcessor()

In [14]:
# --- Оценка с динамическим порогом ---
def find_optimal_threshold(labels, scores):
    precision, recall, thresholds = precision_recall_curve(labels, scores)
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
    optimal_idx = np.argmax(f1_scores)
    return thresholds[optimal_idx] if optimal_idx < len(thresholds) else 0.5

def evaluate_with_dynamic_threshold(loader, model, post_processor=None):
    labels, scores = [], []
    processed_examples = []

    for batch in loader:
        t1 = [ex.texts[0] for ex in batch]
        t2 = [ex.texts[1] for ex in batch]
        e1 = model.encode(t1, convert_to_tensor=True, batch_size=4)
        e2 = model.encode(t2, convert_to_tensor=True, batch_size=4)
        sim = torch.nn.functional.cosine_similarity(e1, e2).cpu().numpy()

        # Применяем пост-фильтрацию
        if post_processor:
            filtered_sim = []
            for i, s in enumerate(sim):
                filtered_score = post_processor.should_filter_similarity(t1[i], t2[i], s)
                filtered_sim.append(filtered_score)
            sim = np.array(filtered_sim)

        labels.extend([ex.label for ex in batch])
        scores.extend(sim.tolist())
        processed_examples.extend([(t1[i], t2[i], sim[i], batch[i].label) for i in range(len(batch))])

    # Находим оптимальный порог
    optimal_threshold = find_optimal_threshold(labels, scores)

    # Вычисляем метрики
    preds = [1 if s >= optimal_threshold else 0 for s in scores]
    p, r, f, _ = precision_recall_fscore_support(labels, preds, average='binary')
    auc = roc_auc_score(labels, scores)

    return {
        'precision': p,
        'recall': r,
        'f1': f,
        'roc_auc': auc,
        'optimal_threshold': optimal_threshold,
        'processed_examples': processed_examples
    }


In [15]:
# Контрастное обучение с callback для обновления эпохи
class HardNegativeCallback:
    def __init__(self, hard_miner):
        self.hard_miner = hard_miner

    def on_epoch_begin(self, *args, **kwargs):
        if hasattr(self.hard_miner, 'set_epoch'):
            self.hard_miner.set_epoch(kwargs.get('epoch', 0))

# Создаем evaluator с post-processing
class CustomBinaryClassificationEvaluator(BinaryClassificationEvaluator):
    def __init__(self, *args, post_processor=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.post_processor = post_processor

    def __call__(self, model, output_path=None, epoch=-1, steps=-1):
        # Переопределяем метод для использования post-processor
        if self.post_processor:
            # Временно заменяем encode метод
            original_encode = model.encode
            def patched_encode(sentences, *args, **kwargs):
                embeddings = original_encode(sentences, *args, **kwargs)
                return embeddings
            model.encode = patched_encode

            # Выполняем оценку
            result = super().__call__(model, output_path, epoch, steps)

            # Восстанавливаем оригинальный метод
            model.encode = original_encode
            return result
        else:
            return super().__call__(model, output_path, epoch, steps)

evaluator = CustomBinaryClassificationEvaluator.from_input_examples(
    val_ex, name="val", show_progress_bar=False, post_processor=post_processor
)

# Callback для обновления эпохи в hard miner
hard_negative_callback = HardNegativeCallback(hard_miner)

st_model.fit(
    train_objectives=[(train_loader_hard, losses.ContrastiveLoss(model=st_model))],
    evaluator=evaluator,
    epochs=6,
    evaluation_steps=max(100, len(train_loader_hard) // 5),
    warmup_steps=50,
    use_amp=True,
    output_path=OUTPUT_DIR,
    optimizer_params={'lr': 2e-5},
    callback=hard_negative_callback.on_epoch_begin if hasattr(hard_negative_callback, 'on_epoch_begin') else None
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss,Val Cosine Accuracy,Val Cosine Accuracy Threshold,Val Cosine F1,Val Cosine F1 Threshold,Val Cosine Precision,Val Cosine Recall,Val Cosine Ap,Val Cosine Mcc
897,0.0186,No log,0.914846,0.70886,0.839361,0.700812,0.834448,0.844332,0.902946,0.781457
1794,0.0094,No log,0.944271,0.712957,0.897498,0.68173,0.858025,0.940778,0.952257,0.860236
2691,0.0065,No log,0.946946,0.717437,0.900418,0.717437,0.890728,0.910321,0.957391,0.86436
3588,0.0063,No log,0.960767,0.736355,0.926544,0.728295,0.914333,0.939086,0.975504,0.899936
4485,0.0056,No log,0.963888,0.738066,0.933002,0.711693,0.912621,0.954315,0.978916,0.908726
4486,0.0056,No log,0.964779,0.716763,0.934002,0.716763,0.922442,0.945854,0.978794,0.91012
5382,0.0033,No log,0.967008,0.774922,0.93771,0.743715,0.932998,0.94247,0.980761,0.915295
6279,0.0038,No log,0.963888,0.774737,0.931298,0.774737,0.933673,0.928934,0.981944,0.90681
7176,0.0032,No log,0.970575,0.776519,0.944444,0.776519,0.939698,0.949239,0.983739,0.924455
8073,0.0035,No log,0.974588,0.79625,0.951342,0.70007,0.943428,0.959391,0.98672,0.933797


  a = torch.tensor(a)


Step,Training Loss,Validation Loss,Val Cosine Accuracy,Val Cosine Accuracy Threshold,Val Cosine F1,Val Cosine F1 Threshold,Val Cosine Precision,Val Cosine Recall,Val Cosine Ap,Val Cosine Mcc
897,0.0186,No log,0.914846,0.70886,0.839361,0.700812,0.834448,0.844332,0.902946,0.781457
1794,0.0094,No log,0.944271,0.712957,0.897498,0.68173,0.858025,0.940778,0.952257,0.860236
2691,0.0065,No log,0.946946,0.717437,0.900418,0.717437,0.890728,0.910321,0.957391,0.86436
3588,0.0063,No log,0.960767,0.736355,0.926544,0.728295,0.914333,0.939086,0.975504,0.899936
4485,0.0056,No log,0.963888,0.738066,0.933002,0.711693,0.912621,0.954315,0.978916,0.908726
4486,0.0056,No log,0.964779,0.716763,0.934002,0.716763,0.922442,0.945854,0.978794,0.91012
5382,0.0033,No log,0.967008,0.774922,0.93771,0.743715,0.932998,0.94247,0.980761,0.915295
6279,0.0038,No log,0.963888,0.774737,0.931298,0.774737,0.933673,0.928934,0.981944,0.90681
7176,0.0032,No log,0.970575,0.776519,0.944444,0.776519,0.939698,0.949239,0.983739,0.924455
8073,0.0035,No log,0.974588,0.79625,0.951342,0.70007,0.943428,0.959391,0.98672,0.933797


In [16]:
# --- Финальная оценка ---
best = SentenceTransformer(OUTPUT_DIR, device=device)
results = evaluate_with_dynamic_threshold(test_loader, best, post_processor)

print("Test metrics:")
print(f"Precision: {results['precision']:.4f}")
print(f"Recall: {results['recall']:.4f}")
print(f"F1-score: {results['f1']:.4f}")
print(f"ROC-AUC: {results['roc_auc']:.4f}")
print(f"Optimal threshold: {results['optimal_threshold']:.4f}")

# Анализ проблемных пар
print("\nAnalysis of problematic pairs:")
problematic_pairs = [
    ("увеличение содержание ион калий в сыворотка", "увеличивать выведение ион натрий"),
    ("cyp2c8", "cyp2c9"),
    ("стимулировать образование весь тип интерферон α", "стимулировать образование весь тип интерферон β"),
    ("рассабление гладкий мышца мочевыводящий путь", "рассабление гладкий мышца желчевыводить путь"),
    ("бета2-адренорецептор гладкий мускулатура матка", "бета1-адреноблокатор")
]

for t1, t2 in problematic_pairs:
    e1 = best.encode([t1], convert_to_tensor=True)
    e2 = best.encode([t2], convert_to_tensor=True)
    sim = torch.nn.functional.cosine_similarity(e1, e2).item()

    # Применяем пост-фильтрацию
    filtered_sim = post_processor.should_filter_similarity(t1, t2, sim)

    print(f"'{t1}' vs '{t2}': raw={sim:.4f}, filtered={filtered_sim:.4f}")

Test metrics:
Precision: 0.9764
Recall: 0.9847
F1-score: 0.9805
ROC-AUC: 0.9964
Optimal threshold: 0.7698

Analysis of problematic pairs:
'увеличение содержание ион калий в сыворотка' vs 'увеличивать выведение ион натрий': raw=0.3660, filtered=0.3660
'cyp2c8' vs 'cyp2c9': raw=0.9956, filtered=0.8500
'стимулировать образование весь тип интерферон α' vs 'стимулировать образование весь тип интерферон β': raw=0.9850, filtered=0.9850
'рассабление гладкий мышца мочевыводящий путь' vs 'рассабление гладкий мышца желчевыводить путь': raw=0.8372, filtered=0.8372
'бета2-адренорецептор гладкий мускулатура матка' vs 'бета1-адреноблокатор': raw=0.1199, filtered=0.1199


In [17]:
import shutil
import os

folder_name = "trained_synonym_model_2"

# Архивируем папку в .zip
shutil.make_archive(f'/content/{folder_name}', 'zip', f'/content/{folder_name}')

# Скачиваем архив
from google.colab import files
files.download(f'/content/{folder_name}.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>