This script trains model with positive and negative pairs.

In [28]:
import os
os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1'

import numpy as np
import pandas as pd
import wandb
import torch
import argparse
from tqdm import tqdm
from torch.optim import AdamW
from datetime import datetime
from sentence_transformers import losses, SentenceTransformer
from types import SimpleNamespace

from modules.ModelFunctions import get_ST_model, auto_load_model
from modules.timed_logger import logger
from modules.metrics import evaluate_embedding_similarity_with_mrr
from modules.STHardNegMiner import mine_negatives, _pairs_to_dataset
from sentence_transformers.util import mine_hard_negatives
from sentence_transformers import InputExample
from datasets import Dataset


def to_hf_dataset(dataset_iterable):
    rows = []
    for item in tqdm(dataset_iterable):
        rows.append({
            "sentence1": item["sentence1"],
            "sentence2": item["sentence2"],
            "label": int(item["label"])   # ensure 0/1
        })
    ds = Dataset.from_list(rows)
    # Ensure correct column order for the loss: [inputs..., label]
    ds = ds.select_columns(["sentence1", "sentence2", "label"])
    return ds



logger.reset_timer()

In [29]:

args = {
    "test_mode": False,
    
    "no_relation": False,  # Disable relation data even if files exist or config says True
    "range_min": 10,       # Minimum rank for candidate negatives
    "range_max": 50,       # Maximum rank for candidate negatives
    "relative_margin": 0.01,  # Relative margin for mining
    "num_neg_matching": None,  # Negatives per anchor for matching mining (override)
    "num_neg_relation": None,  # Negatives per anchor for relation mining (override)
    "sampling_strategy": "top",  # Negative sampling strategy from candidates
    "batch_size": 256,      # Batch size for training
    "no_faiss": False,           # Disable FAISS acceleration in miner
    "model_checkpoint": "none",  # Path to model checkpoint to load (default: none - use HF base model)
    
    
    "use_relation": True,
    "n_pos_matching": 5,
    "n_neg_matching": 5,
    "n_fp_matching": 5,
    "n_pos_relation": 5,
    "n_neg_relation": 5,
    "n_fp_relation": 5,
    
}

args = SimpleNamespace(**args)


# Initialize Model

In [30]:

logger.log("Loading Model (initial training)")
output_dir = f"output/finetune_initial/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
base_model = 'all-MiniLM-L6-v2'

# Load model - use Hugging Face model directly or specified checkpoint
if args.model_checkpoint and args.model_checkpoint != "none":
    model, tokenizer = auto_load_model(args.model_checkpoint)
else:
    model, tokenizer = None, None

if model is None:
    # Load base ST model with special tokens for relation support
    logger.log(f"Loading base model with special tokens: {base_model}")
    model, tokenizer = get_ST_model(base_model)

# start_epoch = args.epoch
# start_batch_i = args.batch_i + 1
# start_global_step = args.global_step + 1

2025-10-02 15:10:28 - Loading Model (initial training) (Elapsed: 0.35s, Since last: 0.35s)
2025-10-02 15:10:28 - Loading base model with special tokens: all-MiniLM-L6-v2 (Elapsed: 0.35s, Since last: 0.00s)
2025-10-02 15:10:28 - Use pytorch device_name: cuda:0
2025-10-02 15:10:28 - Load pretrained SentenceTransformer: models/all-MiniLM-L6-v2_ST_childof_parentof\auto_save_1_20251001_164241


Loaded latest auto-saved model from: models/all-MiniLM-L6-v2_ST_childof_parentof\auto_save_1_20251001_164241


# Load OMOP data

In [31]:
logger.log("Loading training data")

matching_base_path = "data/matching"
relation_base_path = "data/relation"
seed = 42


target_concepts_path = os.path.join(matching_base_path, 'target_concepts.feather')
target_concepts = pd.read_feather(target_concepts_path)
matching_name_bridge = pd.read_feather(os.path.join(matching_base_path, 'condition_matching_name_bridge_train.feather'))
matching_name_table = pd.read_feather(os.path.join(matching_base_path, 'condition_matching_name_table_train.feather'))

target_concepts.shape, matching_name_bridge.shape, matching_name_table.shape

2025-10-02 15:10:29 - Loading training data (Elapsed: 0.84s, Since last: 0.49s)


((160288, 2), (643319, 2), (566536, 5))

# Training dataset creation

In [26]:
from modules.Dataset import PositiveDataset, NegativeDataset, CombinedDataset
matching_pos = PositiveDataset(
    target_concepts=target_concepts,
    name_table=matching_name_table,
    name_bridge=matching_name_bridge,
    max_elements=args.n_pos_matching,
    seed=seed
)
print(matching_pos)
print(matching_pos[0])

