In [71]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM

# OPTIONAL: if you want to have more information on what's happening under the hood, activate the logger as follows
#import logging
#logging.basicConfig(level=logging.INFO)

# See here: https://huggingface.co/transformers/quickstart.html#quick-tour-usage

In [73]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load all 12 layers by setting output_hidden_states to True
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
eval_res = model.eval()



INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/ubuntu/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/ubuntu/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.8f56353af4a709bf5ff0fbc915d8f5b42bfff892cbb6ac98c3c45f481a03c685
INFO:transformers.configuration_utils:Model config BertConfig {
  "_num_labels": 2,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "

In [74]:
def get_sentence_embedding(sent, layer=-1):
    tokenized_text = tokenizer.tokenize("[CLS] {0} [SEP]".format(sent))
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [0 for i in range(len(indexed_tokens))]
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])
    tokens_tensor = tokens_tensor.to('cuda')
    segments_tensors = segments_tensors.to('cuda')
    model.to('cuda')
    with torch.no_grad():
        outputs = model(tokens_tensor, token_type_ids=segments_tensors)
    res = outputs[2][layer].mean(dim=1)
    return res
    
    

In [75]:
A = "This is a delicious restaurant"
B = "I like the food"

layer = 12

vec_A = get_sentence_embedding(A, layer)
vec_B = get_sentence_embedding(B, layer)

print(vec_A.shape)

cos = torch.nn.CosineSimilarity(eps=1e-6)
cos(vec_A, vec_B)


torch.Size([1, 768])


tensor([0.7543], device='cuda:0')

In [90]:
PATH_TO_SENTEVAL = '../../SentEval/'
PATH_TO_DATA = '../../SentEval/data'
import sys
import numpy as np
sys.path.insert(0, PATH_TO_SENTEVAL)
import senteval

# Create dictionary
def create_dictionary(sentences, threshold=0):
    words = {}
    for s in sentences:
        for word in s:
            words[word] = words.get(word, 0) + 1

    if threshold > 0:
        newwords = {}
        for word in words:
            if words[word] >= threshold:
                newwords[word] = words[word]
        words = newwords
    words['<s>'] = 1e9 + 4
    words['</s>'] = 1e9 + 3
    words['<p>'] = 1e9 + 2

    sorted_words = sorted(words.items(), key=lambda x: -x[1])  # inverse sort
    id2word = []
    word2id = {}
    for i, (w, _) in enumerate(sorted_words):
        id2word.append(w)
        word2id[w] = i
    return id2word, word2id


def prepare(params, samples):
    _, params.word2id = create_dictionary(samples)
    return


def batcher(params, batch):
    embeddings = []
    for sent in batch:
        sentvec = get_sentence_embedding(" ".join(sent), params.layer).cpu()
        embeddings.append(sentvec)

    embeddings = np.vstack(embeddings)
    return embeddings


for layer in range(13):
    params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, 'layer' : layer}
    se = senteval.engine.SE(params, batcher, prepare)
    transfer_tasks = ['STS16']
    results = se.eval(transfer_tasks)
    print(layer, results['STS16']['all']['spearman'])

0 {'mean': 0.6193841212648833, 'wmean': 0.6236655610169665}
1 {'mean': 0.6737804119931133, 'wmean': 0.6766365900899679}
2 {'mean': 0.6602513719177379, 'wmean': 0.6633485511078577}
3 {'mean': 0.6438565600590082, 'wmean': 0.648103799477346}
4 {'mean': 0.6171613820033857, 'wmean': 0.6233625810943636}
5 {'mean': 0.6238467150554166, 'wmean': 0.6302992099157712}
6 {'mean': 0.6236550817510895, 'wmean': 0.6319174395739774}
7 {'mean': 0.6320902482095737, 'wmean': 0.6398140597593204}
8 {'mean': 0.620164495640599, 'wmean': 0.6289194011333401}
9 {'mean': 0.6048972089589701, 'wmean': 0.6134637966199321}
10 {'mean': 0.6406573294764532, 'wmean': 0.6456470950944564}
11 {'mean': 0.6385209046542732, 'wmean': 0.6422447982993017}
12 {'mean': 0.6493581642407218, 'wmean': 0.6520864230612654}
