# Comparing the overlap between inferred tags with a supervised tagger

In [1]:
import spacy
import codecs

import numpy as np 

from tokenizations import get_alignments
from transformers import BertTokenizer
from tqdm import tqdm
from collections import Counter

In [2]:
def read_data(data_path='../../data/news/'):
    """Read 20news data, train only"""
    # use the cased data for NER, otherwise spacy does not work with uncased 
    with codecs.open(data_path + '20news.txt', encoding='utf-8') as fd:
        data = fd.readlines()
    train_idx = np.load(data_path + 'train_idx.npy')
    train_data = [data[i][: -1] for i in train_idx]
    return train_data

In [3]:
train_data = read_data()

In [4]:
pos_dict = {}
ent_dict = {}
token_bert2spacy = {}

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
nlp = spacy.load("en_core_web_sm", disable='parser')

In [5]:
for s in tqdm(train_data[:1000]):
    doc = nlp(s)

    spacy_tokenized = []
    for token in doc:
        if(token.pos_ not in pos_dict):
            pos_dict[token.pos_] = [token.text]
        else:
            pos_dict[token.pos_].append(token.text)
            spacy_tokenized.append(token.text)

    for ent in doc.ents:
        if(ent.label_ not in ent_dict): ent_dict[ent.label_] = [ent.text]
        else: ent_dict[ent.label_].append(ent.text)

    bert_tokenized = tokenizer.tokenize(s)

    bert2spacy, spacy2bert = get_alignments(bert_tokenized, spacy_tokenized)
    for w, w_ in enumerate(bert2spacy):
        w = bert_tokenized[w]
        if(w not in token_bert2spacy):
            token_bert2spacy[w] = []
        for wi in w_:
            token_bert2spacy[w].append(spacy_tokenized[wi])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:07<00:00, 131.22it/s]


In [27]:
bert_tokenized

['(', 'hard', 'to', 'believe', ',', 'isn', "'", 't', 'it', '?', ')']

In [28]:
spacy_tokenized

['(', 'hard', 'to', 'believe', ',', 'is', "n't", 'it', '?', ')']

In [26]:
bert2spacy

[[0], [1], [2], [3], [4], [5, 6], [6], [6], [7], [8], [9]]

In [6]:
for p in pos_dict:
    pos_dict[p] = Counter(pos_dict[p])
    
for e in ent_dict:
    ent_dict[e] = Counter(ent_dict[e])
    
for w in token_bert2spacy:
    token_bert2spacy[w] = Counter(token_bert2spacy[w])

In [10]:
pos_dict