matching_neg = NegativeDataset(
    target_concepts=target_concepts,
    name_table=matching_name_table,
    blacklist_bridge=matching_name_bridge,
    max_elements=args.n_neg_matching,
    seed=seed
)
print(matching_neg)
print(matching_neg[0])


# if args.test_mode:
#     anchor_positive_match = anchor_positive_match.iloc[:1000]

PositiveDataset(length=276059, label=1, seed=42)
{'sentence1': 'Epidermal burn of multiple sites of upper limb', 'sentence2': 'Corrosion of first degree of multiple sites of right shoulder and upper limb, except wrist and hand', 'label': 1}
NegativeDataset(length=801440, seed=42)
{'sentence1': 'Infection present (Deprecated)', 'sentence2': 'Familial Retinoblastomas', 'label': 0}


In [32]:
matching_pos_pairs = to_hf_dataset(matching_pos)
print(matching_pos_pairs)
matching_neg_pairs = to_hf_dataset(matching_neg)
print(matching_neg_pairs)

  1%|▏         | 4056/276059 [00:00<00:06, 40451.35it/s]

100%|██████████| 276059/276059 [00:06<00:00, 41719.59it/s]


Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 276059
})


100%|██████████| 801440/801440 [00:18<00:00, 43740.47it/s]


Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 801440
})


In [34]:

name_table_relation = pd.read_feather(os.path.join(relation_base_path, 'name_table_relation.feather'))
name_bridge_relation = pd.read_feather(os.path.join(relation_base_path, 'name_bridge_relation.feather'))
if args.use_relation:
    from modules.Dataset import PositiveDataset as RelPositiveDataset
    from modules.Dataset import NegativeDataset as RelNegativeDataset
    relation_pos = RelPositiveDataset(
        target_concepts=target_concepts,
        name_table=name_table_relation,
        name_bridge=name_bridge_relation,
        max_elements=args.n_pos_relation,
        seed=seed
    )
    print(relation_pos)
    print(relation_pos[0])

    relation_neg = RelNegativeDataset(
        target_concepts=target_concepts,
        name_table=name_table_relation,
        blacklist_bridge=name_bridge_relation,
        max_elements=args.n_neg_relation,
        seed=seed
    )
    print(relation_neg)
    print(relation_neg[0])



PositiveDataset(length=135385, label=1, seed=42)
{'sentence1': 'Meningococcal infectious disease', 'sentence2': '<|parent of|>Acute meningococcal pericarditis', 'label': 1}
NegativeDataset(length=801440, seed=42)
{'sentence1': 'Infection present (Deprecated)', 'sentence2': '<|parent of|>Papillary adenocarcinoma, NOS, of tonsillar fossa', 'label': 0}


In [38]:
if args.use_relation:
    relation_pos_pairs = to_hf_dataset(relation_pos)
    print(relation_pos_pairs)
    relation_neg_pairs = to_hf_dataset(relation_neg)
    print(relation_neg_pairs)
else:
    relation_pos_pairs = None
    relation_neg_pairs = None

  0%|          | 0/135385 [00:00<?, ?it/s]

100%|██████████| 135385/135385 [00:03<00:00, 42606.31it/s]


Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 135385
})


100%|██████████| 801440/801440 [00:18<00:00, 42676.59it/s]


Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 801440
})


## Combine and Create final training dataset

In [40]:
from datasets import concatenate_datasets
if args.use_relation:
    all_pairs = [matching_pos_pairs, matching_neg_pairs, relation_pos_pairs, relation_neg_pairs]
else:
    all_pairs = [matching_pos_pairs, matching_neg_pairs]
    
ds_all = concatenate_datasets(all_pairs)
ds_all


Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 2014324
})

# Validation dataset

In [19]:
logger.log("Loading validation data")
condition_matching_valid = pd.read_feather(os.path.join(matching_base_path, 'condition_matching_valid.feather'))
print(condition_matching_valid.columns)

condition_matching_train_subset = pd.read_feather(os.path.join(matching_base_path, 'condition_matching_train_subset.feather'))
print(condition_matching_train_subset.columns)

condition_relation_train_subset = pd.read_feather(os.path.join(relation_base_path, 'condition_relation_train_subset.feather'))
print(condition_relation_train_subset.columns)

2025-10-02 15:06:15 - Loading validation data (Elapsed: 27.39s, Since last: 27.39s)


Index(['corpus_name', 'query_name', 'corpus_id', 'query_id', 'label'], dtype='object')
Index(['query_name', 'corpus_name', 'query_id', 'corpus_id', 'label'], dtype='object')
Index(['query_name', 'corpus_name', 'query_id', 'corpus_id', 'label'], dtype='object')


# Model Training settings

