In [1]:
import torch
from transformers import DistilBertModel, DistilBertTokenizer

In [2]:
class SentenceTransformer:
    def __init__(self, model_name="distilbert-base-uncased"):
        # Load the pre-trained model and tokenizer
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
        self.model = DistilBertModel.from_pretrained(model_name)

    def encode(self, sentences):
        # Tokenize the sentences (batch processing)
        inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
        
        # Forward pass through the model
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        # Extract the last hidden state (token embeddings) and apply mean pooling
        token_embeddings = outputs.last_hidden_state  # Shape: [batch_size, sequence_length, hidden_size]
        
        # Mean pooling: mean of token embeddings across the sequence length dimension
        attention_mask = inputs["attention_mask"]
        masked_token_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
        sentence_embeddings = masked_token_embeddings.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
        
        return sentence_embeddings


In [3]:
sentences = [
    "These are example senetences for the sentence transformer.",
    "Sentence transformers are useful for many NLP tasks.",
    "The weather is warm and sunny today.",
    "It looks like it will rain soon."
]

In [4]:
model = SentenceTransformer(model_name="distilbert-base-uncased")
embeddings = model.encode(sentences)

In [5]:
embeddings.shape

torch.Size([4, 768])

In [6]:
print(sentences[0], ':\n', embeddings[0])

These are example senetences for the sentence transformer. :
 tensor([-1.6575e-01, -2.1145e-01, -1.5407e-01,  2.4948e-04, -7.8773e-02,
         2.3819e-02,  2.1170e-01,  1.6797e-01,  1.0916e-01,  3.2143e-02,
        -2.4718e-01,  7.0778e-02, -2.0016e-01, -1.3198e-01, -1.6691e-01,
         2.3159e-01, -6.1706e-02,  3.5158e-02, -1.5826e-01, -2.1277e-01,
         2.9083e-01,  3.3889e-01, -1.3739e-03,  1.5036e-01,  3.6409e-01,
        -2.4928e-01, -2.5242e-03, -1.7506e-02, -3.8636e-01,  7.7822e-02,
         3.9544e-02,  5.4002e-01, -1.2652e-01, -1.6412e-01,  2.9374e-02,
         1.1566e-01,  2.3228e-01, -2.4044e-01, -1.3916e-01,  2.0177e-01,
        -4.4393e-01, -3.5111e-01,  4.0886e-02, -1.5443e-01, -2.3737e-01,
        -3.5476e-01, -6.4319e-02, -3.4519e-01, -7.3408e-02, -5.2618e-02,
        -8.6616e-01,  4.0336e-01,  3.3825e-02,  3.2807e-01, -6.9723e-02,
         5.7166e-01,  1.2126e-01, -5.7674e-01,  3.1775e-01, -2.7250e-01,
         8.8507e-02,  9.9551e-02, -1.1892e-01, -4.1509e-01,  3

In [7]:
print(sentences[1], ':\n', embeddings[1])

Sentence transformers are useful for many NLP tasks. :
 tensor([ 6.2342e-03, -1.0091e-01, -3.2863e-02,  1.3378e-01,  5.7312e-02,
        -3.2103e-01,  1.6157e-01, -1.8010e-02, -1.0450e-01, -2.2852e-02,
        -2.0801e-01,  3.0121e-02, -1.9861e-01, -1.2776e-01, -1.7679e-01,
         2.5852e-01, -2.5621e-02,  1.6422e-01, -3.7527e-02, -1.8756e-01,
         1.2707e-01,  3.7283e-02, -3.8539e-02,  2.5775e-01,  6.2193e-02,
        -2.0972e-01,  3.7296e-02,  1.8713e-01, -3.0179e-01, -6.7718e-02,
        -2.4342e-01,  2.3574e-01,  7.8746e-02, -1.3430e-01,  1.8251e-01,
         2.1505e-01,  9.4307e-02, -1.4295e-01, -3.8505e-02,  1.7322e-01,
        -1.4439e-01, -1.7160e-01, -2.3221e-02,  5.4074e-02, -1.9189e-01,
        -4.0092e-01, -9.9745e-02, -1.7989e-01,  2.2963e-01, -1.6416e-01,
        -9.0734e-01,  5.4769e-01, -1.1411e-01,  2.9595e-01,  1.1945e-01,
         3.1283e-01,  2.8820e-01, -5.2915e-01,  2.2332e-01, -1.3195e-01,
        -1.0924e-01,  7.5171e-02, -2.7328e-01, -3.5046e-01,  5.4048e

I used the distilbert-base-uncased model for the transformer backbone. I decided to use mean pooling to aggregate the embeddings into a single fixed length vector. Mean pooling is a simple, effective, and computationally efficient method to do so. I also added an attention mask and multiplied it with the embeddings to avoid the impact of padding tokens.