In [2]:
%pip install -r ../requirements.txt -q

Note: you may need to restart the kernel to use updated packages.


In [14]:
import os
import sys
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModel, utils, FlaxBertForTokenClassification, FlaxBertForMultipleChoice
from bertviz import model_view, head_view
from transformers import pipeline
import jax
import matplotlib.pyplot as plt
import seaborn as sns


utils.logging.set_verbosity_error()

# os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [15]:
MODEL_MSFT_L12_H384 =  "microsoft/xtremedistil-l12-h384-uncased"
MODEL_MSFT_L6_H384 = 'microsoft/xtremedistil-l6-h384-uncased'
MODEL_DISTILL = 'distilbert-base-uncased'
MODEL_BERT_BASE = 'bert-base-uncased'
MODEL_BERT_TINY = "prajjwal1/bert-tiny"

In [16]:
DATA_DIR = '../data'
SPECIFIC_ABSTRACT_CSV = f'{DATA_DIR}/merge/specific_abstract.csv'
SPECIFIC_ABSTRACT_DATA = pd.read_csv(SPECIFIC_ABSTRACT_CSV)
SPECIFIC_ABSTRACT_DATA

Unnamed: 0,word,specific,abstract
0,Beautiful,Beautiful girl,Beautiful soul
1,World,The world is very old,He lives in his own world
2,School,The school is near the park.,School is a garden to nurture the mind
3,Oxygen,Oxygen is crucial to life,Music is my oxygen


In [96]:
class TF:
    def __init__(self, model_name):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, resume_download=True)
        self.model = AutoModel.from_pretrained(model_name, resume_download=True, output_attentions=True, output_hidden_states=True)
        self.model.eval()
        self.model.to('cpu')

        self.mc_model = FlaxBertForMultipleChoice.from_pretrained(model_name, resume_download=True, output_attentions=True, output_hidden_states=True)
        
        self.unmasker = pipeline('fill-mask', model=self.model_name, tokenizer=self.tokenizer)

    def __call__(self, text):
        inputs = self.tokenizer(text, return_tensors="pt")
        outputs = self.model(**inputs)
        return outputs

    def get_attention_and_hidden_states(self, text):
        outputs = self(text)
        return outputs.attentions, outputs.hidden_states

    def info(self):
        return pd.DataFrame({
            'Model Name': self.model_name,
            'Layers': self.model.config.num_hidden_layers,
            'Heads': self.model.config.num_attention_heads,
            'Params': self.model.num_parameters()
        }, index=[self.model_name])

    def _process_sentences(self, sentence_a, sentence_b):
        inputs = self.tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
        input_ids = inputs['input_ids']
        if sentence_b:
            token_type_ids = inputs['token_type_ids']
            attention = self.model(input_ids, token_type_ids=token_type_ids)[-1]
            sentence_b_start = token_type_ids[0].tolist().index(1)
            
        else:
            attention = self.model(input_ids)[-1]
            sentence_b_start = None
        input_id_list = input_ids[0].tolist()
        tokens = self.tokenizer.convert_ids_to_tokens(input_id_list)    
        return attention, tokens, sentence_b_start

    def head_view(self, sentence_a, sentence_b=None):
        attention, tokens, sentence_b_start = self._process_sentences(sentence_a, sentence_b)
        head_view(attention, tokens, sentence_b_start)

    def model_view(self, sentence_a, sentence_b=None):
        attention, tokens, sentence_b_start = self._process_sentences(sentence_a, sentence_b)
        model_view(attention, tokens, sentence_b_start)

    def unmask(self, text):
        return self.unmasker(text)

    def choose(self, prompt, choices):
        prompt = [prompt] * len(choices)
        encoding = self.tokenizer(prompt, choices, return_tensors='jax', padding=True, truncation=True)
        outputs = self.mc_model(**{k: v[None, :] for k, v in encoding.items()})
        probs = jax.nn.softmax(outputs.logits, axis=-1)
        probs_list = probs.tolist()[0]
        mapped = {choice: prob for choice, prob in zip(choices, probs_list)}
        df = pd.DataFrame(mapped.items(), columns=['choice', 'probability'])
        return df.sort_values('probability', ascending=False)

    def embedding(self, text):
        inputs = self.tokenizer(text, return_tensors="pt")
        outputs = self.model(**inputs)
        return outputs.pooler_output

    def are_related(self, text1, text2):
        raise NotImplementedError

