# MedVec-Scratch

![MedVec-Architecture](https://raw.githubusercontent.com/Vivek02Sharma/MedVec-Scratch/main/assets/MedVec-scratch.png)


Dataset_URL - https://huggingface.co/datasets/abhinand/MedEmbed-training-triplets-v1

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim

import math
import pandas as pd

## Positional Encoding

In [12]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len = 5000):
    super().__init__()
    pe = torch.zeros(max_len, d_model, dtype = torch.float32)
    pos = torch.arange(0, max_len, dtype = torch.float32).unsqueeze(1)
    denominators = torch.exp(torch.arange(0, d_model, 2, dtype = torch.float32) * (-math.log(10000.0)/d_model))
    pe[:, 0::2] = torch.sin(pos * denominators)
    pe[:, 1::2] = torch.cos(pos * denominators)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe', pe)

  def forward(self, x):
    seq_len = x.size(1)
    x = x + self.pe[:, :seq_len]
    return x

## Transformer Encoder block

In [13]:
class TransformerEncoderBlock(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):
    super().__init__()
    self.self_attention = nn.MultiheadAttention(d_model, num_heads, dropout = dropout, batch_first = True)
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, src_mask = None, src_key_padding_mask = None):
    attention_output, _ = self.self_attention(x, x, x, attn_mask = src_mask, key_padding_mask = src_key_padding_mask)
    x2 = self.dropout(attention_output) + x
    x = self.norm1(x2)
    ff = self.linear2(self.dropout(torch.relu(self.linear1(x))))
    ff = self.dropout(ff)
    x = self.norm2(x + ff)
    return x


## Transformer Encoder embedder

In [14]:
class SentenceEmbeddingModel(nn.Module):
  def __init__(self,
               vocab_size,
               d_model = 512,
               nhead = 8,
               num_layers = 6,
               dim_ff = 2048,
               max_seq_len = 512,
               dropout = 0.1
               ):

    super().__init__()
    self.d_model = d_model
    self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx = 0)
    self.positional_encoding = PositionalEncoding(d_model, max_len = max_seq_len)
    self.encoder_layers = nn.ModuleList(
        [TransformerEncoderBlock(d_model, nhead, dim_ff, dropout) for _ in range(num_layers)]
    )

  def mean_pooling(self, token_embeddings, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min = 1e-9)
        return sum_embeddings / sum_mask

  def forward(self, src, attention_mask):
    x = self.token_embedding(src)
    x = x * math.sqrt(self.d_model)
    x = self.positional_encoding(x)

    key_padding_mask = (attention_mask == 0)

    for layer in self.encoder_layers:
      x = layer(x, src_key_padding_mask = key_padding_mask)
    sentence_vector = self.mean_pooling(x, attention_mask)
    return sentence_vector

## Tokenization

In [15]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

In [16]:
def get_training_corpus(dataset):
    batch_size = 1000
    for i in range(0, len(dataset), batch_size):
        batch = dataset[i : i + batch_size]
        yield batch['query'] + batch['pos'] + batch['neg'] # combine the data

In [17]:
def train_custom_bpe_tokenizer(dataset, vocab_size = 30000):
    tokenizer = Tokenizer(BPE(unk_token = "[UNK]")) # handle unknown characters
    tokenizer.pre_tokenizer = Whitespace() # split the word by space

    trainer = BpeTrainer(
        vocab_size = vocab_size,
        special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
    )

    print("training BPE tokenizer...")
    tokenizer.train_from_iterator(get_training_corpus(dataset), trainer)

    return tokenizer

## Data integration

In [18]:
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

In [19]:
class TripletDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_len = 128):
        self.dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_len = max_len

        self.tokenizer.enable_padding(pad_id = 0, pad_token = "[PAD]", length = max_len)
        self.tokenizer.enable_truncation(max_length = max_len)

    def __len__(self):
        return len(self.dataset)

    def encode_text(self, text):
        enc = self.tokenizer.encode(text)
        return torch.tensor(enc.ids, dtype = torch.long), torch.tensor(enc.attention_mask, dtype = torch.long)

    def __getitem__(self, idx):
        row = self.dataset[idx]

        q_ids, q_mask = self.encode_text(row['query'])
        p_ids, p_mask = self.encode_text(row['pos'])
        n_ids, n_mask = self.encode_text(row['neg'])

        return {
            'q_ids': q_ids, 'q_mask': q_mask,
            'p_ids': p_ids, 'p_mask': p_mask,
            'n_ids': n_ids, 'n_mask': n_mask
        }

## Training

