# Token-level analysis

In [176]:
from BERTQG.token_generation import load_model, generate_token
import numpy as np

In [177]:
# model chkecpoint paths
bert_model = 'BERTQG/models/bert-base-uncased'
lm_qg = 'BERTQG/models/lm_10k_QG/pytorch_model.bin'
uqa_qg = 'BERTQG/models/uqa_10k_QG/pytorch_model.bin'

# load models
lm_qg, _, _ = load_model(bert_model, lm_qg)
uqa_qg, tokenizer, device = load_model(bert_model, uqa_qg)

05/05/2020 13:48:36 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/05/2020 13:48:36 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/05/2020 13:48:36 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/lm_10k_QG/pytorch_model.bin


05/05/2020 13:48:38 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/05/2020 13:48:38 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/05/2020 13:48:38 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/uqa_10k_QG/pytorch_model.bin


In [178]:
# util functions

def topk_retriever(probs, topk):
    '''
    Input:
        - probs: probability distribution over vocab (output distribution)
        - topk: k to return
    Output:
        - list of tokens (top 1 to k)
    '''
    return tokenizer.convert_ids_to_tokens(reversed(np.argsort(probs)[-topk:].numpy().tolist()))

def ignore_probs(probs, idx_list):
    '''
    Input:
        - probs: probability distribution over vocab (output distribution)
        - idx_list: index list to ignore (to force it as -inf)
    Output:
        - new probs where probabilities of some tokens are ignored
    '''
    for idx in idx_list:
        probs[idx] = -float('inf')
    return probs

In [179]:
context = 'On November 4, 2010, Jakes released his third album, Burning Down The Underground.  Jakes recorded the album primarily at his home studio in San Francisco, with bandmate, Matthew Whitemyer, co-engineering.  The Music Cycle wrote, "Burning Down The Underground stays true to the deep bluesy rock sound that has come to garner Jakes such critical acclaim.  At the same time, however, Jakes continues to push the boundaries of sound and genre creating unique and interesting fusions of music".'
question_text = "what does jakes release"
ans = "Burning Down The Underground"
ans_start = context.find(ans)
assert context[ans_start:ans_start+len(ans)] == ans

In [180]:
lm_qi_token, lm_qi_idx, lm_qi_probs = generate_token(lm_qg, tokenizer, device, context, question_text, ans, ans_start)
uqa_qi_token, uqa_qi_idx, uqa_qi_probs = generate_token(uqa_qg, tokenizer, device, context, question_text, ans, ans_start)

In [181]:
lm_qi_token, lm_qi_idx

('?', 1029)

In [182]:
uqa_qi_token, uqa_qi_idx

('your', 2115)

In [183]:
print('topk idx:', reversed(np.argsort(lm_qi_probs)[-10:]).tolist())
print('topk tokens:', topk_retriever(lm_qi_probs, 10))

topk idx: [1029, 2006, 2144, 1999, 2005, 2013, 2017, 1010, 2004, 1996]
topk tokens: ['?', 'on', 'since', 'in', 'for', 'from', 'you', ',', 'as', 'the']


In [184]:
print('topk idx:', reversed(np.argsort(uqa_qi_probs)[-10:]).tolist())
print('topk tokens:', topk_retriever(uqa_qi_probs, 10))

topk idx: [2115, 2010, 2006, 5180, 2256, 1029, 2017, 1010, 1996, 2026]
topk tokens: ['your', 'his', 'on', 'jake', 'our', '?', 'you', ',', 'the', 'my']


In [185]:
# you may want lm not to generate '?' or a certain token in general for current step
lm_qi_probs[lm_qi_idx] = -float('inf')

# re-check
print('topk idx:', reversed(np.argsort(lm_qi_probs)[-10:]).tolist())
print('topk tokens:', topk_retriever(lm_qi_probs, 10))

topk idx: [2006, 2144, 1999, 2005, 2013, 2017, 1010, 2004, 1996, 102]
topk tokens: ['on', 'since', 'in', 'for', 'from', 'you', ',', 'as', 'the', '[SEP]']


In [186]:
context = 'On November 4, 2010, Jakes released his third album, Burning Down The Underground.  Jakes recorded the album primarily at his home studio in San Francisco, with bandmate, Matthew Whitemyer, co-engineering.  The Music Cycle wrote, "Burning Down The Underground stays true to the deep bluesy rock sound that has come to garner Jakes such critical acclaim.  At the same time, however, Jakes continues to push the boundaries of sound and genre creating unique and interesting fusions of music".'
question_text = "how many people ."
ans = "Burning Down The Underground"
ans_start = context.find(ans)
assert context[ans_start:ans_start+len(ans)] == ans

In [187]:
lm_qi_token, lm_qi_idx, lm_qi_probs = generate_token(lm_qg, tokenizer, device, context, question_text, ans, ans_start)
uqa_qi_token, uqa_qi_idx, uqa_qi_probs = generate_token(uqa_qg, tokenizer, device, context, question_text, ans, ans_start)

In [188]:
print('topk idx:', reversed(np.argsort(uqa_qi_probs)[-10:]).tolist())
print('topk tokens:', topk_retriever(uqa_qi_probs, 10))

topk idx: [2040, 2054, 2073, 2043, 2129, 2339, 3005, 3183, 1000, 2024]
topk tokens: ['who', 'what', 'where', 'when', 'how', 'why', 'whose', 'whom', '"', 'are']


In [189]:
# you may want to force wh-word is not generated after first wh-word
ignore_idx_list = tokenizer.convert_tokens_to_ids(['what', 'who', 'where', 'when', 'how', 'which', 'why', 'whose', 'whom'])
print('ignore_idx_list:', ignore_idx_list)
lm_qi_probs = ignore_probs(uqa_qi_probs, ignore_idx_list)

# re-check
print('topk idx:', reversed(np.argsort(uqa_qi_probs)[-10:]).tolist())
print('topk tokens:', topk_retriever(uqa_qi_probs, 10))

ignore_idx_list: [2054, 2040, 2073, 2043, 2129, 2029, 2339, 3005, 3183]
topk idx: [1000, 2024, 2065, 1998, 2003, 2006, 2001, 2021, 2030, 2004]
topk tokens: ['"', 'are', 'if', 'and', 'is', 'on', 'was', 'but', 'or', 'as']
