In [None]:
! pip -q install datasets

In [None]:
import re
import time
import pickle
import pandas as pd
import numpy as np
from collections import Counter
from tqdm.notebook import tqdm

In [None]:
from datasets import load_dataset

dataset = load_dataset("conll2003")

In [None]:
cn = pd.read_csv('conceptnet_en.csv')
cn_isa = pd.read_csv('data/conceptnet_isa.csv')

In [None]:
cn_isa.object.unique()

In [None]:
cn_keys = set(cn.subject.values)

In [None]:
len(cn_keys)

In [None]:
cn.subject.values[1800:1823]

In [None]:
word2labels = {}
for w in cn_isa.subject.unique():
    subcn = cn_isa[cn_isa.subject == w]
    word2labels[w] = subcn.object.values.tolist()

In [None]:
pickle.dump(word2labels, open('edges/word2labels.pickle', 'wb'))

In [None]:
! ls data

In [None]:
dataset['train']['words'][:3], dataset['train']['pos'][:3], dataset['train']['ner'][:3] 

In [None]:
''.join(['!', '$', '%', '&', "'", '*', '+', ',', '-', '.', ':', ';', '<', '=', '>', '?', '@', '`'])

In [None]:
%%time
punctuation = ['!', '$', '%', '&', "'", '*', '+', ',', '-', '.', ':', ';', '<', '=', '>', '?', '@', '`']

vocabulary = {}
data = {}

special_cases = []
special_num_cases = []
special_O_cases = []

for split in['train', 'validation', 'test']:
    # print(split)
    data[split] = []
    vocabulary[split] = set()
    
    for doc in tqdm(dataset[split], desc=split.upper()):
        tokens, labels, extras = [], [], []
        
        for token, pos, label in zip(doc['words'], doc['pos'], doc['ner']):
            if token == pos:
                continue # this is punctuation
            
            elif pos == ',':
                pos = 'NNP'
            
            if token.endswith('='):
                token = token[:-1]
            
            while token and token[0] in punctuation:
                token = token[1:]
            
            token = re.sub(r'\d+', '<NUM>', token)
            token = token.replace('`', "'")
            
            if not token:
                continue
            
            if all([c in ',.-' for c in token.split('<NUM>')]):
                special_num_cases.append((token, label))
                token = '<NUM>'
            elif not token.isalpha() and label != 'O':
                special_cases.append((token, label))
            elif not token.isalpha() and label == 'O':
                special_O_cases.append((token, label))
            
            extra = ['<'+pos.lower()+'>']
            if token.lower() in word2labels:
                extra.extend(['<'+l.lower()+'>' for l in word2labels[token.lower()]])
                
            if token.lower() not in cn_keys:
                extra.append('<not_in_dict>')
            if token == token.upper():
                extra.append('<all_caps>')
            if token.count('.') > 0 and (token.count('.') + 1) == len(token.split('.')): # C.J or C.J.
                extra.append('<accronym>')
            if token[0] == token[0].upper() and token[:1] == token[:1].lower(): 
                extra.append('<capitalized>')
                
                
            vocabulary[split].add(token.lower())
            tokens.append(token)
            labels.append(label)
            extras.append(extra)
            

        data[split].append((tokens, labels, extras))

In [None]:
data['train'][0]

In [None]:
extra_vocab = list(set([e for example in data['train'] for l in example[2] for e in l]))
print(extra_vocab)

In [None]:
# print(set([x[0] for x in special_O_cases]))

In [None]:
all_voc = set([w for split in vocabulary for w in vocabulary[split]])
print(len(all_voc))

In [None]:
[len(vocabulary[split]) for split in vocabulary]

In [None]:
all_words = [x for l in data['train'] for x in l[0]]
train_counter = Counter(all_words)

In [None]:
train_labels_counter = Counter([x for l in data['test'] for x in l[1]])
train_labels_counter

In [None]:
intials = []
accronyms = []
whatelse = []
hyphenated = []

