### INSTALL DEPENDENCIES

In [14]:
%%capture
!pip install --upgrade sentence-transformers datasets accelerate torch transformers

### IMPORT LIBRARIES

In [13]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import json
import torch
from datasets import Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses,
)
from sentence_transformers.util import cos_sim
from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.losses import MatryoshkaLoss

### CONFIGURATIONS

In [3]:
DATA_PATH = "./final_mixed_train.jsonl"
MODEL_ID = "ibm-granite/granite-embedding-english-r2"
OUTPUT_DIR = "./granite-embedding"

MATRYOSHKA_DIMS = [768, 512, 256, 128, 64]

# List of metrics to display
metrics = [
    'ndcg@10',
    'mrr@10',
    'map@100',
    'accuracy@1',
    'accuracy@3',
    'accuracy@5',
    'accuracy@10',
    'precision@1',
    'precision@3',
    'precision@5',
    'precision@10',
    'recall@1',
    'recall@3',
    'recall@5',
    'recall@10'
]

### DATA LOADING & PREPARATION

In [15]:
data_rows = []

try:
  with open(DATA_PATH, 'r', encoding='utf-8') as f:
    for idx, line in enumerate(f):
      try:
        row = json.loads(line)
        question = row.get('instruction', '').strip()
        answer = row.get('output', '').strip()
        context = row.get('input', '').strip()

        # Combine context
        positive_text = f"Context: {context}\nAnswer: {answer}"

        # Convert ID to string (Must follow this order)
        data_rows.append({
          "anchor": question,
          "positive": positive_text,
          "id": str(idx),
        })

      except json.JSONDecodeError:
        continue
except FileNotFoundError:
    print(f"ERROR: Could not find {DATA_PATH}. Please upload it to Colab.")

# Convert to HF dataset
full_dataset = Dataset.from_list(data_rows)

# Split dataset
dataset_dict = full_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_dict["train"]
test_dataset = dataset_dict["test"]

### MODEL & LOSS SETUP

In [5]:
model = SentenceTransformer(MODEL_ID, device="cuda" if torch.cuda.is_available() else "cpu")
model.max_seq_length = 512

# Inner Loss: MNRL
inner_loss = losses.MultipleNegativesRankingLoss(model)

# Outer Loss: Matryoshka
train_loss = MatryoshkaLoss(
    model=model,
    loss=inner_loss,
    matryoshka_dims=MATRYOSHKA_DIMS
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/230 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/55.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/298M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]

### EVALUATION SETUP

In [6]:
# Extract lists for evaluation from the Test set
corpus = {row["id"]: row["positive"] for row in test_dataset}
queries = {row["id"]: row["anchor"] for row in test_dataset}
relevant_docs = {row["id"]: {row["id"]} for row in test_dataset}

# Create a list of evaluators
evaluators_list = []
for dim in MATRYOSHKA_DIMS:
   # Define the evaluator
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to the respective dimension
        score_functions={"cosine": cos_sim},
    )
    # Add to list
    evaluators_list.append(ir_evaluator)

# Combine into one Sequential Evaluator
seq_evaluator = SequentialEvaluator(evaluators_list)

### EVALUATE BEFORE TRAINING

In [7]:
# Evaluate the model
base_results = seq_evaluator(model)

# Print header
print("\nBase Model Evaluation Results")
print("-" * 85)
print(f"{'Metric':15} {'768d':>12} {'512d':>12} {'256d':>12} {'128d':>12} {'64d':>12}")
print("-" * 85)

# Print each metric
for metric in metrics:
    values = []
    for dim in MATRYOSHKA_DIMS:
        key = f"dim_{dim}_cosine_{metric}"
        values.append(base_results[key])

    # Highlight NDCG@10
    metric_name = f"=={metric}==" if metric == "ndcg@10" else metric
    print(f"{metric_name:15}", end="  ")
    for val in values:
        print(f"{val:12.4f}", end=" ")
    print()

# Print sequential score
print("-" * 85)
print(f"{'seq_score:'} {base_results['sequential_score']:1f}")

W0102 15:50:59.445000 592 torch/_inductor/utils.py:1613] [1/0_1] Not enough SMs to use max_autotune_gemm mode



Base Model Evaluation Results
-------------------------------------------------------------------------------------
Metric                  768d         512d         256d         128d          64d
-------------------------------------------------------------------------------------
==ndcg@10==            0.6213       0.6214       0.6163       0.6036       0.5899 
mrr@10                 0.6098       0.6087       0.6045       0.5941       0.5807 
map@100                0.6151       0.6139       0.6095       0.5987       0.5856 
accuracy@1             0.5880       0.5860       0.5820       0.5760       0.5580 
accuracy@3             0.6280       0.6220       0.6200       0.6100       0.6040 
accuracy@5             0.6380       0.6420       0.6300       0.6160       0.6100 
accuracy@10            0.6580       0.6620       0.6540       0.6340       0.6180 
precision@1            0.5880       0.5860       0.5820       0.5760       0.5580 
precision@3            0.2093       0.2073       0.2

### TRAINING