In [20]:
if __name__ == "__main__":

    VOCAB_SIZE = 30000
    MAX_SEQ_LEN = 128
    BATCH_SIZE = 32
    D_MODEL = 256
    NHEAD = 4
    NUM_LAYERS = 4
    DIM_FF = 1024
    DROPOUT = 0.1
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    NUM_EPOCHS = 10
    SAVE_EVERY_EPOCH = 1

    # saving tokenizer on drive
    import os
    from google.colab import drive

    print("Mounting google drive...")

    drive.mount('/content/drive')
    CHECKPOINT_DIR = "/content/drive/MyDrive/MedEmbed_Checkpoints"

    if not os.path.exists(CHECKPOINT_DIR):
        os.makedirs(CHECKPOINT_DIR)
        print(f"Created directory in Drive: {CHECKPOINT_DIR}")
    else:
        print(f"Using existing directory in Drive: {CHECKPOINT_DIR}")

    # load dataset
    print("Loading medical dataset from huggingface...")
    full_dataset = load_dataset("abhinand/MedEmbed-training-triplets-v1", split = "train")

    # about dataset
    print("Dataset loaded. Rows: ", len(full_dataset))
    print("Columns: ", full_dataset.column_names)

    tokenizer_path = os.path.join(CHECKPOINT_DIR, "tokenizer.json")

    if os.path.exists(tokenizer_path):
        print(f"Loading tokenizer from Drive: {tokenizer_path}")
        tokenizer = Tokenizer.from_file(tokenizer_path)
    else:
        print("Training tokenizer...")
        tokenizer = train_custom_bpe_tokenizer(full_dataset, vocab_size = VOCAB_SIZE)
        tokenizer.save(tokenizer_path)
        print(f"Tokenizer saved to Drive: {tokenizer_path}")

    # preparing model
    train_dataset = TripletDataset(full_dataset, tokenizer, max_len = MAX_SEQ_LEN)
    train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 2)

    model = SentenceEmbeddingModel(
        vocab_size = tokenizer.get_vocab_size(),
        d_model = D_MODEL,
        max_seq_len = MAX_SEQ_LEN,
        nhead = NHEAD,
        num_layers = NUM_LAYERS,
        dim_ff = DIM_FF,
        dropout = DROPOUT
    ).to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr = 1e-4)
    criterion = nn.TripletMarginLoss(margin = 1.0, p = 2)

    # training loop
    print("\nStarting Training...")
    model.train()

    for epoch in range(NUM_EPOCHS):
        total_loss = 0
        batch_count = 0

        for batch in train_dataloader:
            optimizer.zero_grad()

            batch = {k: v.to(DEVICE) for k, v in batch.items()}

            q_ids, p_ids, n_ids = batch['q_ids'], batch['p_ids'], batch['n_ids']
            q_mask, p_mask, n_mask = batch['q_mask'], batch['p_mask'], batch['n_mask']

            q_emb = model(q_ids, q_mask)
            p_emb = model(p_ids, p_mask)
            n_emb = model(n_ids, n_mask)

            loss = criterion(q_emb, p_emb, n_emb)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            batch_count += 1

            if batch_count % 100 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_count}/{len(train_dataloader)}, Loss: {loss.item():.4f}")

        avg_loss = total_loss / batch_count
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        if (epoch + 1) % SAVE_EVERY_EPOCH == 0:
            save_path = os.path.join(CHECKPOINT_DIR, f"model_epoch_{epoch+1}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, save_path)

            print(f"Saved checkpoint: {save_path}")

Mounting google drive...
Mounted at /content/drive
Created directory in Drive: /content/drive/MyDrive/MedEmbed_Checkpoints
Loading medical dataset from huggingface...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/59.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/232684 [00:00<?, ? examples/s]

Dataset loaded. Rows:  232684
Columns:  ['query', 'pos', 'neg', 'query_id', 'pos_id', 'neg_id']
Training tokenizer...
training BPE tokenizer...
Tokenizer saved to Drive: /content/drive/MyDrive/MedEmbed_Checkpoints/tokenizer.json

Starting Training...
Epoch 1, Batch 100/7272, Loss: 0.7643
Epoch 1, Batch 200/7272, Loss: 0.6434
Epoch 1, Batch 300/7272, Loss: 0.4567
Epoch 1, Batch 400/7272, Loss: 0.5051
Epoch 1, Batch 500/7272, Loss: 0.7626
Epoch 1, Batch 600/7272, Loss: 0.6204
Epoch 1, Batch 700/7272, Loss: 0.5324
Epoch 1, Batch 800/7272, Loss: 0.5192
Epoch 1, Batch 900/7272, Loss: 0.5735
Epoch 1, Batch 1000/7272, Loss: 0.5651
Epoch 1, Batch 1100/7272, Loss: 0.5170
Epoch 1, Batch 1200/7272, Loss: 0.5679
Epoch 1, Batch 1300/7272, Loss: 0.4556
Epoch 1, Batch 1400/7272, Loss: 0.6683
Epoch 1, Batch 1500/7272, Loss: 0.3793
Epoch 1, Batch 1600/7272, Loss: 0.3390
Epoch 1, Batch 1700/7272, Loss: 0.3630
Epoch 1, Batch 1800/7272, Loss: 0.3780
Epoch 1, Batch 1900/7272, Loss: 0.4421
Epoch 1, Batch 20

## Inferencing

In [23]:
import torch.nn.functional as F

checkpoint_dir = "/content/drive/MyDrive/MedEmbed_Checkpoints"
tokenizer = Tokenizer.from_file(f"{checkpoint_dir}/tokenizer.json")

model = SentenceEmbeddingModel(
    vocab_size=tokenizer.get_vocab_size(),
    d_model=256,
    nhead=4,
    num_layers=4,
    dim_ff=1024,
    max_seq_len=128
)

checkpoint = torch.load(f"{checkpoint_dir}/model_epoch_10.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# generate embedding
def get_embedding(text):
    enc = tokenizer.encode(text)
    ids = torch.tensor(enc.ids).unsqueeze(0)
    mask = torch.tensor(enc.attention_mask).unsqueeze(0)

    with torch.no_grad():
        embedding = model(ids, mask)
        return F.normalize(embedding, p = 2, dim = 1)


docs = [
    "The patient presented to the emergency room with severe abdominal pain in the left lower quadrant, which was diagnosed as a cystic mass on the left anterior wall of the uterus. Presenting Symptoms: Severe abdominal pain in the left lower quadrant",
    "The patient presented to the emergency department with a high-grade fever, myalgia, and headache.",
    "He received induction chemotherapy with daunorubicin and cytarabine (3+7), and during his hospital course, he developed febrile neutropenia but eventually resolved on Day 21."
    ]

queries = [
    "cystic mass on uterus symptoms",
    "dengue fever symptoms",
    "AML M-4 treatment options"
]

print("\nGenerating document embeddings...")
# Pre-calculate and stack all doc vectors into one matrix [N_docs, Dim]
docs_embeddings = torch.cat([get_embedding(doc) for doc in docs], dim = 0)

print("-" * 50)
print("Starting Semantic Search...")
print("-" * 50)

for query_text in queries:
    print(f"\nQuery: '{query_text}'")

    # 1. Get query vector [1, Dim]
    query_vec = get_embedding(query_text)

    # 2. Calculate Cosine Similarity against all docs
    # Since vectors are normalized, this is just matrix multiplication.
    # [1, Dim] x [Dim, N_docs] -> [1, N_docs]
    scores = torch.mm(query_vec, docs_embeddings.transpose(0, 1))

    # 3. Sort results by score (descending)
    # Get top N results (e.g., top 2)
    top_k_scores, top_k_indices = torch.topk(scores[0], k = len(docs))

    # 4. Print Results
    for i, (score, idx) in enumerate(zip(top_k_scores, top_k_indices)):
        doc_idx = idx.item()
        print(f"   Rank {i+1}: Score {score.item():.4f} | Doc ID {doc_idx}")
        # Print the first 100 chars of the doc for context
        print(f"   -> {docs[doc_idx]}...")


Generating document embeddings...
--------------------------------------------------
Starting Semantic Search...
--------------------------------------------------

Query: 'cystic mass on uterus symptoms'
   Rank 1: Score 0.4981 | Doc ID 0
   -> The patient presented to the emergency room with severe abdominal pain in the left lower quadrant, which was diagnosed as a cystic mass on the left anterior wall of the uterus. Presenting Symptoms: Severe abdominal pain in the left lower quadrant...
   Rank 2: Score 0.1999 | Doc ID 1
   -> The patient presented to the emergency department with a high-grade fever, myalgia, and headache....
   Rank 3: Score 0.0971 | Doc ID 2
   -> He received induction chemotherapy with daunorubicin and cytarabine (3+7), and during his hospital course, he developed febrile neutropenia but eventually resolved on Day 21....

Query: 'dengue fever symptoms'
   Rank 1: Score 0.5865 | Doc ID 1
   -> The patient presented to the emergency department with a high-grade f