In [1]:
import torch, transformers
import re

class MultiTokenMaskLM:
    
    def __init__(self, model_name="roberta-large"):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
        self.model = transformers.AutoModelForMaskedLM.from_pretrained(model_name)


    def get_log_probability(self, text, text_to_mask):
        """Returns the log probability of a span included in a larger text"""
        
        # We create two tokenized texts: the original one, and the masked one
        orig_tokens, masked_tokens = self._tokenize_and_mask(text, text_to_mask)
      
        # We search for the token indices that have been masked  
        mask_ids = [i for i in range(len(orig_tokens["input_ids"][0])) 
                    if masked_tokens["input_ids"][0,i]==self.tokenizer.mask_token_id]    
        
        # We compute the log probability token by token, starting with the
        # last token (which is often the family name)
        total_log_prob = 0
        with torch.no_grad():
            for mask_id in mask_ids[::-1]:
                logits = self.model(**masked_tokens).logits
                log_probs = torch.nn.functional.log_softmax(logits[0,mask_id], dim=0)
                actual_id = orig_tokens["input_ids"][0,mask_id]
                total_log_prob += log_probs[actual_id].item()
                
                # Once we are done, we replace the mask with the actual token
                masked_tokens["input_ids"][0, mask_id] = actual_id
                
        return total_log_prob
        
            
    def get_alternatives(self, text, text_to_mask, beam_size=20):
        """Returns a list of possible replacements for a span included in a larger text. 
        The method relies on a  form of beam search"""
        
        # We tokenize and mask the text
        _, tokens = self._tokenize_and_mask(text, text_to_mask)

        # We search for the token indices that have been masked  
        mask_ids = [i for i in range(len(tokens["input_ids"][0])) 
                    if tokens["input_ids"][0,i]==self.tokenizer.mask_token_id]
            
        beam = [(tokens,0)] 
        
        # We search for alternatives token by token
        for mask_id in mask_ids[::-1]:
            new_beam = []  
            for current, current_logprob in beam:           
                for filled, new_logprob in self._fill(current, mask_id, beam_size):
                    new_beam.append((filled, current_logprob + new_logprob))
            
            # We restrict the beam to a maximum size
            beam = sorted(new_beam, key=lambda x : x[1])[-beam_size:]
            print("best replacement so far:", self.tokenizer.decode(beam[-1][0]["input_ids"][0,mask_ids]))
        
        # We finally convert the results into strings
        beam_string = {self.tokenizer.decode(solution["input_ids"][0, mask_ids]):logprob 
                       for solution, logprob in beam[::-1]}
        
        return beam_string      


    def _tokenize_and_mask(self, text, text_to_mask):
        """Returns two tokenized representations of the text: the original one,
        and one where all tokens included in the text span to mask are replaced
        by a special <mask> value."""
        
        if text_to_mask is not None and text_to_mask not in text:
            raise RuntimeError("Text to mask must be included in full text")

        # We run the tokenizer (with offset mapping to find the tokens to mask)
        orig_tokens = self.tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
        offset_mapping = orig_tokens["offset_mapping"][0]
        del orig_tokens["offset_mapping"]
        
    #    print("tokens:", [self.tokenizer.decode(x) for x in orig_tokens["input_ids"][0]])
        
        # We create the masked version
        masked_tokens = {key:value.clone().detach() for key, value in orig_tokens.items()}
        for match in re.finditer(re.escape(text_to_mask), text):
            for i, (tok_start, tok_end) in enumerate(offset_mapping):
                if tok_start >= match.start(0) and tok_end <= match.end(0) and tok_end > tok_start:
                    masked_tokens["input_ids"][0,i] = self.tokenizer.mask_token_id
        
        return orig_tokens, masked_tokens
           
           
    def _fill(self, tokens, mask_id, beam_size=100):
        """Generates possible updated list of tokens where the masked token is replaced.
        Each candidate is associated with a given log-probability"""
        with torch.no_grad():
            logits = self.model(**tokens).logits
            log_probs = torch.nn.functional.log_softmax(logits[0,mask_id], dim=0)

            best_replacements_idx = torch.argsort(log_probs)[-beam_size:]    
            
            for replacement_id in best_replacements_idx:
                new_tokens = {key:value.clone().detach() for key, value in tokens.items()}
                new_tokens["input_ids"][0, mask_id] = replacement_id
                logprob = log_probs[replacement_id].item()
                yield new_tokens, logprob

        
filler = MultiTokenMaskLM(model_name="roberta-large")

text = """Charles Darwin (12 February 1809 – 19 April 1882) was an English naturalist, geologist, and biologist, widely known for contributing to the understanding of evolutionary biology."""
print("Log probability:", filler.get_log_probability(text, "Charles Darwin"))
print("Alternative results:", filler.get_alternatives(text, "Charles Darwin"))


  from .autonotebook import tqdm as notebook_tqdm


Log probability: -2.7288677991600707


2022-11-25 20:15:45.350769: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-25 20:15:46.931706: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


best replacement so far: <mask> Darwin
best replacement so far: Charles Darwin
Alternative results: {'Charles Darwin': -2.7288677991600707, 'William Thomson': -4.342340171337128, 'George Burgess': -5.420966029167175, 'John Smith': -5.676791429519653, 'Edward Smith': -5.809417009353638, 'John Taylor': -5.944948792457581, 'Thomas Hardy': -6.007612705230713, 'John Bates': -6.015552520751953, 'William Hudson': -6.075858116149902, 'Edward Robinson': -6.120614171028137, 'William Smith': -6.148982286453247, 'William Bates': -6.241352081298828, 'James Burgess': -6.290881037712097, 'George Smith': -6.3188393115997314, 'James Thomson': -6.3247315883636475, 'Joseph Smith': -6.367221117019653, 'Edward Edwards': -6.403763771057129, 'Edward Fisher': -6.4369553327560425, 'John Brown': -6.451346755027771, 'Henry Hudson': -6.464964866638184}
