In [1]:
BERT_columns = ['question_title', 'question_body', 'answer']
BERT_max_sentence_size = 512
BERT_embedding_size = 768

In [34]:
import torch
from transformers import BertTokenizer, BertModel
BERT_BASE_UNCASED_LOCATION = "../models/bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained(BERT_BASE_UNCASED_LOCATION, return_dict=False)



In [40]:
inputs = tokenizer("Example sentence for encoding", return_tensors="pt")
tokens = inputs['input_ids']
outputs = model(**inputs, output_hidden_states=True)

In [81]:

def _get_layer_width(inputs):
    tokens = inputs['input_ids']
    return tokens.shape[1]

def _get_hidden_states(outputs):
    return outputs.hidden_states if hasattr(outputs, 'hidden_states') else outputs[2]

def process_BERT_features(
    inputs,
    outputs, 
    layer_start,
    layer_end=0,
    pooling_function=torch.nn.AvgPool2d, 
    pooling_function_args=[1]
    ):
    """
    performs pooling of layer of index `layer start` and returns torch.tensor(1,1,758)
    give layer_end to have a concat pooling, returns torch.tensor(1,1,758 * (layer-end - layer-start))
    put layer_end to None to get [layer_start:] splice
    """
    pooling = pooling_function((_get_layer_width(inputs), *pooling_function_args))
    hidden_states = _get_hidden_states(outputs)
    features = torch.cat(hidden_states[layer_start:layer_end if layer_end != 0 else layer_start + 1], dim=2)
    return pooling(features)

In [39]:
process_BERT_features(inputs, outputs, 1).shape

torch.Size([1, 1, 763])

In [16]:
process_BERT_features(inputs, outputs, 2).shape

torch.Size([1, 1, 763])

In [17]:
process_BERT_features(inputs, outputs, 5).shape

torch.Size([1, 1, 763])

In [18]:
process_BERT_features(inputs, outputs, 5, pooling_function=torch.nn.MaxPool2d).shape

torch.Size([1, 1, 763])

In [23]:
process_BERT_features(inputs, outputs, 5, 7, pooling_function=torch.nn.MaxPool2d).shape

torch.Size([1, 1, 1531])

In [19]:
process_BERT_features(inputs, outputs, 5, 7).shape

torch.Size([1, 1, 1531])

In [20]:
process_BERT_features(inputs, outputs, -3).shape

torch.Size([1, 1, 763])

In [21]:
process_BERT_features(inputs, outputs, -3, -1).shape

torch.Size([1, 1, 1531])

In [22]:
process_BERT_features(inputs, outputs, -3, None).shape

torch.Size([1, 1, 2299])

In [69]:
def features(s):
    inputs = tokenizer(s, return_tensors="pt")
    tokens = inputs['input_ids']
    outputs = model(**inputs, output_hidden_states=True)
    return process_BERT_features(inputs, outputs, 5)

In [82]:
features("testing").shape

torch.Size([1, 1, 768])

In [83]:
features("testing and testing").shape

torch.Size([1, 1, 768])

In [84]:
features("testing and testing and testing").shape

torch.Size([1, 1, 768])