# Evaluation Scripts Development

In [1]:
import spacy
import codecs
import torch 

import numpy as np 

from sklearn.metrics import v_measure_score
from tokenizations import get_alignments
from transformers import BertTokenizer, BertConfig, BertModel
from tqdm import tqdm
from collections import Counter, OrderedDict

import sys 
sys.path.append('..')
from frtorch import LinearChainCRF
from data_utils import News20Data

In [159]:
import pickle

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
def read_data(data_path='../../data/news/'):
    """Read 20news data, train only"""
    # use the cased data for NER, otherwise spacy does not work with uncased 
    with codecs.open(data_path + '20news.txt', encoding='utf-8') as fd:
        data = fd.readlines()
    idx = np.load(data_path + 'dev_idx.npy')
    data = [data[i][: -1] for i in idx]
    return data

dev_data = read_data()

## Get tag from Spacy

In [5]:
nlp = spacy.load("en_core_web_sm", disable='parser')

In [7]:
pos_tags = []
ent_tags = []
spacy_tokenized = []

for s in tqdm(dev_data):
    doc = nlp(s)
    tokens = []
    pos = []
    ent = []
    for token in doc:
        tokens.append(token.text)
        pos.append(token.pos)
        ent.append(token.ent_type)
        
    spacy_tokenized.append(tokens)
    pos_tags.append(pos)
    ent_tags.append(ent)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26141/26141 [02:52<00:00, 151.18it/s]


## Load model


In [6]:
ckpt_path = '/home/s1946695/Scale-CRF-Latent-Space/models/bertnet_0.0.6.1/ckpt-e16.pt'
ckpt = torch.load(ckpt_path)
state_matrix = ckpt['state_matrix'].to('cuda')
bert_config = BertConfig.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert = BertModel.from_pretrained('bert-base-uncased').to('cuda')
crf = LinearChainCRF()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Some Tests

### Use model to infer latent tags

In [42]:
s = dev_data[100]
inputs = tokenizer(s, return_tensors='pt')

In [44]:
tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

['[CLS]',
 '8',
 ',',
 'st',
 'louis',
 ',',
 'shan',
 '##aha',
 '##n',
 '51',
 '(',
 'emerson',
 ')',
 '19',
 ':',
 '38',
 '.',
 '[SEP]']

In [45]:
inputs

