In [1]:
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F

# Define the terms
terms = ["hyperkalemia", "hypermetropia", "eye disease"]

# Function to calculate embeddings
def calculate_embeddings(model_name, tokenizer_name, terms):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    model = AutoModel.from_pretrained(model_name)
    embeddings = []
    
    for term in terms:
        inputs = tokenizer(term, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        # Use the [CLS] token's embedding for simplicity
        cls_embedding = outputs.last_hidden_state[:, 0, :]
        embeddings.append(cls_embedding.squeeze(0))
    return embeddings

# Function for cosine similarity
def cosine_similarity(vec1, vec2):
    return F.cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0)).item()

# Generate embeddings using PubMedBERT
print("Generating embeddings using PubMedBERT...")
pubmedbert_embeddings = calculate_embeddings("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", 
                                              "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", terms)

# Generate embeddings using BioBERT
print("Generating embeddings using BioBERT...")
biobert_embeddings = calculate_embeddings("dmis-lab/biobert-base-cased-v1.1", 
                                           "dmis-lab/biobert-base-cased-v1.1", terms)

# Generate embeddings using SBERT
print("Generating embeddings using SBERT...")
sbert_model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
sbert_embeddings = [sbert_model.encode(term, convert_to_tensor=True) for term in terms]

# Calculate cosine similarity for pairs of terms
print("\nCosine Similarity:")
for i in range(len(terms)):
    for j in range(i + 1, len(terms)):
        pubmedbert_sim = cosine_similarity(pubmedbert_embeddings[i], pubmedbert_embeddings[j])
        biobert_sim = cosine_similarity(biobert_embeddings[i], biobert_embeddings[j])
        sbert_sim = cosine_similarity(sbert_embeddings[i], sbert_embeddings[j])
        
        print(f"Terms: {terms[i]} vs {terms[j]}")
        print(f"  PubMedBERT: {pubmedbert_sim:.4f}")
        print(f"  BioBERT: {biobert_sim:.4f}")
        print(f"  SBERT: {sbert_sim:.4f}")


Generating embeddings using PubMedBERT...


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Generating embeddings using BioBERT...


config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Generating embeddings using SBERT...


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]


Cosine Similarity:
Terms: hyperkalemia vs hypermetropia
  PubMedBERT: 0.9251
  BioBERT: 0.9168
  SBERT: 0.4888
Terms: hyperkalemia vs eye disease
  PubMedBERT: 0.8716
  BioBERT: 0.8752
  SBERT: 0.1756
Terms: hypermetropia vs eye disease
  PubMedBERT: 0.9037
  BioBERT: 0.9315
  SBERT: 0.2563