In [20]:
from sentence_transformers.evaluation import SentenceEvaluator
# Create a custom evaluator using your existing evaluation function
class CustomMRREvaluator(SentenceEvaluator):
    def __init__(self, validation_data, name="custom_eval"):
        self.validation_data = validation_data
        self.name = name
        
    def __call__(self, model, output_path=None, epoch=None, steps=None):
        # Use your existing evaluation function
        eval_results = evaluate_embedding_similarity_with_mrr(model, self.validation_data)
        
        # Return the main metric (MRR) for model selection
        eval_results = {f"{self.name}/{k}": v for k, v in eval_results.items()}
        return eval_results
    

# Create custom evaluator using your existing function
evaluator_valid = CustomMRREvaluator(condition_matching_valid, name="eval_valid")
evaluator_train_matching = CustomMRREvaluator(condition_matching_train_subset, name="train_matching")
evaluator_train_relation = CustomMRREvaluator(condition_relation_train_subset, name="train_relation")

evaluator_valid(model)

Batches:   0%|          | 0/84 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Batches:   0%|          | 0/16 [00:00<?, ?it/s]

{'eval_valid/roc_auc': 0.7680990833089902,
 'eval_valid/average_precision': 0.4821933667885799,
 'eval_valid/f1_score': 0.44631901840490795,
 'eval_valid/precision': 0.36014851485148514,
 'eval_valid/recall': 0.5866935483870968,
 'eval_valid/accuracy': 0.7435168738898756,
 'eval_valid/best_hit1': 0.7157258064516129,
 'eval_valid/best_hit3': 0.8487903225806451,
 'eval_valid/best_hit5': 0.9153225806451613,
 'eval_valid/best_hit10': 1.0,
 'eval_valid/best_hit20': 1.0,
 'eval_valid/best_hit50': 1.0,
 'eval_valid/best_hit100': 1.0,
 'eval_valid/best_reciprocal_rank': 0.802284946236559,
 'eval_valid/worst_hit1': 0.7157258064516129,
 'eval_valid/worst_hit3': 0.8487903225806451,
 'eval_valid/worst_hit5': 0.9153225806451613,
 'eval_valid/worst_hit10': 1.0,
 'eval_valid/worst_hit20': 1.0,
 'eval_valid/worst_hit50': 1.0,
 'eval_valid/worst_hit100': 1.0,
 'eval_valid/worst_reciprocal_rank': 0.802284946236559}

In [43]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments

max_saves = 4

epoch_num = 10
batch_size = args.batch_size

wandb_report_steps = 256
arg_eval_steps = 2048
arg_saving_steps = arg_eval_steps * 2

learning_rate = 2e-5


loss_func = losses.ContrastiveLoss(model=model)
optimizer = AdamW(model.parameters(), lr=learning_rate)


# Create training arguments equivalent to your custom training loop
training_args = SentenceTransformerTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=epoch_num,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    fp16=False,
    bf16=False,
    learning_rate=learning_rate,
    eval_strategy="steps",
    eval_steps=arg_eval_steps,
    save_strategy="steps",
    save_steps=arg_saving_steps,
    save_total_limit=max_saves,
    logging_steps=wandb_report_steps,
    run_name=output_dir,
    seed=seed,
    data_seed=seed,
    report_to="wandb" if not args.test_mode else "none",
    remove_unused_columns=False,
    dataloader_drop_last=False,
    dataloader_num_workers=0,  # Avoid multiprocessing issues on Windows
    metric_for_best_model="eval_valid/roc_auc",
    greater_is_better=True,
    load_best_model_at_end=True,
)



In [44]:
# Initialize the trainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=ds_all,
    eval_dataset=None,  # Using custom evaluator instead
    loss=loss_func,
    evaluator=[evaluator_valid, evaluator_train_matching, evaluator_train_relation],
    tokenizer=tokenizer,
)

# Start training
logger.log(f"Starting SentenceTransformer training. Model saved to: {output_dir}")
trainer.train(resume_from_checkpoint=args.model_checkpoint if args.model_checkpoint != "none" else None)

# Save the final model
trainer.save_model()
logger.log(f"Training completed. Model saved to {output_dir}")
logger.done()

2025-10-02 15:15:45 - Starting SentenceTransformer training. Model saved to: output/finetune_initial/2025-10-02_15-10-28 (Elapsed: 316.57s, Since last: 34.95s)


  0%|          | 0/78690 [00:00<?, ?it/s]