{'input_ids': tensor([[  101,  1022,  1010,  2358,  3434,  1010, 17137, 23278,  2078,  4868,
          1006, 12628,  1007,  2539,  1024,  4229,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [46]:
x_emb = bert(inputs['input_ids'].to('cuda'), attention_mask=inputs['attention_mask'].to('cuda'))[0]

In [52]:
with torch.no_grad():
    transition = torch.matmul(state_matrix, state_matrix.transpose(1, 0))
    emission = torch.matmul(x_emb, state_matrix.transpose(1, 0))
    lens = inputs['attention_mask'].to('cuda').sum(-1)
    tags = crf.proposal_argmax(state_matrix, emission, lens, sum_size=50)
    _, _, _, z_sample, _, _, _ = crf.rsample_approx(state_matrix, emission, lens, sum_size=50, proposal='softmax')

In [53]:
tags

tensor([1565, 1575, 1246, 1246,  638, 1246, 1745, 1745, 1745,  896, 1838, 1565,
        1924,  443, 1044, 1977,  228,  151], device='cuda:0')

In [35]:
emission.argmax(-1)

tensor([[1565, 1575, 1246,  380,  638, 1246,  648,  346,  752,  896, 1838,  638,
         1924,  443, 1652, 1977, 1879,  184]], device='cuda:0')

In [None]:
z_sample

tensor([[1565,  443,  563,  638,  638,  563, 1683,  881, 1883,  896, 1838, 1565,
         1924,  443,  752, 1977,  228,  151]], device='cuda:0')

### Align Latent BERT Tokenization with Spacy Tokenization

In [13]:
s = dev_data[100]
doc = nlp(s)
spacy_tokenized = []
bert_tokenized = []
for token in doc:
    spacy_tokenized.append(token.text)
    bert_tokenized = tokenizer.tokenize(s)
    bert2spacy, spacy2bert = get_alignments(bert_tokenized, spacy_tokenized)

In [17]:
bert_tokenized

['8',
 ',',
 'st',
 'louis',
 ',',
 'shan',
 '##aha',
 '##n',
 '51',
 '(',
 'emerson',
 ')',
 '19',
 ':',
 '38',
 '.']

In [18]:
spacy_tokenized

['8',
 ',',
 'st',
 'louis',
 ',',
 'shanahan',
 '51',
 '(',
 'emerson',
 ')',
 '19',
 ':',
 '38',
 '.']

In [16]:
bert2spacy

[[0],
 [1],
 [2],
 [3],
 [4],
 [5],
 [5],
 [5],
 [6],
 [7],
 [8],
 [9],
 [10],
 [11],
 [12],
 [13]]

## Get Spacy Tags and Alignments

In [7]:
bert_tokenized_all = []
spacy_tokenized_all = []
ent_tags_all = []
pos_tags_all = []
pos_fine_tags_all = []
bert_to_spacy_all = []
spacy_to_bert_all = []

pos_word_dict = {}
pos_fine_word_dict = {}
ent_word_dict = {}

id_to_pos = {}
id_to_fine_pos = {}
id_to_ent = {}
for s in tqdm(dev_data):
    s_bert = tokenizer(s)
    if(len(s_bert['input_ids']) <= 2): 
        # print('!')
        # print(s)
        continue # pass empty strings
    
    doc = nlp(s)

    # get pos, ner, token with spacy
    # TODO: need to decide if pos_word_dict need to convert tokenization to BERT -- currently do not convert
    spacy_tokenized = []
    pos_tags = []
    pos_fine_tags = []
    ent_tags = []
    for token in doc:
        if(token.pos_ not in pos_word_dict):
            pos_word_dict[token.pos_] = [token.text]
        else:
            pos_word_dict[token.pos_].append(token.text)
        if(token.tag_ not in pos_fine_word_dict):
            pos_fine_word_dict[token.tag_] = [token.text]
        else:
            pos_fine_word_dict[token.tag_].append(token.text)
        spacy_tokenized.append(token.text)
        pos_tags.append(token.pos)
        id_to_pos[token.pos] = token.pos_
        ent_tags.append(token.ent_type)
        id_to_ent[token.ent_type] = token.ent_type_
        pos_fine_tags.append(token.tag)
        id_to_fine_pos[token.tag] = token.tag_
        
    ent_tags_all.append(ent_tags)
    pos_tags_all.append(pos_tags)
    pos_fine_tags_all.append(pos_fine_tags)

    for ent in doc.ents:
        if(ent.label_ not in ent_word_dict): ent_word_dict[ent.label_] = [ent.text]
        else: ent_word_dict[ent.label_].append(ent.text)

    # get bert tokenization
    bert_tokenized = tokenizer.tokenize(s)

    bert2spacy, spacy2bert = get_alignments(bert_tokenized, spacy_tokenized)
    bert_tokenized_all.append(bert_tokenized)
    spacy_tokenized_all.append(spacy_tokenized)
    bert_to_spacy_all.append(bert2spacy)
    spacy_to_bert_all.append(spacy2bert)
    
for k in pos_word_dict:
    pos_word_dict[k] = Counter(pos_word_dict[k])   
for k in pos_fine_word_dict:
    pos_fine_word_dict[k] = Counter(pos_fine_word_dict[k])
for k in ent_word_dict:
    ent_word_dict[k] = Counter(ent_word_dict[k])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26141/26141 [03:00<00:00, 145.16it/s]


In [113]:
def compute_representative_words(tag_word_dict, thres=0.9):
    tag_word_dict_repr = {}
        
    for k in tag_word_dict:
        total_freq = 0.
        for w in tag_word_dict[k]: total_freq += tag_word_dict[k][w]
        cumsum = 0
        tag_word_dict_repr[k] = OrderedDict()
        for w, c in tag_word_dict[k].most_common():
            tag_word_dict_repr[k][w] = c
            cumsum += c
            if(cumsum / total_freq > thres): break
            # print(cumsum / total_freq)
    return tag_word_dict_repr

In [114]:
pos_word_dict_repr = compute_representative_words(pos_word_dict)
pos_fine_word_dict_repr = compute_representative_words(pos_fine_word_dict)
ent_word_dict_repr = compute_representative_words(ent_word_dict)

In [116]:
pos_word_dict_repr['PRON']

OrderedDict([('i', 6206),
             ('it', 4266),
             ('you', 3590),
             ('that', 1803),
             ('they', 1792),
             ('this', 1294),
             ('what', 1176),
             ('we', 1148),
             ('he', 1127),
             ('my', 1075),
             ('there', 1053),
             ('your', 933),
             ('who', 813),
             ('me', 765),
             ('them', 663),
             ('which', 661),
             ('their', 656),
             ('his', 525),
             ('all', 370),
             ('us', 326),
             ('_', 317),
             ('our', 305),
             ('something', 284),
             ('its', 276)])

In [146]:
ent_word_dict_repr

{'CARDINAL': {'4': 148,
  '3': 281,
  '70': 5,
  '156': 2,
  '804': 1,
  '38': 16,
  '958': 1,
  '300 fifth': 1,
  '5': 154,
  '1': 369,
  '12': 65,
  'n1': 1,
  '8': 112,
  '51': 6,
  '19': 35,
  '73': 2,
  '331 334': 1,
  'two': 268,
  '82': 5,
  '680x0': 3,
  '40': 18,
  'eight': 6,
  '4 11': 1,
  '3b': 2,
  '28': 14,
  '15430': 1,
  'one': 397,
  '216': 2,
  '368': 2,
  'dozens': 3,
  'thousands': 22,
  '10': 62,
  '05': 5,
  '33': 12,
  'half': 27,
  '508': 3,
  '2': 360,
  '11': 56,
  '13': 31,
  '7': 85,
  'as much as 6': 1,
  '0': 73,
  '215': 2,
  '358': 2,
  '800 753': 1,
  '62': 2,
  'more than 1440k': 1,
  '48': 15,
  '15': 55,
  'three': 101,
  '#': 60,
  '1 6': 2,
  '3 2': 1,
  '6 1': 1,
  '1 2': 2,
  '1 4': 3,
  'four': 44,
  '5 million': 6,
  'six': 14,
  '296': 1,
  '350': 5,
  '43': 10,
  '129': 6,
  '89': 6,
  '602': 4,
  '66mhz': 1,
  '6': 138,
  '185 / 65hr390': 1,
  '9': 62,
  'about an 80': 1,
  '145': 2,
  'more than one': 11,
  '37': 8,
  '117': 4,
  '000': 74,

## Decode Latent Tags

Use sample for now. TODO: update the Viterbi algorithm

In [10]:
dataset = News20Data(data_path='/home/s1946695/RDP/data/news/')
dev_loader = dataset.val_dataloader()

Processing dataset ...
Reading data ...
... 0 seconds
Tokenizing and sorting train data ...
... 68 seconds
Tokenizing dev data ...
... 9 seconds
Tokenizing test data ...
... 20 seconds


In [54]:
latent_tags = []
for batch in tqdm(dev_loader):
    batch_tags = []
    with torch.no_grad():
        x_emb = bert(batch['input_ids'].to('cuda'), attention_mask=batch['attention_mask'].to('cuda'))[0]
        transition = torch.matmul(state_matrix, state_matrix.transpose(1, 0))
        emission = torch.matmul(x_emb, state_matrix.transpose(1, 0))
        lens = batch['attention_mask'].to('cuda').sum(-1)
        # tags, _, s, bp, log_potentials = crf.argmax(transition, emission, lens) # TBC
        # _, _, _, z_sample, _, _, _ = crf.rsample_approx(state_matrix, emission, lens, sum_size=50, proposal='softmax')
        z_sample = crf.proposal_argmax(state_matrix, emission, lens, sum_size=50)
        z_sample = z_sample.cpu().numpy()
        lens = lens.cpu().numpy()
        for li, l in enumerate(lens):
            latent_tags.append(z_sample[li][1:l-1])

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2613/2613 [01:35<00:00, 27.47it/s]


In [55]:
## Align latent tags to Spacy tokenization

latent_tags_spacy = []
for bert2spacy, tags in tqdm(zip(bert_to_spacy_all, latent_tags)):
    prev_spacy_idx = -1
    tags_converted = []
    assert(len(bert2spacy) == len(tags))
    for bi, (si_, t) in enumerate(zip(bert2spacy, tags)):
        for si in si_:
            assert(si == prev_spacy_idx or si == prev_spacy_idx + 1)
            # if many consequtive BERT token correspond to the same spacy token, 
            # then only use the tag for the first bert token
            if(si == prev_spacy_idx + 1): 
                prev_spacy_idx += 1
                tags_converted.append(t)
    latent_tags_spacy.append(tags_converted)

26121it [00:00, 62009.70it/s]


In [56]:
len(spacy_tokenized)

17

In [141]:
def compute_tag_word_dict(tag_all, token_all):
    tag_word_dict = {}
    for tags, tokens in zip(tag_all, token_all):
        for tag, tok in zip(tags, tokens):
            if(tag not in tag_word_dict): tag_word_dict[tag] = [tok]
            else: tag_word_dict[tag].append(tok)
    for t in tag_word_dict: tag_word_dict[t] = Counter(tag_word_dict[t])
    return tag_word_dict

In [146]:
latent_word_dict_bert = compute_tag_word_dict(latent_tags, bert_tokenized_all)
latent_word_dict_bert_repr = compute_representative_words(latent_word_dict_bert)

In [147]:
len(latent_word_dict_bert_repr)

1404

In [117]:
latent_word_dict = {}
for latent, tokens in tqdm(zip(latent_tags_spacy, spacy_tokenized_all)):
    for tag, tok in zip(latent, tokens):
        if(tag not in latent_word_dict): latent_word_dict[tag] = [tok]
        else: latent_word_dict[tag].append(tok)
for k in latent_word_dict: latent_word_dict[k] = Counter(latent_word_dict[k])
latent_word_dict_repr = compute_representative_words(latent_word_dict)

26121it [00:00, 112578.43it/s]


In [144]:
len(latent_word_dict)

1286

In [58]:
latent_word_dict_repr

{335: {'you': 1052},
 110: {"'re": 22, 'mine': 1, '.': 1574, ')': 354, "'ve": 47},
 1094: {'right': 279,
  'help': 184,
  'advice': 25,
  'correct': 22,
  'rights': 71,
  'liberty': 4,
  'assistance': 6,
  'talent': 2,
  'freedom': 11,
  'aid': 8,
  'proper': 6,
  'assist': 3,
  'save': 5,
  'war': 1,
  'righthanded': 1,
  'liberties': 2,
  'power': 45,
  'butler': 1,
  'wrong': 72,
  'wounded': 1,
  'is': 1,
  'helper': 1,
  'thanks': 73,
  'hp': 1,
  'paradise': 1,
  'hospital': 1,
  'helps': 10,
  'helped': 9,
  'fair': 6,
  'appreciated': 1,
  'powers': 4,
  'ok': 9,
  'obey': 2,
  'time': 1,
  'organization': 1,
  'saved': 8,
  'credit': 2,
  'helpful': 7,
  'just': 2,
  'correctly': 6,
  'grace': 4,
  'reached': 1,
  'fine': 13,
  'stick': 2,
  'guide': 3,
  'ensure': 1,
  'good': 1,
  'guys': 1,
  'period': 1,
  'glory': 1,
  'quality': 1,
  'support': 8,
  'enemies': 1,
  'services': 8,
  'relief': 4,
  'favor': 4,
  'helping': 3,
  'along': 1,
  'hurt': 4,
  'left': 20,
  'tru

In [149]:
cnt = 0
for l in latent_word_dict_bert_repr:
    if(l not in latent_word_dict_repr):
        print(cnt, latent_word_dict_bert_repr[l])
        cnt += 1

0 OrderedDict([('.', 1)])
1 OrderedDict([('##p', 1), ('##ile', 1)])
2 OrderedDict([('##pel', 1), ('##q', 1)])
3 OrderedDict([('##cl', 1), ('##os', 1)])
4 OrderedDict([('##tra', 1), ('##9', 1)])
5 OrderedDict([('##ary', 1)])
6 OrderedDict([('##ra', 1)])
7 OrderedDict([('##cian', 1)])
8 OrderedDict([('##pg', 1), ('.', 1)])
9 OrderedDict([('.', 1)])
10 OrderedDict([('.', 1)])
11 OrderedDict([('s', 1)])
12 OrderedDict([('##ff', 1), ('.', 1)])
13 OrderedDict([('t', 1)])
14 OrderedDict([('.', 1)])
15 OrderedDict([('.', 1)])
16 OrderedDict([('.', 1), ('##1', 1)])
17 OrderedDict([('##hic', 1)])
18 OrderedDict([('##he', 1), ('##lm', 1)])
19 OrderedDict([('##he', 1)])
20 OrderedDict([('##mo', 1), ('##dp', 1)])
21 OrderedDict([('.', 1)])
22 OrderedDict([('.', 1)])
23 OrderedDict([('.', 1)])
24 OrderedDict([('.', 1)])
25 OrderedDict([('##cm', 1)])
26 OrderedDict([('##rom', 1)])
27 OrderedDict([('##wi', 1), ('.', 1)])
28 OrderedDict([('##ft', 1), ('##8', 1)])
29 OrderedDict([('##86', 1)])
30 Ordere

## Compute V measure

### Inferred Tags

In [59]:
latent_tags_spacy_ = []
for l in latent_tags_spacy: latent_tags_spacy_.extend(l)
ent_tags_all_ = []
for l in ent_tags_all: ent_tags_all_.extend(l)
pos_tags_all_ = []
for l in pos_tags_all: pos_tags_all_.extend(l)
pos_fine_tags_all_ = []
for l in pos_fine_tags_all: pos_fine_tags_all_.extend(l)

In [60]:
v_measure_score(np.array(ent_tags_all_), np.array(latent_tags_spacy_))

0.05114213914043182

In [61]:
v_measure_score(np.array(pos_tags_all_), np.array(latent_tags_spacy_))

0.4173038418387589

In [62]:
v_measure_score(np.array(pos_fine_tags_all_), np.array(latent_tags_spacy_))



0.46497892654755996

### Random Tags

In [19]:
random_tags = np.random.randint(0, 2000, len(latent_tags_spacy_))

In [20]:
random_word_dict = {}
for tokens in tqdm(spacy_tokenized_all):
    rdm = np.random.randint(0, 2000, len(tokens))
    for tag, tok in zip(rdm, tokens):
        if(tag not in random_word_dict): random_word_dict[tag] = [tok]
        else: random_word_dict[tag].append(tok)
for k in random_word_dict: random_word_dict[k] = Counter(random_word_dict[k])
random_word_dict_repr = compute_representative_words(random_word_dict)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26121/26121 [00:00<00:00, 32771.16it/s]


In [21]:
v_measure_score(np.array(ent_tags_all_), random_tags)

0.007241495219772309

In [22]:
v_measure_score(np.array(pos_tags_all_), random_tags)

0.007521136557212558

In [23]:
v_measure_score(np.array(pos_fine_tags_all_), random_tags)



0.01908628145166098

## Compute Aligned Latent Tags

In [118]:
def align_tags(latent_tags, defined_tags, thres=0.9):
    latent_to_defined = {}
    for l in latent_tags:
        l_occ = 0
        for w in latent_tags[l]: l_occ += latent_tags[l][w]
        for d in defined_tags:
            d_repr_words = set(defined_tags[d].keys())
            overlap = 0
            for w in latent_tags[l]:
                if(w in d_repr_words): overlap += latent_tags[l][w]
            # print(overlap, l_occ, l, d)
            if(overlap / l_occ > thres): 
                latent_to_defined[l] = d
                break
    return latent_to_defined

In [119]:
latent_to_pos = align_tags(latent_word_dict_repr, pos_word_dict_repr)

In [206]:
pos_word_dict_repr.keys()

dict_keys(['PRON', 'AUX', 'ADJ', 'PUNCT', 'DET', 'NOUN', 'ADP', 'CCONJ', 'ADV', 'VERB', 'NUM', 'PART', 'PROPN', 'SCONJ', 'SYM', 'X', 'INTJ'])

In [124]:
len(latent_word_dict)

1286

In [121]:
len(latent_to_pos)

632

In [122]:
full_acc = 0
for l in latent_word_dict: full_acc += np.sum(list(latent_word_dict[l].values()))

In [123]:
pos_covered = 0
for l in latent_to_pos: pos_covered += np.sum(list(latent_word_dict[l].values()))
print(pos_covered / float(full_acc))

0.5209022129840176


In [125]:
latent_to_fine_pos = align_tags(latent_word_dict_repr, pos_fine_word_dict_repr)

In [209]:
latent_to_fine_pos

{335: 'PRP',
 867: ',',
 1176: "''",
 877: 'DT',
 667: 'VBZ',
 1335: "''",
 898: 'IN',
 1284: 'PRP',
 316: 'CC',
 1309: 'NN',
 928: 'VBP',
 343: 'DT',
 1834: 'NN',
 843: 'DT',
 1821: 'NN',
 81: 'NNP',
 1575: 'CD',
 498: ',',
 563: ',',
 34: 'IN',
 71: 'DT',
 184: 'NN',
 1131: '.',
 694: 'PRP',
 1852: 'RB',
 1684: 'NN',
 948: 'IN',
 517: 'IN',
 1784: 'NN',
 1050: 'NN',
 1838: '-LRB-',
 1701: 'NNP',
 121: 'NN',
 1322: 'IN',
 893: 'DT',
 1687: '-LRB-',
 1105: 'IN',
 1717: 'PRP',
 1202: 'DT',
 894: 'DT',
 734: ',',
 718: 'NN',
 446: 'DT',
 670: 'PRP$',
 569: 'IN',
 1785: 'CC',
 310: 'WRB',
 1933: 'DT',
 601: 'VBD',
 1045: 'IN',
 213: ',',
 795: 'DT',
 1102: 'RB',
 105: 'NN',
 1113: 'NN',
 1223: 'NN',
 75: 'NN',
 228: 'NN',
 293: 'NFP',
 1601: 'CC',
 608: 'NNS',
 935: ',',
 132: 'PRP',
 926: 'NNP',
 272: 'DT',
 978: 'IN',
 199: 'RB',
 1870: 'IN',
 1557: 'NN',
 1730: ',',
 1779: 'DT',
 339: 'WDT',
 466: 'DT',
 25: 'IN',
 1760: 'IN',
 463: 'VBZ',
 1441: 'DT',
 42: 'RB',
 229: 'IN',
 380: 'NNP

In [126]:
len(latent_to_fine_pos)

589

In [127]:
fine_pos_covered = 0
for l in latent_to_fine_pos: fine_pos_covered += np.sum(list(latent_word_dict[l].values()))
print(fine_pos_covered / float(full_acc))

0.43903017451836557


In [128]:
latent_to_ent = align_tags(latent_word_dict_repr, ent_word_dict_repr)

In [129]:
len(latent_to_ent)

51

In [130]:
ent_covered = 0
for l in latent_to_ent: ent_covered += np.sum(list(latent_word_dict[l].values()))
print(ent_covered / float(full_acc))

0.014278141260828927


In [131]:
total_aligned = set(latent_to_pos.keys()).union(set(latent_to_fine_pos.keys())).union(set(latent_to_ent.keys()))
num_not_aligned = 2000 - len(total_aligned)
print(num_not_aligned)

1359


In [132]:
total_covered = 0
for l in total_aligned: total_covered += np.sum(list(latent_word_dict[l].values()))
print(1 - total_covered / float(full_acc))

0.4790745365009835


In [133]:
not_aligned = set(range(2000)) - total_aligned

## States not covered

In [134]:
not_aligned_occ = []
for l in not_aligned: 
    if(l in latent_word_dict_repr):
        not_aligned_occ.append((l, np.sum(list(latent_word_dict_repr[l].values()))))
not_aligned_occ.sort(key=lambda x:x[1], reverse=True)

In [135]:
not_aligned_occ[2]

(1552, 2097)

In [140]:
len(not_aligned_occ)

645

In [156]:
not_align_occ_total = np.sum(list(x[1] for x in not_aligned_occ))
not_align_occ_160 = np.sum(list(x[1] for x in not_aligned_occ[:160]))
print(not_align_occ_160, not_align_occ_total, not_align_occ_160 / not_align_occ_total)

168992 186355 0.9068283652169247


In [158]:
for l, c in not_aligned_occ[:20]:
    print(l)
    print(latent_word_dict_repr[l])

1861
OrderedDict([(':', 2265), ('=', 248), ('called', 104), ('\\', 100), ('call', 79)])
1479
OrderedDict([('time', 336), ('day', 169), ('days', 89), ('night', 76), ('week', 59), ('1993', 50), ('year', 50), ('today', 49), ('months', 47), ('hours', 46), ('weeks', 40), ('1992', 36), ('ago', 36), ('month', 34), ('times', 34), ('years', 33), ('period', 32), ('morning', 31), ('1991', 25), ('1990', 25), ('minutes', 25), ('hour', 21), ('sunday', 19), ('date', 18), ('yesterday', 18), ('daily', 17), ('while', 17), ('1989', 16), ('tonight', 16), ('century', 15), ('moment', 15), ('forever', 15), ('seconds', 15), ('weekend', 14), ('1988', 13), ('saturday', 13), ('.', 12), ('minute', 12), ('1982', 11), ('1986', 11), ('on', 11), ('friday', 11), ('era', 10), ('tuesday', 10), ('pm', 10), ('tomorrow', 10), ('late', 10), ('evening', 9), ('1983', 9), ('clock', 9), ('afternoon', 8), ('second', 8), ('1972', 8), ('1987', 7), ('93', 7), ('wait', 7), ('1980', 6), ('thursday', 6), ('1967', 6), ('monthly', 6), (

## Alignment of Random Tags

In [220]:
random_to_pos = align_tags(random_word_dict_repr, pos_word_dict_repr)

In [221]:
len(random_to_pos)

0

In [222]:
random_to_fine_pos = align_tags(random_word_dict_repr, pos_fine_word_dict_repr)

In [223]:
len(random_to_fine_pos)

0

In [224]:
random_to_ent = align_tags(random_word_dict_repr, ent_word_dict_repr)

In [225]:
len(random_to_ent)

0

## Compute Precision and Recall

In [179]:
id_to_ent

{0: '',
 397: 'CARDINAL',
 380: 'PERSON',
 383: 'ORG',
 384: 'GPE',
 391: 'DATE',
 392: 'TIME',
 396: 'ORDINAL',
 386: 'PRODUCT',
 381: 'NORP',
 394: 'MONEY',
 387: 'EVENT',
 395: 'QUANTITY',
 9191306739292312949: 'FAC',
 389: 'LANGUAGE',
 393: 'PERCENT',
 385: 'LOC',
 390: 'LAW',
 388: 'WORK_OF_ART'}

In [180]:
id_to_pos

{95: 'PRON',
 87: 'AUX',
 84: 'ADJ',
 97: 'PUNCT',
 90: 'DET',
 92: 'NOUN',
 85: 'ADP',
 89: 'CCONJ',
 86: 'ADV',
 100: 'VERB',
 93: 'NUM',
 94: 'PART',
 96: 'PROPN',
 98: 'SCONJ',
 99: 'SYM',
 101: 'X',
 91: 'INTJ'}

In [182]:
id_to_fine_pos

{13656873538139661788: 'PRP',
 9188597074677201817: 'VBP',
 10554686591937588953: 'JJ',
 2593208677638477497: ',',
 14143520107006108953: "''",
 15267657372422890137: 'DT',
 15308085513773655218: 'NN',
 13927759927860985106: 'VBZ',
 1292078113972184607: 'IN',
 17571114184892886314: 'CC',
 164681854541413346: 'RB',
 12646065887601541794: '.',
 783433942507015291: 'NNS',
 14200088355797579614: 'VB',
 272890857012483650: 'JJR',
 8427216679587749980: 'CD',
 11532473245541075862: ':',
 15794550382381185553: 'NNP',
 17111077179131903759: '-LRB-',
 2465883113906300949: '-RRB-',
 74: 'POS',
 16235386156175103506: 'MD',
 1534113631682161808: 'VBG',
 4062917326063685704: 'PRP$',
 5595707737748328492: 'TO',
 17524233984504158541: 'WRB',
 17109001835818727656: 'VBD',
 3822385049556375858: 'VBN',
 14872845191859177490: 'NFP',
 6860118812490040284: 'RP',
 17202369883303991778: 'WDT',
 15361090031084224697: 'EX',
 99: 'SYM',
 16530679158541427010: 'LS',
 4969857429396651903: '``',
 189557958894700426

In [71]:
pred_ent = 0
recall_ent = 0
prec_ent = 0
ent_to_id = {id_to_ent[i]: i for i in id_to_ent}
for e, l in zip(ent_tags_all_, latent_tags_spacy_):
    if(l in latent_to_ent):
        el = ent_to_id[latent_to_ent[l]]
        prec_ent += 1
        if(el == e and e in id_to_ent and e != 0): 
            pred_ent += 1
    if(e in id_to_ent and e != 0): recall_ent += 1
print('prec', pred_ent / prec_ent)
print('recl', pred_ent / recall_ent)

prec 0.5446921209446038
recl 0.09227893064124135


In [195]:
total_ent

26745

In [72]:
pred_pos = 0
recall_pos = 0
prec_pos = 0
pos_to_id = {id_to_pos[i]: i for i in id_to_pos}
for e, l in zip(pos_tags_all_, latent_tags_spacy_):
    if(l in latent_to_pos):
        el = pos_to_id[latent_to_pos[l]]
        prec_pos += 1
        if(el == e and e in id_to_pos): 
            pred_pos += 1
    if(e in id_to_pos): recall_pos += 1
print('prec', pred_pos / prec_pos)
print('recl', pred_pos / recall_pos)

prec 0.7487416049851187
recl 0.3807806592916033


In [73]:
prec_fine_pos = 0
recl_fine_pos = 0
pred_fine_pos = 0
fine_pos_to_id = {id_to_fine_pos[i]: i for i in id_to_fine_pos}
for e, l in zip(pos_fine_tags_all_, latent_tags_spacy_):
    if(l in latent_to_fine_pos):
        el = fine_pos_to_id[latent_to_fine_pos[l]]
        prec_fine_pos += 1
        if(el == e and e in id_to_fine_pos): 
            pred_fine_pos += 1
    if(e in id_to_fine_pos): recl_fine_pos += 1
print('prec', pred_fine_pos / prec_fine_pos)
print('recl', pred_fine_pos / recl_fine_pos)

prec 0.7227927725536764
recl 0.31418188412873344


## Save Everything

In [161]:
len(spacy_tokenized_all)

26121

In [162]:
pickle.dump(spacy_tokenized_all, open('spacy_tokenized.pkl', 'wb'))

In [163]:
pickle.dump(bert_tokenized_all, open('bert_tokenized.pkl', 'wb'))

In [164]:
pickle.dump(pos_tags_all, open('pos_tags.pkl', 'wb'))

In [165]:
pickle.dump(ent_tags_all, open('ent_tags.pkl', 'wb'))

In [166]:
pickle.dump(pos_fine_tags_all, open('fine_pos_tags.pkl', 'wb'))

In [167]:
pickle.dump(bert_to_spacy_all, open('bert_to_spacy.pkl', 'wb'))

In [168]:
pickle.dump(pos_word_dict_repr, open('pos_word_dict_repr.pkl', 'wb'))

In [169]:
pickle.dump(pos_fine_word_dict_repr, open('fine_pos_word_dict_repr.pkl', 'wb'))

In [170]:
pickle.dump(ent_word_dict_repr, open('ent_word_dict_repr.pkl', 'wb'))

In [171]:
pickle.dump(id_to_pos, open('id_to_pos.pkl', 'wb'))

In [172]:
pickle.dump(id_to_ent, open('id_to_ent.pkl', 'wb'))

In [174]:
pickle.dump(id_to_fine_pos, open('id_to_fine_pos.pkl', 'wb'))