# Comparing the overlap between inferred tags with a supervised tagger

In [1]:
import spacy
import codecs

import numpy as np 

from tokenizations import get_alignments
from transformers import BertTokenizer
from tqdm import tqdm
from collections import Counter

In [4]:
def read_data(data_path='../../data/news/'):
    """Read 20news data, train only"""
    # use the cased data for NER, otherwise spacy does not work with uncased 
    with codecs.open(data_path + '20news_cased.txt', encoding='utf-8') as fd:
        data = fd.readlines()
    train_idx = np.load(data_path + 'train_index.npy')
    train_data = [data[i][: -1] for i in train_idx]
    return train_data

In [5]:
train_data = read_data()

In [6]:
pos_dict = {}
ent_dict = {}
token_bert2spacy = {}

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
nlp = spacy.load("en_core_web_sm", disable='parser')

In [8]:
for s in tqdm(train_data[:10000]):
    doc = nlp(s)

    spacy_tokenized = []
    for token in doc:
        if(token.pos_ not in pos_dict):
            pos_dict[token.pos_] = [token.text]
        else:
            pos_dict[token.pos_].append(token.text)
            spacy_tokenized.append(token.text)

    for ent in doc.ents:
        if(ent.label_ not in ent_dict): ent_dict[ent.label_] = [ent.text]
        else: ent_dict[ent.label_].append(ent.text)

    bert_tokenized = tokenizer.tokenize(s)

    bert2spacy, spacy2bert = get_alignments(bert_tokenized, spacy_tokenized)
    for w, w_ in enumerate(bert2spacy):
        w = bert_tokenized[w]
        if(w not in token_bert2spacy):
            token_bert2spacy[w] = []
        for wi in w_:
            token_bert2spacy[w].append(spacy_tokenized[wi])

100%|██████████| 10000/10000 [01:00<00:00, 165.43it/s]


In [11]:
for p in pos_dict:
    pos_dict[p] = Counter(pos_dict[p])
    
for e in ent_dict:
    ent_dict[e] = Counter(ent_dict[e])
    
for w in token_bert2spacy:
    token_bert2spacy[w] = Counter(token_bert2spacy[w])

In [13]:
print(len(pos_dict), len(ent_dict), len(token_bert2spacy))

17 18 13398


In [33]:
pos_dict.keys()

dict_keys(['VERB', 'PRON', 'AUX', 'DET', 'ADJ', 'NOUN', 'PUNCT', 'PROPN', 'ADV', 'SCONJ', 'ADP', 'PART', 'CCONJ', 'SYM', 'NUM', 'INTJ', 'X'])

In [35]:
ent_dict.keys()

dict_keys(['ORG', 'DATE', 'PRODUCT', 'NORP', 'PERSON', 'CARDINAL', 'GPE', 'TIME', 'ORDINAL', 'LAW', 'QUANTITY', 'WORK_OF_ART', 'LANGUAGE', 'MONEY', 'FAC', 'EVENT', 'LOC', 'PERCENT'])

# Read state-word dictionary

In [29]:
state_dict = {}
state_freq = {}
with codecs.open('../../local/bertnet_0.0.4.10.4_stored/bertnet_dev_epoch_17_s2w.txt', encoding='utf-8') as fd:
    lines = fd.readlines()
    for li, l in enumerate(lines):
        if(l.startswith('state') and li %3 == 0):
            l = l[:-1].split()
            state_id = l[1]
            freq = l[3]
            freq_no_sw = l[5]
            state_freq[state_id] = {'freq': float(freq), 'freq_no_sw': float(freq_no_sw)}
            state_dict[state_id] = []
        elif(lines[li - 1].startswith('state')):
            words = l.split(' | ')
            for w in words[:-1]:
                w = w.split()
#                 print(w)
                state_dict[state_id].append((w[0], float(w[1])))

## Measurement of alignment 1, dominate word set
* Definition: W(s) 95% dominate word set = the set of words that account for the 95% of state / POS vocab.
* if W(s) \in W(POS) we say s align with a pos

## Measurement of alignment 2, word occurance 
* If 50% of state word occurance exist in a POS dominate set, then view this as an alignment

## Implementation of dominate word set alignment

In [28]:
# find non stop states 

non_sw_states = []
for s in state_freq: 
    if(state_freq[s]['freq_no_sw'] / float(state_freq[s]['freq']) > 0.5):
        non_sw_states.append(s)
print('%d non stop states' % len(non_sw_states))

1351 non stop states


In [40]:
non_sw_states[:10]

['171', '476', '710', '254', '1419', '1972', '403', '1488', '1556', '243']

In [43]:
# for each non-SW state, calculate their 95% vocab

state_dominate_set = {}
for s in non_sw_states:
    freq = state_freq[s]['freq']
    wf_cumu = 0
    state_dominate_set[s] = []
    for w, wf in state_dict[s]:
        wf_cumu += wf
        if(wf_cumu / freq < 0.9): 
            state_dominate_set[s].append(w)

In [77]:
# remove subwords clusters

