In [1]:
BERT_columns = ['question_title', 'question_body', 'answer']
BERT_max_sentence_size = 512
BERT_embedding_size = 768

In [2]:
import torch
from transformers import BertTokenizer, BertModel
BERT_BASE_UNCASED_LOCATION = "../models/bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained(BERT_BASE_UNCASED_LOCATION, return_dict=True)

In [17]:
inputs = tokenizer("Example sentence for encoding", return_tensors="pt")
tokens = inputs['input_ids']

In [4]:
outputs = model(**inputs, output_hidden_states=True)

In [7]:
len(outputs.hidden_states)

13

In [8]:
def extract_BERT_last_hidden_CLS(outputs):
    return outputs.last_hidden_state[0][0].detach().numpy()
extract_BERT_last_hidden_CLS(outputs).shape

(768,)

In [10]:
outputs.hidden_states[-1].shape

torch.Size([1, 6, 768])

In [22]:
layer_width = tokens.shape[1]

In [24]:
pooling = torch.nn.AvgPool2d(layer_width, 1)
pooling

AvgPool2d(kernel_size=6, stride=1, padding=0)

In [26]:
pooling(outputs.hidden_states[-1]).shape

torch.Size([1, 1, 763])

In [45]:
outputs.hidden_states[2:4][0].shape

torch.Size([1, 6, 768])

In [49]:
torch.cat(outputs.hidden_states[2:4], dim=2).shape

torch.Size([1, 6, 1536])

In [54]:
outputs.hidden_states[-2:None]

(tensor([[[-0.6934, -0.2054, -0.6048,  ..., -1.1626, -0.4721,  1.0977],
          [-0.1397,  0.1470, -0.7247,  ..., -1.1639,  0.3561,  0.3066],
          [-0.0867, -0.1655, -0.1916,  ..., -0.7379, -0.1337,  0.5020],
          [-0.8690,  0.2497,  0.2089,  ..., -0.1149, -0.5614,  0.1031],
          [-0.6641, -0.7928, -0.4635,  ..., -0.1858, -0.2659,  0.1560],
          [ 0.0379,  0.0096, -0.0318,  ..., -0.0077, -0.0451,  0.0137]]],
        grad_fn=<NativeLayerNormBackward>),
 tensor([[[-0.0346, -0.1905, -0.3956,  ..., -0.6355, -0.0444,  0.8815],
          [-0.0055, -0.0130, -0.6201,  ..., -0.3944,  0.5584,  0.5617],
          [-0.1010, -0.1315, -0.0035,  ..., -0.6751, -0.2122,  0.4347],
          [ 0.0680, -0.0280,  0.0513,  ..., -0.2038, -0.4270,  0.0056],
          [-0.3042, -0.6091, -0.3541,  ...,  0.0265, -0.1716,  0.3388],
          [ 1.0098, -0.0306, -0.4799,  ...,  0.2509, -0.9008, -0.3134]]],
        grad_fn=<NativeLayerNormBackward>))

In [74]:
def _get_layer_width(inputs):
    tokens = inputs['input_ids']
    return tokens.shape[1]

def process_BERT_features(
    inputs,
    outputs, 
    layer_start,
    layer_end=0,
    pooling_function=torch.nn.AvgPool2d, 
    pooling_function_args=[1]
    ):
    """
    performs pooling of layer of index `layer start` and returns torch.tensor(1,1,758)
    give layer_end to have a concat pooling, returns torch.tensor(1,1,758 * (layer-end - layer-start))
    put layer_end to None to get [layer_start:] splice
    """
    pooling = pooling_function(_get_layer_width(inputs), *pooling_function_args)
    features = torch.cat(outputs.hidden_states[layer_start:layer_end if layer_end != 0 else layer_start + 1], dim=2)
    return pooling(features)

In [75]:
process_BERT_features(inputs, outputs, 1).shape

torch.Size([1, 1, 763])

In [76]:
process_BERT_features(inputs, outputs, 2).shape

torch.Size([1, 1, 763])

In [77]:
process_BERT_features(inputs, outputs, 5).shape

torch.Size([1, 1, 763])

In [78]:
process_BERT_features(inputs, outputs, 5, pooling_function=torch.nn.MaxPool2d).shape

torch.Size([1, 1, 763])

In [79]:
process_BERT_features(inputs, outputs, 5, 7).shape

torch.Size([1, 1, 1531])

In [80]:
process_BERT_features(inputs, outputs, -3).shape

torch.Size([1, 1, 763])

In [81]:
process_BERT_features(inputs, outputs, -3, -1).shape

torch.Size([1, 1, 1531])

In [82]:
process_BERT_features(inputs, outputs, -3, None).shape

torch.Size([1, 1, 2299])