In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

# Define model path
MODEL_PATH ="/Users/sir/Downloads/HuggingFace/sentence_transformer/intfloat_e5-large-v2"

# use mps if available, else cuda, else cpu
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")

# Load model and tokenizer
print("Loading model...")
model = AutoModel.from_pretrained(MODEL_PATH).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print("Model loaded.")

Using device: mps
Loading model...
Model loaded.


In [26]:
print(model)

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False, 'architecture': 'BertModel'})
  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)


#### Training
Training using Matryoshka Representation Learning (MRL) is quite elementary: rather than applying some loss function on only the full-size embeddings, we also apply that same loss function on truncated portions of the embeddings. For example, if a model has an embedding dimension of 768 by default, it can now be trained on 768, 512, 256, 128, 64 and 32. Each of these losses will be added together, optionally with some weight:

In [25]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss, MatryoshkaLoss


model = SentenceTransformer(MODEL_PATH).to(DEVICE)

base_loss = CoSENTLoss(model=model)
loss = MatryoshkaLoss(model=model, loss=base_loss, matryoshka_dims=[1024, 768, 512, 256, 128, 64])

Additionally, this can be combined with the `AdaptiveLayerLoss` such that the resulting model can be reduced both in the size of the output dimensions, but also in the number of layers for faster inference. See also the [Adaptive Layers](https://sbert.net/examples/sentence_transformer/training/adaptive_layer/README.html) for more information on reducing the number of model layers. In Sentence Transformers, the combination of these two losses is called `Matryoshka2dLoss`, and a shorthand is provided for simpler training.

Reference: [Matryoshka2dLoss](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#matryoshka2dloss)

In [12]:
# ----------------------------
# 1. Matryoshka Text Encoder
# ----------------------------
class MatryoshkaTextEncoder(nn.Module):
    def __init__(self, model_name=MODEL_PATH, dims=[128, 256, 384, 512]):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.projection = nn.Linear(self.encoder.config.hidden_size, max(dims))
        self.dims = dims

    def forward(self, input_ids, attention_mask=None, **kwargs):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
        cls_emb = out.last_hidden_state[:, 0]             # [CLS] token
        z = self.projection(cls_emb)                      # [batch, max_dim]
        z = F.normalize(z, dim=-1)
        return z


In [13]:
def matryoshka_text_loss(z1, z2, dims=[128, 256, 384, 512], temperature=0.05):
    total_loss = 0
    for m in dims:
        z1_m = F.normalize(z1[:, :m], dim=-1)
        z2_m = F.normalize(z2[:, :m], dim=-1)
        logits = (z1_m @ z2_m.T) / temperature
        labels = torch.arange(z1.size(0), device=z1.device)
        loss = F.cross_entropy(logits, labels)
        total_loss += loss
    return total_loss / len(dims)

In [20]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = MatryoshkaTextEncoder().to(DEVICE)

texts = [
    "The cat sits on the mat.",
    "A small feline rests on a rug."
]

enc = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(DEVICE)

z1 = model(**enc)
z2 = model(**enc)

loss = matryoshka_text_loss(z1, z2)
loss.backward()
print("Training loss:", loss.item())

Training loss: 0.058626338839530945


In [21]:
dims = [128, 256, 384, 512]

for m in dims:
    sim = F.cosine_similarity(z1[0, :m].unsqueeze(0), z2[1, :m].unsqueeze(0)).item()
    print(f"Cosine similarity ({m}D): {sim:.4f}")

Cosine similarity (128D): 0.8485
Cosine similarity (256D): 0.8547
Cosine similarity (384D): 0.8628
Cosine similarity (512D): 0.8701