In [97]:
bert_base = TF(MODEL_BERT_BASE)
bert_base.info()

Unnamed: 0,Model Name,Layers,Heads,Params
bert-base-uncased,bert-base-uncased,12,12,109482240


In [117]:
world = SPECIFIC_ABSTRACT_DATA.iloc[1]
bert_base.head_view(world.specific, world.abstract)

<IPython.core.display.Javascript object>

In [113]:
msft_l6 = TF(model_name=MODEL_MSFT_L6_H384)
msft_l6.info()

Unnamed: 0,Model Name,Layers,Heads,Params
microsoft/xtremedistil-l6-h384-uncased,microsoft/xtremedistil-l6-h384-uncased,6,12,22713216


In [114]:
beautiful = SPECIFIC_ABSTRACT_DATA.iloc[0]
msft_l6.head_view(beautiful.specific, beautiful.abstract)

<IPython.core.display.Javascript object>

In [115]:
school = SPECIFIC_ABSTRACT_DATA.iloc[2]
msft_l6.head_view(school.specific, school.abstract)

<IPython.core.display.Javascript object>

In [118]:
msft_l6.model_view(world.specific, world.abstract)

<IPython.core.display.Javascript object>

In [85]:
bert_base.unmask(f'She is {beautiful.specific} & has a {beautiful.abstract}. She is not [MASK] person.')

[{'score': 0.45051640272140503,
  'token': 1037,
  'token_str': 'a',
  'sequence': 'she is beautiful girl & has a beautiful soul. she is not a person.'},
 {'score': 0.03543044999241829,
  'token': 3819,
  'token_str': 'perfect',
  'sequence': 'she is beautiful girl & has a beautiful soul. she is not perfect person.'},
 {'score': 0.03371156379580498,
  'token': 3376,
  'token_str': 'beautiful',
  'sequence': 'she is beautiful girl & has a beautiful soul. she is not beautiful person.'},
 {'score': 0.030174285173416138,
  'token': 2008,
  'token_str': 'that',
  'sequence': 'she is beautiful girl & has a beautiful soul. she is not that person.'},
 {'score': 0.018938617780804634,
  'token': 2919,
  'token_str': 'bad',
  'sequence': 'she is beautiful girl & has a beautiful soul. she is not bad person.'}]

In [107]:
abstract_prompt = "She has a beautiful soul."

choices = [
    "she is not human",
    "she is human",
    "she is bad person.", 
    "she is good person.",
    "she looks good.",
    "she is woman",
]

bert_base.choose(abstract_prompt, choices)

Unnamed: 0,choice,probability
0,she is not human,0.185077
1,she is human,0.181877
3,she is good person.,0.16429
2,she is bad person.,0.16064
5,she is woman,0.154544
4,she looks good.,0.153572


In [108]:
specific_prompt = "she is beautiful."
bert_base.choose(specific_prompt, choices)

Unnamed: 0,choice,probability
0,she is not human,0.176771
1,she is human,0.17617
2,she is bad person.,0.164809
3,she is good person.,0.16411
4,she looks good.,0.163744
5,she is woman,0.154396


In [109]:
specific_abstract_prompt = "She has a beautiful face and a beautiful soul."
bert_base.choose(specific_abstract_prompt, choices)

Unnamed: 0,choice,probability
0,she is not human,0.184592
1,she is human,0.179992
3,she is good person.,0.163912
2,she is bad person.,0.159642
5,she is woman,0.158469
4,she looks good.,0.153392
