In [9]:
from transformers import DistilBertTokenizer, DistilBertModel
import torch
import torch.nn as nn



In [10]:

# Example usage
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
sentences = ["I love machine learning.", "Transformers are powerful."]
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")


In [11]:
class SentenceTransformer(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased'):
        super(SentenceTransformer, self).__init__()
        self.bert = DistilBertModel.from_pretrained(model_name)
        
    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Mean pooling
        last_hidden_state = output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size())
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

In [12]:
model = SentenceTransformer()
embeddings = model(inputs['input_ids'], inputs['attention_mask'])
print("Embeddings shape:", embeddings.shape)  # Should be (batch_size, hidden_size)


Embeddings shape: torch.Size([2, 768])
