# Simple Fine-Tuning Example with Sentence Transformers

This notebook shows a **minimal example** of fine-tuning an existing Sentence Transformer model
(`all-MiniLM-L6-v2`) on a small subset of the **STSb** dataset using the new
`SentenceTransformerTrainer` API:

1. Install dependencies  
2. Load a small training & validation split from STSb  
3. Define a loss (`CoSENTLoss`) and an evaluator (`EmbeddingSimilarityEvaluator`)  
4. Configure `SentenceTransformerTrainingArguments`  
5. Train with `SentenceTransformerTrainer`  
6. Evaluate and save the fine-tuned model


## 1. Install dependencies

In [None]:
!pip install -U sentence-transformers datasets

## 2. Imports and basic setup

In [None]:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import CoSENTLoss
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction
from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers

import torch

print("Torch version:", torch.__version__)

## 3. Load a small STSb training & validation set

In [None]:
# Load STSb dataset (semantic textual similarity)
raw_train = load_dataset("sentence-transformers/stsb", split="train[:2000]")
raw_dev   = load_dataset("sentence-transformers/stsb", split="validation[:500]")

# CoSENTLoss expects a score/label column in [0, 1].
# STSb scores are originally in [0, 5], so we normalize them.
def normalize_scores(batch):
    batch["score"] = [s / 5.0 for s in batch["score"]]
    return batch

train_dataset = raw_train.map(normalize_scores, batched=True)
dev_dataset   = raw_dev.map(normalize_scores, batched=True)

print(train_dataset)
print(dev_dataset)

## 4. Load base model and define loss

In [None]:
# Start from an existing Sentence Transformer model (small & fast)
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# CoSENTLoss works on (sentence1, sentence2, score) pairs
loss = CoSENTLoss(model)

## 5. Create evaluator for validation set

In [None]:
dev_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=dev_dataset["sentence1"],
    sentences2=dev_dataset["sentence2"],
    scores=dev_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="stsb-dev",
)

# Optional: evaluate base model before training
print("Base model STSb dev performance:")
print(dev_evaluator(model))

## 6. Define training arguments

In [None]:
args = SentenceTransformerTrainingArguments(
    # Required:
    output_dir="models/all-MiniLM-L6-v2-stsb-cosent",

    # Keep small for demo purposes
    num_train_epochs=1,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=torch.cuda.is_available(),  # use FP16 if GPU supports it
    bf16=False,
    batch_sampler=BatchSamplers.NO_DUPLICATES,

    # Logging / eval / saving
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=1,
    logging_steps=50,
    run_name="all-MiniLM-L6-v2-stsb-cosent-demo"
)

## 7. Create trainer and fine-tune the model

In [None]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=train_dataset.select(range(500)),  # small subset for eval loss
    loss=loss,
    evaluator=dev_evaluator,
)

trainer.train()

## 8. Evaluate and save the fine-tuned model

In [None]:
print("Fine-tuned model STSb dev performance:")
print(dev_evaluator(model))

# Save locally
save_path = "models/all-MiniLM-L6-v2-stsb-cosent/final"
model.save_pretrained(save_path)
print("Model saved to:", save_path)

# Optional: push to the Hugging Face Hub (requires HF token configured)
# model.push_to_hub("your-username/all-MiniLM-L6-v2-stsb-cosent-demo")

## 9. Quick qualitative check of embeddings

In [None]:
sentences = [
    "A man is playing guitar on stage.",
    "Someone is performing music in front of an audience.",
    "A dog is running through a field.",
]

embeddings = model.encode(sentences, convert_to_tensor=True, normalize_embeddings=True)
cos_sim = torch.matmul(embeddings, embeddings.T)

print("Cosine similarity matrix:")
for i, s in enumerate(sentences):
    row = ", ".join(f"{cos_sim[i,j]:.3f}" for j in range(len(sentences)))
    print(f"{i} ({s}): [{row}]")