def test_subword_cluster(wset):
    """if half of the cluster is subwords, then this is a subword cluster"""
    set_len = len(wset)
    num_subwords = 0
    for w in wset:
        if(w.startswith('##')): num_subwords += 1
        if(len(w) <= 2): num_subwords += 1
    if(num_subwords > set_len // 2): return True
    else: return False
    
non_sw_states = []
for s in state_dominate_set:
    if(test_subword_cluster(state_dominate_set[s])): pass
    else: non_sw_states.append(s)

In [78]:
len(non_sw_states)

1060

In [82]:
# convert dominate set to gpt2 tokenization

state_dominate_set_gpt_token = {}
for s in non_sw_states:
    w_ = []
    for w in state_dominate_set[s]:
        if(w in token_bert2spacy):
            spacy_tokens_w = token_bert2spacy[w].keys()
            spacy_tokens_w = set([t.lower() for t in spacy_tokens_w])
            w_.extend(list(spacy_tokens_w))
        else: w_.append(w)
    state_dominate_set_gpt_token[s] = set(w_)

In [80]:
non_sw_states[:10]

['476', '710', '254', '1419', '1972', '403', '1488', '1556', '243', '1410']

In [84]:
state_dominate_set_gpt_token['710']

{'abuse',
 'accident',
 'attack',
 'attacks',
 'bad',
 'badanes',
 'badatom',
 'badcolor',
 'badertscher',
 'badfont',
 'conflict',
 'crash',
 'crime',
 'crimes',
 'crimestrike',
 'criminals',
 'damage',
 'dangerous',
 'dead',
 'death',
 'deathbed',
 'deaths',
 'destroyed',
 'destruction',
 'die',
 'died',
 'disease',
 'error',
 'errors',
 'evil',
 'failure',
 'fault',
 'genocide',
 'harm',
 'harming',
 'hate',
 'hateful',
 'hatred',
 'hell',
 'hellish',
 'hellman',
 'hit',
 'hite',
 'hiten',
 'hurt',
 'incident',
 'incidental',
 'incidentally',
 'injury',
 "is'wrong",
 "jesus'death",
 'kill',
 'killed',
 'killing',
 "like'crime",
 'lose',
 'loss',
 'lossless',
 'lossy',
 'lost',
 "makedepend'problem",
 'massacre',
 'massacred',
 'massacres',
 'mistake',
 'murder',
 'murdered',
 'murders',
 'pain',
 'painless',
 'penalty',
 "pitchers'bad",
 'problem',
 'problems',
 'punish',
 'punishable',
 'punishes',
 'punishing',
 'punishment',
 'sick',
 'sickens',
 'sin',
 'sin66',
 'sinauer',
 'si

In [97]:
# calculate 95% vocab of POS 

pos_dominate_set = {}

for p in pos_dict:
    full_occ = float(sum(c for w, c in pos_dict[p].most_common()))
    wf_cumu = 0
    wset = []
    for w, c in pos_dict[p].most_common():
        wf_cumu += c
        if(wf_cumu / full_occ < 0.95): wset.append(w.lower())
    pos_dominate_set[p] = set(wset)
    print('dominate set size for %s = %d' % (p, len(wset)))

dominate set size for VERB = 2405
dominate set size for PRON = 33
dominate set size for AUX = 29
dominate set size for DET = 16
dominate set size for ADJ = 1520
dominate set size for NOUN = 5183
dominate set size for PUNCT = 8
dominate set size for PROPN = 4711
dominate set size for ADV = 389
dominate set size for SCONJ = 14
dominate set size for ADP = 28
dominate set size for PART = 3
dominate set size for CCONJ = 6
dominate set size for SYM = 4
dominate set size for NUM = 603
dominate set size for INTJ = 108
dominate set size for X = 69


In [119]:
len(state_dominate_set_gpt_token['243'])

161

In [118]:
len(state_dominate_set_gpt_token['243'].intersection(pos_dominate_set['ADV']))

71

In [125]:
# calculate coverage
# TODO: change vocabulary based matching to occurance based matching
# Q: why no alignment with ADV?

alignment = {}
thres = 0.5
covered_pos = []
aligned_states = 0
for s in state_dominate_set_gpt_token:
    for p in pos_dominate_set:
        overlap = state_dominate_set_gpt_token[s].intersection(pos_dominate_set[p])
        overlap = len(overlap) / float(len(state_dominate_set_gpt_token[s]))
        if(overlap > thres):
            print('state %s aligned with POS %s, overlap %.2f' % (s, p, overlap))
            covered_pos.append(p)
            aligned_states += 1
            if(p not in alignment): alignment[p] = [s]
            else: alignment[p].append(s)
covered_pos = set(covered_pos)

state 710 aligned with POS NOUN, overlap 0.51
state 254 aligned with POS PROPN, overlap 0.73
state 403 aligned with POS PROPN, overlap 0.58
state 1488 aligned with POS ADJ, overlap 0.61
state 1410 aligned with POS ADJ, overlap 0.56
state 555 aligned with POS NOUN, overlap 0.55
state 865 aligned with POS NOUN, overlap 0.53
state 1572 aligned with POS PROPN, overlap 0.56
state 1683 aligned with POS NOUN, overlap 0.55
state 1584 aligned with POS PROPN, overlap 0.53
state 1643 aligned with POS NOUN, overlap 0.60
state 1966 aligned with POS PROPN, overlap 0.72
state 1656 aligned with POS NOUN, overlap 0.65
state 1702 aligned with POS NOUN, overlap 0.63
state 603 aligned with POS PROPN, overlap 0.51
state 1874 aligned with POS PROPN, overlap 0.53
state 1181 aligned with POS NOUN, overlap 0.56
state 1973 aligned with POS NOUN, overlap 0.69
state 202 aligned with POS PROPN, overlap 0.60
state 1022 aligned with POS PROPN, overlap 0.62
state 1183 aligned with POS PROPN, overlap 0.69
state 1142 a

In [126]:
print(aligned_states)

237


**This is to say, with the dominate word set measurement, there are 237 out of 2000 states align with pre-defined POS**

## Implementation of word occurrence alignment, note that this might be a distribution shift problem here