{'loss': 0.0113, 'grad_norm': 0.07827557623386383, 'learning_rate': 6.506544668954125e-07, 'epoch': 0.03}
{'loss': 0.0067, 'grad_norm': 0.042671654373407364, 'learning_rate': 1.301308933790825e-06, 'epoch': 0.07}
{'loss': 0.005, 'grad_norm': 0.023665642365813255, 'learning_rate': 1.951963400686237e-06, 'epoch': 0.1}
{'loss': 0.0044, 'grad_norm': 0.021462224423885345, 'learning_rate': 2.60261786758165e-06, 'epoch': 0.13}
{'loss': 0.0041, 'grad_norm': 0.026749519631266594, 'learning_rate': 3.253272334477062e-06, 'epoch': 0.16}
{'loss': 0.004, 'grad_norm': 0.023066626861691475, 'learning_rate': 3.903926801372474e-06, 'epoch': 0.2}
{'loss': 0.0039, 'grad_norm': 0.018647313117980957, 'learning_rate': 4.554581268267887e-06, 'epoch': 0.23}
{'loss': 0.0037, 'grad_norm': 0.03143948316574097, 'learning_rate': 5.2052357351633e-06, 'epoch': 0.26}


Batches:   0%|          | 0/84 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Batches:   0%|          | 0/86 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Batches:   0%|          | 0/103 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

{'eval_valid/roc_auc': 0.7756002309115442, 'eval_valid/average_precision': 0.4621983271783827, 'eval_valid/f1_score': 0.41245136186770426, 'eval_valid/precision': 0.2717948717948718, 'eval_valid/recall': 0.8548387096774194, 'eval_valid/accuracy': 0.5708703374777975, 'eval_valid/best_hit1': 0.6915322580645161, 'eval_valid/best_hit3': 0.8306451612903226, 'eval_valid/best_hit5': 0.9092741935483871, 'eval_valid/best_hit10': 1.0, 'eval_valid/best_hit20': 1.0, 'eval_valid/best_hit50': 1.0, 'eval_valid/best_hit100': 1.0, 'eval_valid/best_reciprocal_rank': 0.7875672043010753, 'eval_valid/worst_hit1': 0.6915322580645161, 'eval_valid/worst_hit3': 0.8306451612903226, 'eval_valid/worst_hit5': 0.9092741935483871, 'eval_valid/worst_hit10': 1.0, 'eval_valid/worst_hit20': 1.0, 'eval_valid/worst_hit50': 1.0, 'eval_valid/worst_hit100': 1.0, 'eval_valid/worst_reciprocal_rank': 0.7875672043010753, 'eval_train_matching/roc_auc': 0.796510080161091, 'eval_train_matching/average_precision': 0.5146340767190489

Batches:   0%|          | 0/84 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Batches:   0%|          | 0/86 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Batches:   0%|          | 0/103 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

{'eval_valid/roc_auc': 0.781222614029963, 'eval_valid/average_precision': 0.470859385346768, 'eval_valid/f1_score': 0.4044838860345633, 'eval_valid/precision': 0.2632218844984802, 'eval_valid/recall': 0.8729838709677419, 'eval_valid/accuracy': 0.5470692717584369, 'eval_valid/best_hit1': 0.6995967741935484, 'eval_valid/best_hit3': 0.8346774193548387, 'eval_valid/best_hit5': 0.9314516129032258, 'eval_valid/best_hit10': 1.0, 'eval_valid/best_hit20': 1.0, 'eval_valid/best_hit50': 1.0, 'eval_valid/best_hit100': 1.0, 'eval_valid/best_reciprocal_rank': 0.7945564516129031, 'eval_valid/worst_hit1': 0.6995967741935484, 'eval_valid/worst_hit3': 0.8346774193548387, 'eval_valid/worst_hit5': 0.9314516129032258, 'eval_valid/worst_hit10': 1.0, 'eval_valid/worst_hit20': 1.0, 'eval_valid/worst_hit50': 1.0, 'eval_valid/worst_hit100': 1.0, 'eval_valid/worst_reciprocal_rank': 0.7945564516129031, 'eval_train_matching/roc_auc': 0.8051521871116982, 'eval_train_matching/average_precision': 0.5147229242801051, 

2025-10-02 15:32:53 - Saving model checkpoint to output/finetune_initial/2025-10-02_15-10-28\checkpoint-4096
2025-10-02 15:32:53 - Save model to output/finetune_initial/2025-10-02_15-10-28\checkpoint-4096


{'loss': 0.003, 'grad_norm': 0.018925417214632034, 'learning_rate': 1.106112593722201e-05, 'epoch': 0.55}
{'loss': 0.0029, 'grad_norm': 0.0170204546302557, 'learning_rate': 1.1711780404117424e-05, 'epoch': 0.59}
{'loss': 0.0029, 'grad_norm': 0.02700810879468918, 'learning_rate': 1.2362434871012836e-05, 'epoch': 0.62}


KeyboardInterrupt: 