{'DET': Counter({'a': 262,
          'the': 658,
          'an': 43,
          'some': 22,
          'no': 40,
          'this': 50,
          'those': 5,
          'each': 1,
          'these': 12,
          'either': 1,
          'that': 10,
          'every': 6,
          'any': 17,
          'another': 5,
          'all': 15,
          'whose': 3,
          'f': 1,
          'which': 4,
          'both': 2,
          'half': 1,
          'twice': 1,
          'whatever': 1,
          'quite': 1,
          'what': 2,
          'b': 1}),
 'NUM': Counter({'hundred': 1,
          '80': 2,
          '90': 2,
          '100': 3,
          '200': 1,
          '1000': 1,
          'one': 35,
          '2': 22,
          'two': 10,
          '300': 2,
          '250': 1,
          '7': 12,
          '1919': 1,
          '3': 9,
          '4': 6,
          '5': 12,
          '1940': 1,
          '1': 14,
          '000': 2,
          '10': 6,
          '20': 8,
          '13': 2,
          '

In [7]:
print(len(pos_dict), len(ent_dict), len(token_bert2spacy))

17 15 4295


In [8]:
pos_dict.keys()

dict_keys(['DET', 'NUM', 'NOUN', 'AUX', 'VERB', 'PUNCT', 'INTJ', 'X', 'PROPN', 'CCONJ', 'PART', 'PRON', 'SCONJ', 'ADP', 'ADJ', 'SYM', 'ADV'])

In [9]:
ent_dict.keys()

dict_keys(['CARDINAL', 'QUANTITY', 'MONEY', 'ORG', 'PERSON', 'DATE', 'ORDINAL', 'NORP', 'GPE', 'TIME', 'PERCENT', 'LAW', 'PRODUCT', 'LOC', 'EVENT'])

# Read state-word dictionary

In [12]:
state_dict = {}
state_freq = {}
with codecs.open('../../local/bertnet_0.0.4.10.4_outputs/bertnet_dev_epoch_17_s2w.txt', encoding='utf-8') as fd:
    lines = fd.readlines()
    for li, l in enumerate(lines):
        if(l.startswith('state') and li %3 == 0):
            l = l[:-1].split()
            state_id = l[1]
            freq = l[3]
            freq_no_sw = l[5]
            state_freq[state_id] = {'freq': float(freq), 'freq_no_sw': float(freq_no_sw)}
            state_dict[state_id] = []
        elif(lines[li - 1].startswith('state')):
            words = l.split(' | ')
            for w in words[:-1]:
                w = w.split()
#                 print(w)
                state_dict[state_id].append((w[0], float(w[1])))

## Measurement of alignment 1, dominate word set
* Definition: W(s) 95% dominate word set = the set of words that account for the 95% of state / POS vocab.
* if W(s) \in W(POS) we say s align with a pos

## Measurement of alignment 2, word occurance 
* If 50% of state word occurance exist in a POS dominate set, then view this as an alignment

## Implementation of dominate word set alignment

In [13]:
# find non stop states 

non_sw_states = []
for s in state_freq: 
    if(state_freq[s]['freq_no_sw'] / float(state_freq[s]['freq']) > 0.5):
        non_sw_states.append(s)
print('%d non stop states' % len(non_sw_states))

1351 non stop states


In [14]:
non_sw_states[:10]

['171', '476', '710', '254', '1419', '1972', '403', '1488', '1556', '243']

In [15]:
# for each non-SW state, calculate their 95% vocab

state_dominate_set = {}
for s in non_sw_states:
    freq = state_freq[s]['freq']
    wf_cumu = 0
    state_dominate_set[s] = []
    for w, wf in state_dict[s]:
        wf_cumu += wf
        if(wf_cumu / freq < 0.9): 
            state_dominate_set[s].append(w)

In [16]:
# remove subwords clusters

def test_subword_cluster(wset):
    """if half of the cluster is subwords, then this is a subword cluster"""
    set_len = len(wset)
    num_subwords = 0
    for w in wset:
        if(w.startswith('##')): num_subwords += 1
        if(len(w) <= 2): num_subwords += 1
    if(num_subwords > set_len // 2): return True
    else: return False
    
non_sw_states = []
for s in state_dominate_set:
    if(test_subword_cluster(state_dominate_set[s])): pass
    else: non_sw_states.append(s)

In [17]:
len(non_sw_states)

1060

In [18]:
# convert dominate set to gpt2 tokenization

state_dominate_set_gpt_token = {}
for s in non_sw_states:
    w_ = []
    for w in state_dominate_set[s]:
        if(w in token_bert2spacy):
            spacy_tokens_w = token_bert2spacy[w].keys()
            spacy_tokens_w = set([t.lower() for t in spacy_tokens_w])
            w_.extend(list(spacy_tokens_w))
        else: w_.append(w)
    state_dominate_set_gpt_token[s] = set(w_)

In [19]:
non_sw_states[:10]

['476', '710', '254', '1419', '1972', '403', '1488', '1556', '243', '1410']

In [20]:
state_dominate_set_gpt_token['710']

{'abuse',
 'accident',
 'attack',
 'attacks',
 'bad',
 'badwindow',
 'conflict',
 'crash',
 'crime',
 'crimes',
 'criminals',
 'damage',
 'dangerous',
 'dead',
 'death',
 'deathbed',
 'deaths',
 'destroyed',
 'destruction',
 'die',
 'died',
 'disease',
 'error',
 'errored',
 'errors',
 'evil',
 'failure',
 'fault',
 'genocide',
 'harm',
 'hate',
 'hatred',
 'hell',
 'hellbound',
 'hit',
 'hurt',
 'incident',
 'injury',
 'kill',
 'killed',
 'killfile',
 'killing',
 'lose',
 'loss',
 'lost',
 'massacre',
 'mistake',
 'murder',
 'murdered',
 'murders',
 'pain',
 'penalty',
 'problem',
 'problems',
 'punish',
 'punishment',
 'sick',
 'sin',
 'sinful',
 'sorry',
 'suicidede',
 'trouble',
 'violence',
 'worse',
 'worst',
 'wrong'}

In [21]:
# calculate 95% vocab of POS 

pos_dominate_set = {}

for p in pos_dict:
    full_occ = float(sum(c for w, c in pos_dict[p].most_common()))
    wf_cumu = 0
    wset = []
    for w, c in pos_dict[p].most_common():
        wf_cumu += c
        if(wf_cumu / full_occ < 0.95): wset.append(w.lower())
    pos_dominate_set[p] = set(wset)
    print('dominate set size for %s = %d' % (p, len(wset)))

dominate set size for DET = 7
dominate set size for NUM = 142
dominate set size for NOUN = 1585
dominate set size for AUX = 24
dominate set size for VERB = 768
dominate set size for PUNCT = 8
dominate set size for INTJ = 21
dominate set size for X = 27
dominate set size for PROPN = 670
dominate set size for CCONJ = 3
dominate set size for PART = 3
dominate set size for PRON = 31
dominate set size for SCONJ = 16
dominate set size for ADP = 21
dominate set size for ADJ = 482
dominate set size for SYM = 3
dominate set size for ADV = 160


In [22]:
len(state_dominate_set_gpt_token['243'])

67

In [23]:
len(state_dominate_set_gpt_token['243'].intersection(pos_dominate_set['ADV']))

37

In [24]:
# calculate coverage
# TODO: change vocabulary based matching to occurance based matching
# Q: why no alignment with ADV?

alignment = {}
thres = 0.5
covered_pos = []
aligned_states = 0
for s in state_dominate_set_gpt_token:
    for p in pos_dominate_set:
        overlap = state_dominate_set_gpt_token[s].intersection(pos_dominate_set[p])
        overlap = len(overlap) / float(len(state_dominate_set_gpt_token[s]))
        if(overlap > thres):
            print('state %s aligned with POS %s, overlap %.2f' % (s, p, overlap))
            covered_pos.append(p)
            aligned_states += 1
            if(p not in alignment): alignment[p] = [s]
            else: alignment[p].append(s)
covered_pos = set(covered_pos)

state 254 aligned with POS PROPN, overlap 0.55
state 1972 aligned with POS NOUN, overlap 0.61
state 1488 aligned with POS ADJ, overlap 0.54
state 243 aligned with POS ADV, overlap 0.55
state 1410 aligned with POS ADJ, overlap 0.59
state 305 aligned with POS NOUN, overlap 0.60
state 127 aligned with POS NOUN, overlap 0.58
state 568 aligned with POS NOUN, overlap 0.52
state 555 aligned with POS NOUN, overlap 0.61
state 935 aligned with POS ADJ, overlap 0.54
state 66 aligned with POS NOUN, overlap 0.56
state 1514 aligned with POS NOUN, overlap 0.51
state 1291 aligned with POS NOUN, overlap 0.51
state 1579 aligned with POS NOUN, overlap 0.52
state 1461 aligned with POS NOUN, overlap 0.54
state 1064 aligned with POS NOUN, overlap 0.55
state 1286 aligned with POS NOUN, overlap 0.54
state 1450 aligned with POS NOUN, overlap 0.55
state 123 aligned with POS NOUN, overlap 0.52
state 1967 aligned with POS NOUN, overlap 0.62
state 165 aligned with POS NOUN, overlap 0.67
state 1816 aligned with POS

In [25]:
print(aligned_states)

58


**This is to say, with the dominate word set measurement, there are 237 out of 2000 states align with pre-defined POS**

## Implementation of word occurrence alignment, note that this might be a distribution shift problem here