In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64,expandable_segments:True"

import random
import json
import numpy as np
import torch
from itertools import combinations
from collections import defaultdict

from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling,
    Trainer, TrainingArguments, EarlyStoppingCallback
)
from sentence_transformers import SentenceTransformer, models, InputExample, losses, evaluation
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score

In [None]:
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("torch.cuda.is_available():", torch.cuda.is_available())

CORPUS_FILE = "/content/text_corpus_grls_rlsnet.txt"
SYN_DATA_FILE_LIST = [
    "/content/clusters_2025_08_20_15.json",
    "/content/clusters_2025_08_20_200.json",
    "/content/clusters_2025_08_20_400.json",
    "/content/clusters_synosym_dict.json"
]

OUTPUT_DIR = "trained_synonym_model"

# 1. Domain-Adaptive Pretraining (MLM) с короткими последовательностями
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len=128):
        toks = tokenizer(texts, truncation=True, padding="max_length",
                         max_length=max_len, return_tensors="pt")
        self.input_ids = toks["input_ids"]
    def __len__(self):
        return len(self.input_ids)
    def __getitem__(self, idx):
        return {"input_ids": self.input_ids[idx]}

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

with open(CORPUS_FILE, encoding="utf-8") as f:
    lines = [l.strip() for l in f if len(l.split()) > 3]
texts = list(set(lines))

dataset = TextDataset(texts, tokenizer, max_len=128)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(
    output_dir="dapt_mlm",
    overwrite_output_dir=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    fp16=True,
    save_steps=5000,
    logging_steps=2000,
    learning_rate=3e-5
)

