In [1]:
import torch
from transformers import AutoTokenizer, AutoModel

In [21]:
from typing import List, Optional

In [2]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

In [3]:
model = AutoModel.from_pretrained('distilbert-base-uncased', output_hidden_states=True)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
test_texts = ["my hovercraft is full of eels", "where do we go now?", "to-ma-to, to-mah-to", "don't call me shirley"]

In [23]:
def gen_embeds(sents: List[str], average: bool = True) -> torch.Tensor:
    """
    sents:  List of sentences to generate embeddings for
    average: Return average over sentences

    Returns:
    torch.Tensor of size num_sents x emb_dim if average = True else
                    size num_sents x max_len, emb_dim
    """
    with torch.no_grad():
        inputs = tokenizer(sents, padding=True, return_tensors="pt")
        logits = model(**inputs).last_hidden_state

        if average:
            # Average across the entire sentence
            return torch.mean(logits, dim=1)
        else:
            return logits


In [24]:
gen_embeds(test_texts)

tensor([[ 0.1746,  0.0238,  0.0530,  ..., -0.0598,  0.0308,  0.1818],
        [ 0.3365, -0.0234,  0.1607,  ...,  0.0293,  0.0692, -0.0304],
        [ 0.0729,  0.2888,  0.0086,  ..., -0.0316,  0.0133,  0.4798],
        [ 0.3095,  0.1425,  0.0594,  ...,  0.1269, -0.0609,  0.1577]])