In [3]:
from tqdm import tqdm
import numpy as np
import torch
from transformers import BertTokenizer, BertModel

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('Device:', torch.cuda.get_device_name(0))
 
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device('cpu')

No GPU available, using the CPU instead.


In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased',
                                          do_lower_case=False)
 
# need all hidden states to get last 4 layers
pretrained = BertModel.from_pretrained('bert-base-multilingual-cased',
                                       output_hidden_states=True) 
        
# move pretrained model to device
pretrained = pretrained.to(device)


Downloading:   0%|          | 0.00/996k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/625 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/714M [00:00<?, ?B/s]

In [8]:
class Preprocess:
 
    '''
    Tokenize sentences and obtain hidden representations for annotated tokens 
    '''
    def __init__(self, tokenizer, pretrained):
 
        '''
        Args:
        tokenizer: transformers’ tokenizer object
        pretrained: transformers’ model object
        '''
 
        self.tokenizer = tokenizer
        self.pretrained = pretrained
        
    def __call__(self, sentence, ids):
 
        '''
        Args:
        sentence (List[str]): list of tokens
        ids (List[int and str]): List of instance ids for annotation (-1 if other)
        '''
 
        tokenizer = self.tokenizer
 
        tokenized_sent = [torch.tensor([tokenizer.encode(tokenizer.cls_token,
                          add_special_tokens=False)])] # add [CLS] token
        annotation_mask = [-1] 
        for tok, idx in zip(sentence, ids):
	     # tokenizer.encode returns list of token ids
            word_ids = torch.tensor([tokenizer.encode(tok, 
                                     add_special_tokens=False)])             
            # ignore empty / invalid tokens
            if word_ids.size() == torch.Size([1, 0]):
                continue
 
            tokenized_sent.append(word_ids)
 
            if idx != -1: # annotated instance → mask it for later 
                annotation_mask.extend([1] * (word_ids.size(dim=-1))) 
 
            else: # other
                annotation_mask.extend([-1] * word_ids.size(dim=-1))
 
        tokenized_sent.append(torch.tensor([tokenizer.encode(tokenizer.sep_token, 
                              add_special_tokens=False)])) # add [SEP] token
        annotation_mask.append(-1)
 
        # concatenate tensors to create 1 continuing tensor 
        # (format: tensor([[a, b, c]]))
        tokenized_sent = torch.cat(tokenized_sent, dim=-1)
        
        # run pretrained model to get hidden representations 
        # - tokenized_sent = list of token ids
        # - annotation_mask = list of -1 or 1 
        #   (tokens corresponding to the annotated word is 1 and -1 otherwise)
        annotated_emb = self.get_embeddings(tokenized_sent, annotation_mask)
 
        return annotated_emb
        
    def get_embeddings(self, tokenized_sent, annotation_mask, freeze=True):
 
        pretrained = self.pretrained
 
        # freeze parameters in BERT
        if freeze:
            for param in pretrained.parameters():
                param.requires_grad = False
 
        # get hidden representations
        annotated_embs = []
 
        tokenized_sent = tokenized_sent.to(device)
        # run pretrained model to get hidden representations
        # output format = (last_hidden_state, pooler_output, 
        #                  hidden_states[optional], attentions[optional])
        with torch.no_grad():
            embs = pretrained(tokenized_sent)
 
        # hidden states = input_emb + hidden_state_embs 
        # --> hidden_state_embs = [1:]
        hidden_embs = embs[2][1:] 
        for index, mask in enumerate(annotation_mask):
            if mask == 1:
                # need to squeeze() to match dimensions 
                # (size: [1, #token] --> [#token])
                out = sum([hidden_embs[i].squeeze()[index] 
                           for i in [-1, -2, -3, -4]])
                annotated_embs.append(out.cpu())
 
        annotated_emb = torch.mean(torch.stack(annotated_embs), dim=0)
 
        return annotated_emb


In [11]:
preprocess = Preprocess(tokenizer, pretrained)

'''
Assume we have a sentence “I go to the bank to deposit some money .”, 
and the word “bank” is annotated (we want the embedding for “bank”).
'''
 
sentence = 'I go to the bank to deposit some money .'
sentence = sentence.split(' ')
annotated_id = 4 # 5-th word “bank” is annotated
 
ids = [-1] * len(sentence)
ids[annotated_id] = 1 # mask the annotated word
 
ctx_emb = preprocess(sentence, ids)

print(ctx_emb)