model = AutoModelForMaskedLM.from_pretrained("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Some weights of XLMRobertaForMaskedLM were not initialized from the model checkpoint at sentence-transformers/paraphrase-multilingual-mpnet-base-v2 and are newly initialized: ['lm_head.bias', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
trainer.train()
model.save_pretrained("dapt_mlm")
tokenizer.save_pretrained("dapt_mlm")

Step,Training Loss


('dapt_mlm/tokenizer_config.json',
 'dapt_mlm/special_tokens_map.json',
 'dapt_mlm/sentencepiece.bpe.model',
 'dapt_mlm/added_tokens.json',
 'dapt_mlm/tokenizer.json')

In [None]:
# --- Шаг 2: подготовка примеров ---
pos_pairs = []
neg_pairs = []

for filename in SYN_DATA_FILE_LIST:
    with open(filename, 'r', encoding="utf-8") as f:
        clusters = json.load(f)["clusters"]


    for cluster_data in clusters.values():
        grouped = defaultdict(list)
        all_labels = []

        for label, group_id in cluster_data["labels"]:
            all_labels.append((label, group_id))
            if group_id is not None:
                grouped[group_id].append(label)

        # Положительные пары — внутри одного подмножества с одинаковыми group_id ≠ None
        for group in grouped.values():
            if len(group) > 1:
                pos_pairs.extend(combinations(group, 2))

        # Отрицательные пары — между разными номерами внутри кластера
        group_ids = list(grouped.keys())
        for i in range(len(group_ids)):
            for j in range(i + 1, len(group_ids)):
                group_a = grouped[group_ids[i]]
                group_b = grouped[group_ids[j]]
                neg_pairs.extend((a, b) for a in group_a for b in group_b)

        # Отрицательные пары для всех с null
        null_items = [label for label, gid in all_labels if gid is None]
        non_null_items = [label for label, gid in all_labels if gid is not None]

        for null_label in null_items:
            for other_label in non_null_items:
                neg_pairs.append((null_label, other_label))

# Удаление дубликатов
pos_pairs = list(set(pos_pairs))
neg_pairs = list(set(neg_pairs) - set(pos_pairs))

# Балансировка
random.shuffle(neg_pairs)
neg_pairs = neg_pairs[:len(pos_pairs)]

examples = [InputExample(texts=[a, b], label=1.0) for a, b in pos_pairs] + \
           [InputExample(texts=[a, b], label=0.0) for a, b in neg_pairs]
random.shuffle(examples)

In [None]:
# Деление на тренировочную, валидационную и тестовую выборки
n = len(examples)
train_ex = examples[:int(0.8 * n)]
val_ex   = examples[int(0.8 * n):int(0.9 * n)]
test_ex  = examples[int(0.9 * n):]

train_loader = DataLoader(train_ex, shuffle=True, batch_size=4, collate_fn=lambda batch: batch)
val_loader   = DataLoader(val_ex, shuffle=False, batch_size=8, collate_fn=lambda batch: batch)
test_loader  = DataLoader(test_ex, shuffle=False, batch_size=8, collate_fn=lambda batch: batch)

In [None]:
# --- Шаг 3: SentenceTransformer из DAPT-чекпоинта ---
transformer = models.Transformer("dapt_mlm", max_seq_length=128)
pooling     = models.Pooling(transformer.get_word_embedding_dimension(),
                             pooling_mode_mean_tokens=True)
st_model    = SentenceTransformer(modules=[transformer, pooling], device=device)

# жёсткие негативы
def mine_hard(batch, model, k=1):
    texts = [t for ex in batch for t in ex.texts]
    emb   = model.encode(texts, convert_to_tensor=True, batch_size=4)
    emb   = torch.nn.functional.normalize(emb, dim=1)
    sims  = emb @ emb.T
    hard  = []
    for i in range(0, len(texts), 2):
        row = sims[i].clone()
        row[i] = row[i+1] = -1
        neg_ids = torch.topk(row, k).indices.tolist()
        for j in neg_ids:
            hard.append(InputExample(texts=[texts[i], texts[j]], label=0.0))
    return hard

class HardNegDataset(Dataset):
    def __init__(self, examples, model):
        self.examples = examples
        self.model    = model
    def __len__(self):
        return len(self.examples)
    def __getitem__(self, idx):
        return self.examples[idx]
    def collate(self, batch):
        return batch + mine_hard(batch, self.model)

hard_ds = HardNegDataset(train_ex, st_model)
train_loader_hard = DataLoader(
    hard_ds,
    shuffle=True,
    batch_size=4,
    collate_fn=lambda batch: hard_ds.collate(batch)
)

# Контрастное обучение
train_loss = losses.ContrastiveLoss(model=st_model)
evaluator  = evaluation.BinaryClassificationEvaluator.from_input_examples(
    val_ex, name="val", show_progress_bar=False
)

st_model.fit(
    train_objectives=[(train_loader_hard, train_loss)],
    evaluator=evaluator,
    epochs=10,
    evaluation_steps=len(train_loader_hard),
    warmup_steps=50,
    use_amp=True,
    output_path=OUTPUT_DIR,
    optimizer_params={'lr': 2e-5}
)

Some weights of XLMRobertaModel were not initialized from the model checkpoint at dapt_mlm and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss,Val Cosine Accuracy,Val Cosine Accuracy Threshold,Val Cosine F1,Val Cosine F1 Threshold,Val Cosine Precision,Val Cosine Recall,Val Cosine Ap,Val Cosine Mcc
258,No log,No log,0.96124,0.695152,0.964029,0.695152,0.957143,0.971014,0.971978,0.922129
516,0.009200,No log,0.937984,0.797684,0.943662,0.596288,0.917808,0.971014,0.97474,0.876537
774,0.009200,No log,0.945736,0.90818,0.948148,0.90818,0.969697,0.927536,0.979386,0.892265
1032,0.002000,No log,0.945736,0.865521,0.948148,0.865521,0.969697,0.927536,0.980088,0.892265
1290,0.002000,No log,0.937984,0.888971,0.940299,0.888971,0.969231,0.913043,0.979718,0.877593
1548,0.001000,No log,0.937984,0.866359,0.942029,0.757279,0.942029,0.942029,0.983018,0.875362
1806,0.001000,No log,0.945736,0.700096,0.950355,0.700096,0.930556,0.971014,0.984847,0.891566
2064,0.000600,No log,0.930233,0.828539,0.937931,0.55908,0.894737,0.985507,0.983349,0.86394
2322,0.000600,No log,0.937984,0.776559,0.942029,0.776559,0.942029,0.942029,0.982385,0.875362
2580,0.000500,No log,0.937984,0.771797,0.942029,0.771797,0.942029,0.942029,0.981244,0.875362


  a = torch.tensor(a)


In [None]:
# --- Шаг 4: оценка ---
def evaluate(loader, model, thr=0.92):
    labels, scores = [], []
    for batch in loader:
        t1 = [ex.texts[0] for ex in batch]
        t2 = [ex.texts[1] for ex in batch]
        e1 = model.encode(t1, convert_to_tensor=True, batch_size=4)
        e2 = model.encode(t2, convert_to_tensor=True, batch_size=4)
        sim = torch.nn.functional.cosine_similarity(e1, e2).cpu().numpy()
        labels.extend([ex.label for ex in batch])
        scores.extend(sim.tolist())
    preds = [1 if s >= thr else 0 for s in scores]
    p, r, f, _ = precision_recall_fscore_support(labels, preds, average='binary')
    auc       = roc_auc_score(labels, scores)
    return {'precision': p, 'recall': r, 'f1': f, 'roc_auc': auc}

best = SentenceTransformer(OUTPUT_DIR, device=device)
print("Test metrics:", evaluate(test_loader, best))

Test metrics: {'precision': 1.0, 'recall': 0.9538461538461539, 'f1': 0.9763779527559056, 'roc_auc': np.float64(0.9911057692307693)}
