In [3]:
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import numpy as np
import torch

In [4]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', cache_dir='/home/sunkaran/scr/sunkaraneni/.cache')
model = RobertaForSequenceClassification.from_pretrained('roberta-base', cache_dir='/home/sunkaran/scr/sunkaraneni/.cache')

In [5]:
SPECIAL_TOKENS = ["<s>", "<eos>", "<therapist>", "<client>", "<utterance>"]
ATTR_TO_SPECIAL_TOKEN = {'bos_token': '<s>', 'eos_token': '</s>', 'pad_token': '<pad>',
                         'additional_special_tokens': ["<therapist>", "<patient>", "<utterance>", "</therapist>", "</patient>", "</utterance>"]}

def add_special_tokens_(model, tokenizer):
  """ Add special tokens to the tokenizer and the model if they have not already been added. """
  orig_num_tokens = len(tokenizer.encoder)
  num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) # doesn't add if they are already there
  if num_added_tokens > 0:
      model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens) # doesn't mess with existing tokens

add_special_tokens_(model, tokenizer)

In [7]:
symbol_dict = {
    'SPECIAL_START_TOKEN_IDS': set(tokenizer.convert_tokens_to_ids(ATTR_TO_SPECIAL_TOKEN['additional_special_tokens'])[:3]),
    'SPECIAL_END_TOKEN_IDS': set(tokenizer.convert_tokens_to_ids(ATTR_TO_SPECIAL_TOKEN['additional_special_tokens'])[3:]),
    'BOS_TOKEN_ID': tokenizer.bos_token_id,
    'EOS_TOKEN_ID': tokenizer.eos_token_id,
    'PAD_TOKEN_ID': tokenizer.pad_token_id
  }
    
def generate_dialogue_attention_mask(batch):
  mask = -10000 * torch.ones((batch.shape[0], batch.shape[1], batch.shape[1])) # 12 heads is fixed
  for i in np.arange(batch.shape[0]):
    example_special_idx = torch.nonzero(sum(batch[i] == t for t in (set.union(symbol_dict['SPECIAL_START_TOKEN_IDS'], symbol_dict['SPECIAL_END_TOKEN_IDS'])))).flatten().tolist()
    last_idx = None
    for idx, token_id in enumerate(batch[i].tolist()):
      if token_id == symbol_dict['PAD_TOKEN_ID']:
        break
      if token_id == symbol_dict['BOS_TOKEN_ID'] or token_id == symbol_dict['EOS_TOKEN_ID']:
        mask[i, idx, example_special_idx] = 0 # attend to other special tokens
        mask[i, example_special_idx, idx] = 0 # let other special tokens attend to this
        mask[i, idx, idx] = 0 # attend to self
        if token_id == symbol_dict['EOS_TOKEN_ID']:
          mask[i, idx, 0] = 0 # eos attends to bos
          mask[i, 0, idx] = 0 # bos attends to eos
      elif token_id in symbol_dict['SPECIAL_START_TOKEN_IDS']:
         mask[i, idx, example_special_idx] = 0 # attend to other special tokens
         last_idx = idx
      elif token_id in symbol_dict['SPECIAL_END_TOKEN_IDS']:
         mask[i, idx, example_special_idx] = 0
         span_range = np.arange(last_idx, idx+1) # starts from the last opening special token to including this special token
         x, y = np.meshgrid(span_range, span_range)
         x, y = x.flatten(), y.flatten()
         span_product = np.array(list(zip(x, y))) # 2-D array
         mask[i, span_product[:, 0], span_product[:, 1]] = 0
  # because we use multi-headed attention
  return mask

In [8]:
tokenizer.decode(tokenizer.encode('<s> <therapist> I really dislike you </therapist> <patient> <utterance> I know, you suck</utterance><patient></s>', add_special_tokens=False))

'<s> <therapist>  I really dislike you </therapist> <patient> <utterance>  I know, you suck </utterance> <patient> </s>'

In [9]:
examples = [
    '<s> <therapist> I really dislike you </therapist> <patient> <utterance> I know, you suck</utterance><patient></s>',
    '<s> No context<utterance> I know, you suck </utterance> <therapist> </s>'
]

In [13]:
inputs = tokenizer.batch_encode_plus(examples, add_special_tokens=False, pad_to_max_length=True, return_tensors='pt')

In [11]:
inputs.update({
    "attention_mask": generate_dialogue_attention_mask(inputs['input_ids'])
})

In [14]:
model(**inputs)

(tensor([[-0.1333,  0.0890],
         [-0.1192,  0.0830]], grad_fn=<AddmmBackward>),)

In [None]:
transformers.modeling_bert.BertSelfAttention