In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

# Define model path
MODEL_PATH ="/Users/sir/Downloads/HuggingFace/sentence_transformer/intfloat_e5-large-v2"

# use mps if available, else cuda, else cpu
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")

# Load model and tokenizer
print("Loading model...")
model = AutoModel.from_pretrained(MODEL_PATH).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print("Model loaded.")

Using device: mps
Loading model...
Model loaded.


In [12]:
# ----------------------------
# 1. Matryoshka Text Encoder
# ----------------------------
class MatryoshkaTextEncoder(nn.Module):
    def __init__(self, model_name=MODEL_PATH, dims=[128, 256, 384, 512]):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.projection = nn.Linear(self.encoder.config.hidden_size, max(dims))
        self.dims = dims

    def forward(self, input_ids, attention_mask=None, **kwargs):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
        cls_emb = out.last_hidden_state[:, 0]             # [CLS] token
        z = self.projection(cls_emb)                      # [batch, max_dim]
        z = F.normalize(z, dim=-1)
        return z


In [13]:
def matryoshka_text_loss(z1, z2, dims=[128, 256, 384, 512], temperature=0.05):
    total_loss = 0
    for m in dims:
        z1_m = F.normalize(z1[:, :m], dim=-1)
        z2_m = F.normalize(z2[:, :m], dim=-1)
        logits = (z1_m @ z2_m.T) / temperature
        labels = torch.arange(z1.size(0), device=z1.device)
        loss = F.cross_entropy(logits, labels)
        total_loss += loss
    return total_loss / len(dims)

In [14]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = MatryoshkaTextEncoder().to(DEVICE)

texts = [
    "The cat sits on the mat.",
    "A small feline rests on a rug."
]

enc = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(DEVICE)

z1 = model(**enc)
z2 = model(**enc)

loss = matryoshka_text_loss(z1, z2)
loss.backward()
print("Training loss:", loss.item())

Training loss: 0.07977110892534256


In [16]:
dims = [128, 256, 384, 512]

for m in dims:
    sim = F.cosine_similarity(z1[0, :m].unsqueeze(0), z2[1, :m].unsqueeze(0)).item()
    print(f"Cosine similarity ({m}D): {sim:.4f}")

Cosine similarity (128D): 0.8901
Cosine similarity (256D): 0.8814
Cosine similarity (384D): 0.8615
Cosine similarity (512D): 0.8640