for term, label in special_cases:
    if term == term.upper() and term.count('.') > 0 and term.count('.') == len(term.split('.')) - 1 and len(term) <= 2:
        intials.append((term, label))
    elif term == term.upper() and term.count('.') > 0 and term.count('.') == len(term.split('.')) - 1 and len(term) > 2:
        accronyms.append((term, label))
    elif '-' in term and len(term.split('-')) > 1 and  (term.split('-')[0] == term.split('-')[0].lower() or (term.split('-')[1] == term.split('-')[1].lower())):
        hyphenated.append((term, label))
    else:
        whatelse.append((term, label))
# print('\n'.join(str(c) for c in set(whatelse)))
print(len(whatelse))

# Build the edgelist

In [None]:
final_vocab = sorted(vocabulary['train']) + ['<span>'] + sorted(extra_vocab)
len(final_vocab)

In [None]:
extra_vocab

In [None]:
final_vocab[:150]

In [None]:
word2id = {w:i for i,w in enumerate(final_vocab)}

In [None]:
word2id['ismail']

In [None]:
len(word2id)

In [None]:
pickle.dump(final_vocab, open('edges/vocabulary.pickle', 'wb'))

In [None]:
before_edges = {w: [] for w in final_vocab}
after_edges  = {w: [] for w in final_vocab}
isa_edges    = {w: [] for w in final_vocab}
vocab_dict   = {w: [] for w in final_vocab}

window_size = 2

for split in data:
    for example in tqdm(data[split], desc=split.upper()):
        text = [w.lower() for w in example[0]]
        for i, word in enumerate(text):
            term = word.lower()
            if term not in vocab_dict: # new words appearing only in the eval and test
                term = '<span>'
            left_context  = text[max(i-window_size, 0):i] + ([] if i >= window_size else ['<span>'])
            right_context = text[i+1:i+1+window_size] + ([] if i + window_size < len(text) else ['<span>'])
            left_context  = [w if w in vocab_dict else '<span>' for w in left_context]
            right_context = [w if w in vocab_dict else '<span>' for w in right_context]
            isa_context   = example[2][i]
            
            before_edges[term].extend(right_context)
            after_edges[term].extend(left_context)
            isa_edges[term].extend(isa_context)



In [None]:
[l[0] for l in sorted(isa_edges.items(), key=lambda k: len(set(k[1])), reverse=True)[:20]]

In [None]:
edge_list_before = []
edge_list_after  = []
edge_list_isa    = []

for word in vocab_dict:
    edge_list_before.extend((word2id[word], word2id[w]) for w in before_edges[word])
    edge_list_after.extend((word2id[word], word2id[w]) for w in after_edges[word])
    edge_list_isa.extend((word2id[word], word2id[w]) for w in isa_edges[word])

In [None]:
len(edge_list_before), len(edge_list_after), len(edge_list_isa)

In [None]:
len(set(edge_list_before)), len(set(edge_list_after)), len(set(edge_list_isa))

In [None]:
len_all_edges = len(set(edge_list_before)) + len(set(edge_list_after)) + len(set(edge_list_isa))
edges_list_unique = sorted(set(edge_list_before).union(set(edge_list_after)))
len_unique_context_edges = len(edges_list_unique)
len_all_edges, len_unique_context_edges, len_all_edges - len_unique_context_edges - len(set(edge_list_isa))

In [None]:
edges_list_all = sorted(set(edge_list_before).union(set(edge_list_after).union(set(edge_list_isa))))

In [None]:
len(edges_list_all)

In [None]:
! ls

In [None]:
Counter(edge_list_before).most_common(10)

In [None]:
final_vocab[7470]

In [None]:
final_vocab[:100]

In [None]:
edge_lists = {'before_edges': edge_list_before,
              'after_edges': edge_list_after,
              'isa_edges': edge_list_isa,
              'context_edges': edges_list_unique,
              'all_edges': edges_list_all,
              }

In [None]:
for filename in edge_lists:
    pickle.dump(edge_lists[filename], open('edges/' + filename + '.pickle', 'wb'))

In [None]:
for filename in edge_lists:
    with open('edges/' + filename + '.edgelist', 'w') as f:
        for s, t in edge_lists[filename]:
            f.write(f'{s} {t}\r')