In [8]:
# Training Arguments
args = SentenceTransformerTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,
    per_device_eval_batch_size=32,
    gradient_checkpointing=True,
    warmup_ratio=0.1,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    fp16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_dim_128_cosine_ndcg@10",       # Optimizing for the best ndcg@10 score for the 128 dimension
    report_to="none"
)

In [9]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=train_loss,
    evaluator=seq_evaluator,
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [10]:
torch.cuda.empty_cache()
trainer.train()

trainer.save_model()
model.save_pretrained(OUTPUT_DIR)

Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
1,4.6642,No log,0.6,0.654,0.672,0.694,0.6,0.218,0.1344,0.0694,0.6,0.654,0.672,0.694,0.646504,0.631457,0.639363,0.602,0.652,0.668,0.692,0.602,0.217333,0.1336,0.0692,0.602,0.652,0.668,0.692,0.645902,0.631288,0.639443,0.602,0.652,0.674,0.702,0.602,0.217333,0.1348,0.0702,0.602,0.652,0.674,0.702,0.648908,0.63231,0.639138,0.606,0.66,0.676,0.702,0.606,0.22,0.1352,0.0702,0.606,0.66,0.676,0.702,0.653431,0.63801,0.644946,0.606,0.664,0.676,0.694,0.606,0.221333,0.1352,0.0694,0.606,0.664,0.676,0.694,0.651311,0.637536,0.644821,0.651311
2,3.2254,No log,0.606,0.67,0.69,0.714,0.606,0.223333,0.138,0.0714,0.606,0.67,0.69,0.714,0.66026,0.643205,0.650942,0.602,0.668,0.68,0.714,0.602,0.222667,0.136,0.0714,0.602,0.668,0.68,0.714,0.657544,0.639673,0.647451,0.61,0.666,0.68,0.72,0.61,0.222,0.136,0.072,0.61,0.666,0.68,0.72,0.662287,0.644305,0.651842,0.614,0.662,0.688,0.71,0.614,0.220667,0.1376,0.071,0.614,0.662,0.688,0.71,0.660386,0.644677,0.652137,0.604,0.664,0.674,0.722,0.604,0.221333,0.1348,0.0722,0.604,0.664,0.674,0.722,0.659451,0.640146,0.646575,0.659451
3,2.395,No log,0.604,0.666,0.688,0.716,0.604,0.222,0.1376,0.0716,0.604,0.666,0.688,0.716,0.659004,0.640958,0.64839,0.604,0.666,0.694,0.72,0.604,0.222,0.1388,0.072,0.604,0.666,0.694,0.72,0.660801,0.642055,0.649497,0.61,0.672,0.69,0.72,0.61,0.224,0.138,0.072,0.61,0.672,0.69,0.72,0.663319,0.645433,0.653093,0.612,0.67,0.69,0.712,0.612,0.223333,0.138,0.0712,0.612,0.67,0.69,0.712,0.661194,0.644956,0.652008,0.602,0.656,0.68,0.722,0.602,0.218667,0.136,0.0722,0.602,0.656,0.68,0.722,0.658404,0.6386,0.644846,0.658404


### EVALUATE AFTER TRAINING

In [11]:
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)

# Evaluate the model
ft_results = seq_evaluator(fine_tuned_model)

# Print header
print("Fine Tuned Model Evaluation Results")
print("-" * 85)
print(f"{'Metric':15} {'768d':>12} {'512d':>12} {'256d':>12} {'128d':>12} {'64d':>12}")
print("-" * 85)

# Print each metric
for metric in metrics:
    values = []
    for dim in MATRYOSHKA_DIMS:
        key = f"dim_{dim}_cosine_{metric}"
        values.append(ft_results[key])

    # Highlight NDCG@10
    metric_name = f"=={metric}==" if metric == "ndcg@10" else metric
    print(f"{metric_name:15}", end="  ")
    for val in values:
        print(f"{val:12.4f}", end=" ")
    print()

# Print sequential score
print("-" * 85)
print(f"{'seq_score:'} {ft_results['sequential_score']:1f}")

Fine Tuned Model Evaluation Results
-------------------------------------------------------------------------------------
Metric                  768d         512d         256d         128d          64d
-------------------------------------------------------------------------------------
==ndcg@10==            0.6609       0.6634       0.6663       0.6604       0.6594 
mrr@10                 0.6435       0.6443       0.6487       0.6440       0.6400 
map@100                0.6510       0.6514       0.6563       0.6511       0.6462 
accuracy@1             0.6080       0.6060       0.6140       0.6120       0.6040 
accuracy@3             0.6660       0.6660       0.6740       0.6660       0.6580 
accuracy@5             0.6900       0.6960       0.6920       0.6880       0.6780 
accuracy@10            0.7160       0.7240       0.7220       0.7120       0.7220 
precision@1            0.6080       0.6060       0.6140       0.6120       0.6040 
precision@3            0.2220       0.2220     

### UPLOAD TO HUGGINGFACE

In [18]:
from huggingface_hub import login, HfApi

login()

trainer.model.push_to_hub("shatonix/granite-embedding-math-cs")

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...u6cbjuh/model.safetensors:   0%|          |  680kB /  596MB            

'https://huggingface.co/shatonix/granite-embedding-math-cs/commit/0791216e11245c91095b51990bc8ed53049bf1c8'