In [None]:
# !pip install -q sentence-transformers torch tqdm pandas numpy faiss-cpu

In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import random
import faiss
import torch
import os

os.environ["WANDB_DISABLED"] = "true"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
df = pd.read_csv(
    "finalData.csv",
    engine="python",         
    on_bad_lines="skip",      
    encoding="utf-8",
)

print(df.shape)

df["ascii_name_clean"] = (
    df["ascii_name_clean"]
    .astype(str)
    .str.lower()
    .str.replace(r"[\(\)\[\],\-]", " ", regex=True)
    .str.replace(r"\s+", " ", regex=True)
    .str.strip()
)

(387828, 2)


In [None]:
df.head()

Unnamed: 0,compound_id,ascii_name_clean
0,3,r 3 hydroxybutanoyl n 2
1,7,s + 3 carene
2,7,1s 3 7 7 trimethylbicyclo 4.1.0 hept 3 ene
3,7,+ 3 carene
4,7,1s 6r 3 7 7 trimethylbicyclo 4.1.0 hept 3 ene


In [None]:
MODEL_NAME = "all-mpnet-base-v2"  
BATCH_SIZE = 1024   
            
model = SentenceTransformer(
    MODEL_NAME,
   tokenizer_kwargs={"model_max_length": 32}

)

#NOTE to load in case of crashing
# CHECKPOINT_PATH = ""
# model = SentenceTransformer(CHECKPOINT_PATH)
# model.max_seq_length = 32

In [None]:
#NOTE check in case of loading a model to continue training
print(model.encode(["aspirin", "acetylsalicylic acid"], normalize_embeddings=True))

[[-0.00400591 -0.05172623 -0.02116138 ... -0.03963629 -0.05777787
  -0.05471736]
 [-0.0388556  -0.0110454  -0.04449013 ...  0.04872469 -0.00099425
  -0.04864759]]


In [None]:
MAX_POS_PER_ANCHOR = 8
MAX_TRAIN_PAIRS = 500_000 
train_examples = []
cid_groups = df.groupby("compound_id")["ascii_name_clean"].apply(list)

for names in cid_groups:
    if len(names) < 2: 
        continue

    names = list(set(names))
    for anchor in names:
        positives = [n for n in names if n != anchor]
        sampled = random.sample(
            positives, 
            min(MAX_POS_PER_ANCHOR, len(positives))
        )

        for pos in sampled:
            train_examples.append(InputExample(texts=[anchor, pos]))

random.shuffle(train_examples)
print("Training pairs:", len(train_examples))

if len(train_examples) > MAX_TRAIN_PAIRS:
    train_examples = random.sample(train_examples, MAX_TRAIN_PAIRS)

Training pairs: 1319118


In [None]:
train_loss = losses.CachedMultipleNegativesRankingLoss(
    model, mini_batch_size=64 
)

train_dataloader = DataLoader(
    train_examples,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=6,
    pin_memory=True
)

In [None]:
SAVE_DIR = "" #NOTE pzth to save the model after each epoch in case it crashes
os.makedirs(SAVE_DIR, exist_ok=True)

TOTAL_EPOCHS = 20
num_steps = len(train_dataloader) * TOTAL_EPOCHS
WARMUP_STEPS = int(0.1 * num_steps)

for epoch in range(TOTAL_EPOCHS):
    print(f"\n Epoch {epoch+1}/{TOTAL_EPOCHS}")

    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs = 1,
        warmup_steps = WARMUP_STEPS if epoch == 0 else 0,
        show_progress_bar = True, use_amp = False
    )

    epoch_ckpt = f"{SAVE_DIR}/epoch_{epoch+1}"
    model.save(epoch_ckpt)
    print(f"Epoch {epoch+1} checkpoint saved → {epoch_ckpt}")

In [None]:
embeds = model.encode(
    names, batch_size=256,
    normalize_embeddings=True,
    show_progress_bar=True
).astype("float32")

torch.cuda.empty_cache()
cid_lookup = dict(zip(df["ascii_name_clean"], df["compound_id"]))

In [None]:
def mine_hard_negatives_fast(
    embeds, names,
    cid_lookup, k=10,
    max_anchors=10_000
):
    d = embeds.shape[1]
    index = faiss.IndexHNSWFlat(d, 32, faiss.METRIC_INNER_PRODUCT)
    index.hnsw.efSearch = 64
    index.add(embeds)

    anchor_ids = np.random.choice(len(names), max_anchors, replace=False)
    anchor_embeds = embeds[anchor_ids]
    _, I = index.search(anchor_embeds, k)

    hard_triples = []
    for idx, nbrs in zip(anchor_ids, I):
        anchor, anchor_cid = names[idx], cid_lookup[anchor]
        pos = neg = None
        
        for j in nbrs[1:]:
            if cid_lookup[names[j]] == anchor_cid and pos is None:
                pos = names[j]
            elif cid_lookup[names[j]] != anchor_cid and neg is None:
                neg = names[j]
            if pos and neg:
                hard_triples.append(InputExample(texts=[anchor, pos, neg]))
                break

    return hard_triples

hard_triples = mine_hard_negatives_fast(embeds, names, cid_lookup)
print("Hard triples:", len(hard_triples))

In [None]:
triplet_loss = losses.TripletLoss(
    model=model,
    distance_metric=losses.TripletDistanceMetric.COSINE,
    triplet_margin=0.4 
)

for round in range(3):
    print(f"\n Hard-negative mining round {round+1}")

    embeds = model.encode(
        names, batch_size=512,
        normalize_embeddings=True,
        show_progress_bar=True
    ).astype("float32")

    hard_triples = mine_hard_negatives_fast(
        embeds, names, cid_lookup,
        k=15, max_anchors=20_000
    )

    hard_loader = DataLoader(
        hard_triples, batch_size=128,
        shuffle=True, num_workers=8,
        pin_memory=True, persistent_workers=True
    )

    model.fit(
        train_objectives=[(hard_loader, triplet_loss)],
        epochs=2, show_progress_bar=True,
        use_amp=False
    )

    torch.cuda.empty_cache()
    
model.save(f"{SAVE_DIR}/final_model_retrain2")
print("Final model saved")