In [None]:
# Requires transformers>=4.51.0

import torch
import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel


def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery:{query}'

# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'

queries = [
    get_detailed_instruct(task, 'What is the capital of China?'),
    get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-8B', padding_side='left')
model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-8B')

# We recommend enabling flash_attention_2 for better acceleration and memory saving.
# model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-8B', attn_implementation="flash_attention_2", torch_dtype=torch.float16).cuda()

max_length = 8192

# Tokenize the input texts
batch_dict = tokenizer(
    input_texts,
    padding=True,
    truncation=True,
    max_length=max_length,
    return_tensors="pt",
)
batch_dict.to(model.device)
outputs = model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7493016123771667, 0.0750647559762001], [0.08795969933271408, 0.6318399906158447]]


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[[0.7493017911911011, 0.0750647783279419], [0.08795967698097229, 0.6318397521972656]]


In [3]:
from typing import List, Iterable, Union, Optional

from langchain_core.embeddings import Embeddings

import torch
import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel

def _last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

class MyEmbeddings(Embeddings):
    def __init__(
        self,
        model: str = "Qwen/Qwen3-Embedding-8B",
        *,
        max_length: int = 8192,
    ):
        self.model_name = model
        self.max_length = max_length

        self._tokenizer = tokenizer = AutoTokenizer.from_pretrained(self.model_name, padding_side='left')
        self._model = AutoModel.from_pretrained(self.model_name)

        if device is None:
            if torch.cuda.is_available():
                device = "cuda"
            elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
                device = "mps"
            else:
                device = "cpu"
        self._device = torch.device(device)

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed search docs."""
        return self._embed_texts(texts)

    def embed_query(self, text: str) -> List[float]:
        """Embed query text."""
        return self._embed_texts([text])[0]

    @torch.inference_mode()
    def _embed_texts(self, texts: Union[List[str], Iterable[str]]) -> List[List[float]]:
        if not isinstance(texts, list):
            texts = list(texts)

        out: List[List[float]] = []
        for i in range(0, len(texts), 32):
            batch_texts = texts[i : i + 32]
            batch = self._tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt",
            )
            batch = {k: v.to(self._device) for k, v in batch.items()}
            outputs = self._model(**batch)
            hidden = outputs.last_hidden_state if hasattr(outputs, "last_hidden_state") else outputs[0]
            pooled = _last_token_pool(hidden, batch["attention_mask"])
            if self.normalize:
                pooled = F.normalize(pooled, p=2, dim=1)
            out.extend(pooled.detach().cpu().tolist())
        return out

In [None]:
embeddings = MyEmbeddings()
print(embeddings.embed_documents(["Hello", "world"]))

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]