In [1]:
import re
import time
import string
import pickle
import numpy as np
import pandas as pd

from datasets import load_dataset

from tqdm.notebook import tqdm
from collections import Counter

from sklearn import metrics
from sklearn.metrics import classification_report, f1_score

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [3]:
dataset = pickle.load(open('conll_graph_all.pickle', 'rb'))
print(', '.join([split + f' : {len(dataset[split])}' for split in dataset]))

train : 178610, validation : 44900, test : 40760


In [8]:
word2labels = {}
for split in dataset:
    for doc in dataset[split]:
        if doc['word'][0] not in word2labels:
            word2labels[doc['word'][0]] = []
        word2labels[doc['word'][0]].append(doc['label'])

In [15]:
counter = 0
anticounter = 0
for w in word2labels:
    if len(set(word2labels[w])) > 1:
        print(w, Counter(word2labels[w]).most_common())
        counter += 1
    else:
        anticounter += 1

german [('MISC', 139), ('ORG', 2), ('PER', 1)]
british [('MISC', 105), ('ORG', 21), ('LOC', 5), ('O', 1)]
blackburn [('ORG', 16), ('PER', 1)]
brussels [('LOC', 44), ('MISC', 2), ('ORG', 2)]
<UNK> [('O', 30339), ('PER', 1146), ('ORG', 695), ('MISC', 372), ('LOC', 188)]
the [('O', 12228), ('ORG', 34), ('LOC', 30), ('MISC', 17), ('PER', 1)]
european [('MISC', 107), ('ORG', 37)]
commission [('ORG', 51), ('O', 24), ('MISC', 4)]
said [('O', 2690), ('PER', 4)]
on [('O', 3108), ('MISC', 6), ('ORG', 1)]
shun [('O', 1), ('MISC', 1)]
disease [('O', 39), ('MISC', 2)]
can [('O', 100), ('LOC', 1)]
germany [('LOC', 240), ('ORG', 2)]
s [('O', 2277), ('ORG', 54), ('LOC', 7), ('MISC', 4), ('PER', 1)]
union [('ORG', 70), ('O', 33), ('LOC', 7)]
veterinary [('O', 8), ('ORG', 1)]
committee [('O', 32), ('ORG', 9), ('MISC', 2)]
wednesday [('O', 287), ('ORG', 13)]
countries [('O', 54), ('MISC', 1)]
than [('O', 219), ('PER', 1)]
britain [('LOC', 167), ('ORG', 1)]
do [('O', 144), ('PER', 3), ('ORG', 1)]
n't [('O

blue [('O', 9), ('ORG', 5)]
brewers [('ORG', 7), ('O', 6)]
jose [('PER', 32), ('LOC', 8), ('ORG', 2)]
eindhoven [('ORG', 15), ('LOC', 1)]
arthur [('PER', 3), ('ORG', 2), ('LOC', 1)]
numan [('PER', 2), ('O', 1)]
defence [('O', 39), ('ORG', 4)]
breda [('ORG', 10), ('LOC', 1)]
swiss [('MISC', 42), ('ORG', 2), ('O', 1)]
hamburg [('ORG', 12), ('LOC', 7)]
antonio [('PER', 11), ('ORG', 3), ('LOC', 1)]
stuttgart [('ORG', 10), ('LOC', 6)]
lausanne [('LOC', 1), ('ORG', 1)]
milan [('ORG', 27), ('LOC', 11), ('PER', 2)]
house [('O', 44), ('ORG', 17), ('LOC', 17)]
grand [('MISC', 46), ('O', 20), ('LOC', 1)]
glass [('O', 5), ('PER', 1)]
carl [('PER', 11), ('MISC', 2)]
max [('PER', 11), ('O', 1)]
miles [('O', 72), ('PER', 4)]
des [('PER', 5), ('LOC', 5)]
miguel [('PER', 14), ('ORG', 2)]
williams [('PER', 17), ('ORG', 8)]
robert [('PER', 41), ('ORG', 2)]
uefa [('MISC', 10), ('ORG', 5)]
formula [('O', 4), ('ORG', 4), ('MISC', 1)]
yellow [('O', 20), ('MISC', 1)]
greece [('LOC', 22), ('ORG', 1)]
police [(

pace [('O', 11), ('PER', 8)]
ferrari [('ORG', 6), ('MISC', 2)]
irvine [('PER', 3), ('LOC', 2)]
bradford [('ORG', 17), ('PER', 1)]
wigan [('ORG', 15), ('LOC', 1)]
helens [('ORG', 10), ('LOC', 1)]
sheffield [('ORG', 21), ('MISC', 5), ('PER', 3), ('LOC', 2)]
leeds [('ORG', 23), ('LOC', 2), ('MISC', 1)]
speak [('O', 17), ('PER', 2)]
dundee [('ORG', 14), ('PER', 2)]
morton [('ORG', 6), ('PER', 1)]
johnstone [('ORG', 6), ('PER', 1)]
hamilton [('ORG', 5), ('PER', 3), ('LOC', 1)]
inverness [('ORG', 6), ('LOC', 1)]
ross [('PER', 9), ('ORG', 6)]
newcastle [('ORG', 18), ('LOC', 1)]
villa [('ORG', 8), ('PER', 3), ('O', 2)]
derby [('ORG', 7), ('O', 1)]
ham [('ORG', 4), ('O', 1)]
southampton [('ORG', 9), ('LOC', 1)]
birmingham [('ORG', 12), ('LOC', 6)]
stoke [('ORG', 10), ('O', 1)]
vale [('ORG', 14), ('PER', 1)]
bristol [('ORG', 19), ('LOC', 3)]
chesterfield [('ORG', 8), ('LOC', 3)]
preston [('ORG', 8), ('PER', 5), ('LOC', 1)]
hull [('ORG', 8), ('O', 2)]
doncaster [('ORG', 7), ('MISC', 1)]
beograd [

In [17]:
counter

1698

In [12]:
word2categories = pickle.load(open('word2categories.pickle', 'rb'))

In [33]:
cn = pd.read_csv('../conceptnet_en.csv')
cn_words = set([str(w).replace('_', '-') for w in cn.subject.unique().tolist()])

In [10]:
conll_dataset = load_dataset("conll2003")

Reusing dataset conll2003 (/opt/tmp/huggingface/datasets/conll2003/conll2003/1.0.0/26b70ce2b0f32cb35a27151dbfa2dbe88c82bcdaf8f29433bcdc612a9b314e83)


In [18]:
all_tokens = []
all_pos = []
all_chunks = []

for split in conll_dataset:
    for doc in tqdm(conll_dataset[split], desc=f'Loading split {split}'):
            for i, (token, pos, chunk, label) in enumerate(zip(doc['words'], doc['pos'], doc['chunk'], doc['ner'])):
                all_tokens.append(token)
                all_pos.append(pos)
                all_chunks.append(chunk)

HBox(children=(FloatProgress(value=0.0, description='Loading split train', max=14041.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Loading split validation', max=3250.0, style=ProgressStyl…




HBox(children=(FloatProgress(value=0.0, description='Loading split test', max=3453.0, style=ProgressStyle(desc…




In [253]:
def replace_nums(s):
    numbers = '0123456789'
    nums = []
    for c in s:
        if c in numbers:
            nums.append(c)
        else:
            break
    if len(nums) < len(s):
        return '<NUM>' + s[len(nums):]
    else:
        return s

In [254]:
replace_nums('1990s')

'<NUM>s'

In [303]:
train_words = []
all_words = []
for split in conll_dataset:
    for doc in tqdm(conll_dataset[split], desc=f'Loading split ' + split):
        if split == 'train':
            train_words.extend(doc['words'])
        all_words.extend(doc['words'])

HBox(children=(FloatProgress(value=0.0, description='Loading split train', max=14041.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Loading split validation', max=3250.0, style=ProgressStyl…




HBox(children=(FloatProgress(value=0.0, description='Loading split test', max=3453.0, style=ProgressStyle(desc…




In [304]:
train_words_counter = Counter([w.lower() for w in train_words])
all_words_counter = Counter([w.lower() for w in all_words])

In [317]:
[w for w in all_words_counter if w not in train_words_counter][:100]

['296',
 'discard',
 '213',
 'bundle',
 '372',
 '37-run',
 'scuttled',
 'dumped',
 'butcher',
 '429',
 '234',
 '123',
 "o'gorman",
 '471',
 '233',
 'frustration',
 'dismiss',
 '214',
 'gritty',
 'ex-england',
 'mccague',
 'stumps',
 '4-38',
 'dale',
 'blenkiron',
 '4-43',
 '4-55',
 '108-3',
 '429-7',
 'kersey',
 'ratcliffe',
 'bicknell',
 '4-37',
 '197-8',
 'hegg',
 '109-5',
 '133-5',
 "t.o'gorman",
 '6-82',
 '185-6',
 'ashes',
 'intinerary',
 'six-test',
 'itinerary',
 'duke',
 'norfolk',
 'arundel',
 '27-29',
 '5-9',
 '14-16',
 '19-23',
 '25-27',
 '28-30',
 '19-21',
 '24-28',
 '21-25',
 'costliest',
 'footballer',
 '23.4',
 'platt',
 'captaincy',
 'mates',
 'teddy',
 'sheringham',
 '92-90',
 '47-47',
 'cosmin',
 'contra',
 'mihai',
 'tararache',
 'danius',
 'gleveckas',
 '13rd',
 'bottles',
 'gorokhovsky',
 'boxing',
 'duran',
 'sands',
 'panamanian',
 'age-defying',
 'sustain',
 'little-known',
 'ariel',
 'middleweight',
 'non-title',
 'bout',
 'boxer',
 'classes',
 'lightweight',
 

In [307]:
graph_dataset = {}

non_alpha = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '`',
             '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '=', '?', '@', '[', ']']

vocabulary = {'word':set(), 'chunk': set(), 'pos':set(), 'classes':set(), 'extra':set()}
labels = set()
ignored = list()

for split in conll_dataset:
    graph_dataset[split] = []
    print(split.upper())
    for doc in tqdm(conll_dataset[split], desc=f'Loading split {split}'):
        graph = []
        doc = zip(doc['words'], doc['pos'], doc['chunk'], doc['ner'])
        for i, (word, pos, chunk, label) in enumerate(doc):

            if pos in [":", "''", ')', '.', '"', '(', ',']:
                continue
            
            for pos_prefix in ['WP', 'NN', 'VB', 'PR', 'JJ']:
                pos = pos_prefix if pos.startswith(pos_prefix) else pos
                    
            pos = '<' + pos.upper() + '>'
            chunk = '<' + chunk.split('-')[-1].upper() + '>'
            gt_label = label
            label = gt_label.split('-')[-1]

            surface = word
            word = surface.lower().replace('`', "'")
            
            if word.endswith('='):
                word = word[:-1]

            while word and word[0] in "!$%&'*+,-.:;<=>?@`":
                word = word[1:]

            
            if all([c in non_alpha for c in word]):
                word = '<NUM>'

            if word not in cn_words and \
               word not in word2categories and \
               all_words_counter[word] < 3:
                word = '<UNK>'
                ignored.append(word)
                        
            if word[0] in '0123456789':
                word = replace_nums(word)
            
            extra = []
            if word.count('.') > 0 and (word.count('.') + 1) == len(word.split('.')): # C.J or C.J.
                extra.append('<ACRONYM>')
            if surface == surface.upper():
                extra.append('<ALL CAPS>')
            if surface[0] == surface[0].upper() and surface[1:] == surface[1:].lower(): 
                extra.append('<CAPITALIZED>')
            
            classes = []
            # add classes only when the word is uppercased
            if surface[0] == surface[0].upper() and word in word2categories:
                classes = ['<'+l.upper()+'>' for l in word2categories[word]]
            
            graph.append({'word': [word], 
                          'label': label,
                          'gt_label': gt_label,
                          'surface': surface, 
                          'pos': [pos], 
                          'chunk': [chunk], 
                          'classes': classes,
                          'extra': extra})
            
            labels.add(label)
            vocabulary['word'].add(word)
            vocabulary['pos'].add(pos)
            vocabulary['chunk'].add(chunk)
            for tag in classes: vocabulary['classes'].add(tag)
            for tag in extra: vocabulary['extra'].add(tag)
        
        for i, node in enumerate(graph):
            graph[i]['left_context'] = [n['word'][0] for n in graph[:i]]
            graph[i]['right_context'] = [n['word'][0] for n in graph[i+1:]]
        
        graph_dataset[split].extend(graph)

TRAIN


HBox(children=(FloatProgress(value=0.0, description='Loading split train', max=14041.0, style=ProgressStyle(de…


VALIDATION


HBox(children=(FloatProgress(value=0.0, description='Loading split validation', max=3250.0, style=ProgressStyl…


TEST


HBox(children=(FloatProgress(value=0.0, description='Loading split test', max=3453.0, style=ProgressStyle(desc…




In [308]:
graph_dataset['train'][2]

{'word': ['german'],
 'label': 'MISC',
 'gt_label': 'B-MISC',
 'surface': 'German',
 'pos': ['<JJ>'],
 'chunk': ['<NP>'],
 'classes': ['<GEOREGION>', '<NAME>', '<GIVEN NAME>', '<FAMILY NAME>'],
 'extra': ['<CAPITALIZED>'],
 'left_context': ['eu', 'rejects'],
 'right_context': ['call', 'to', 'boycott', 'british', 'lamb']}

In [309]:
len(vocabulary['word'])

18993

In [310]:
len(ignored)

32740

In [311]:
len(vocabulary['word'])

18993

In [312]:
pickle.dump(graph_dataset, open('conll_graph_all.pickle', 'wb'))

In [313]:
for key in vocabulary:
    vocabulary[key] = sorted(vocabulary[key])

In [314]:
pickle.dump(vocabulary, open('vocabulary_all.pickle', 'wb'))

In [315]:
sorted(labels)

['LOC', 'MISC', 'O', 'ORG', 'PER']

In [316]:
pickle.dump(sorted(labels), open('labels.pickle', 'wb'))

In [302]:
len(set(rare_words['test']).intersection(set(rare_words['train'])))

291

In [19]:
len(set(all_tokens)), len(set(all_pos)), len(set(all_chunks))

(30289, 45, 21)

In [21]:
print(set(all_pos))

{'WRB', 'NNPS', 'RBS', 'TO', 'FW', "''", ':', '$', 'JJ', 'WP$', 'IN', 'VBN', 'DT', 'VB', 'UH', 'PRP$', ')', '.', 'POS', 'VBP', 'PRP', 'RP', 'NNP', 'VBD', 'WP', 'CD', 'LS', 'MD', '"', 'JJS', '(', 'JJR', 'WDT', 'VBZ', 'NN|SYM', 'PDT', 'EX', 'NNS', 'NN', 'SYM', 'RBR', 'VBG', 'RB', 'CC', ','}


In [126]:
Counter(all_pos).most_common()

[('NNP', 51545),
 ('NN', 34856),
 ('CD', 29962),
 ('IN', 28059),
 ('DT', 19773),
 ('JJ', 17267),
 ('NNS', 14580),
 ('VBD', 12222),
 ('.', 10898),
 (',', 10877),
 ('VB', 6304),
 ('VBN', 5964),
 ('RB', 5852),
 ('CC', 5350),
 ('TO', 5193),
 ('PRP', 4630),
 ('(', 4233),
 (')', 4232),
 ('VBG', 3769),
 (':', 3609),
 ('VBZ', 3439),
 ('"', 3239),
 ('POS', 2323),
 ('PRP$', 2238),
 ('VBP', 2132),
 ('MD', 1767),
 ('NNPS', 1010),
 ('RP', 784),
 ('WDT', 769),
 ('WP', 769),
 ('SYM', 642),
 ('$', 622),
 ('JJR', 579),
 ('WRB', 551),
 ('JJS', 388),
 ('RBR', 259),
 ('FW', 228),
 ('EX', 210),
 ('RBS', 62),
 ("''", 60),
 ('PDT', 47),
 ('UH', 42),
 ('WP$', 41),
 ('LS', 37),
 ('NN|SYM', 5)]

In [22]:
print(set(all_chunks))

{'B-PP', 'I-CONJP', 'I-ADVP', 'I-NP', 'B-NP', 'B-CONJP', 'O', 'B-ADJP', 'B-ADVP', 'B-PRT', 'I-PRT', 'I-ADJP', 'I-PP', 'I-SBAR', 'B-SBAR', 'I-INTJ', 'B-VP', 'B-LST', 'I-VP', 'I-LST', 'B-INTJ'}


In [37]:
raw_voc = set([w.lower() for w in all_tokens])
len(raw_voc)

26869

In [54]:
oov_words = []
non_alpha = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '`',
             '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '=', '?', '@', '[', ']']
probably_numbers = []

for word in tqdm(raw_voc):
    if word not in cn_words and word not in word2categories:
        if all([c in non_alpha for c in word]):
            probably_numbers.append(word)
        else:
            oov_words.append(word)

HBox(children=(FloatProgress(value=0.0, max=26869.0), HTML(value='')))




In [55]:
len(oov_words)

3544

In [56]:
probably_numbers

['52.82',
 '9:30',
 '1:46.18',
 '52.05',
 '5.79',
 '91.36',
 '.522',
 '7-178',
 '24.05',
 '11.55',
 '482.13',
 '.413',
 '2-70',
 '274.4',
 '138.0-149.0',
 '1.59',
 '0.6',
 '42-36',
 '60.92',
 '38:43.253',
 '5-160',
 '8188.2',
 '1,002',
 '407,748',
 '9-0-49-1',
 '1/8',
 '5.60',
 '11,900-12,100',
 '30,300-30,400',
 '2006',
 '12.211',
 '13:06.12',
 '115.0',
 '10/01',
 '102',
 '5452',
 '0.25-1.00',
 '22.58',
 '8,000',
 '439.9',
 '003',
 '21-11',
 '3.03',
 '3-0-11-1',
 '11.3',
 '302',
 '-----',
 '13:06.65',
 '212-859-1650',
 '113.0',
 '4-77',
 '0.79',
 '95/96',
 '19996/97',
 '55.13',
 '1:45.27',
 '4-146',
 '20,167',
 '.359',
 '4.85',
 '3,390,000',
 '1.97',
 '19999.777',
 '30,700-30,800',
 '12:01',
 '33-1',
 '1.91',
 '662',
 '2173',
 '0.96',
 '10-3-19-0',
 '49.31',
 '233,713',
 '1.03',
 '1:19.44',
 '7.73',
 '2.2-0-13-0',
 '47,600',
 '4.50',
 '208,978',
 '7,138',
 ',',
 '2.40',
 '+3.7',
 '746,000',
 '610,000',
 '0.75',
 '2-7',
 '7630',
 '429-7',
 '992,860',
 '15:07.85',
 '8.15',
 '192',
 '7-1

In [53]:
print(sorted(set([c for w in oov_words for c in w])))

['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '=', '?', '@', '[', ']', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [None]:
print(sorted(set([c for w in oov_words for c in w])))

In [79]:
not_accounted = 0
for w in sorted(oov_words):
    if words_counter[w] > 10:
        print(w, '\t', words_counter[w])
    else:
        not_accounted +=1

'll 	 24
'm 	 41
're 	 40
's 	 2339
've 	 29
a$ 	 20
a.m. 	 24
agassi 	 11
atletico 	 21
auxerre 	 14
banharn 	 13
banisadr 	 11
benetton 	 16
bistrita 	 12
brunswijk 	 15
c$ 	 29
cofinec 	 18
corp. 	 18
dinamo 	 11
f.c. 	 14
feyenoord 	 20
first-round 	 12
five-year 	 12
four-day 	 18
gencor 	 13
graafschap 	 11
guingamp 	 11
haitai 	 13
hanwha 	 13
hapoel 	 30
inzamam-ul-haq 	 14
ivac 	 11
jansher 	 15
jayasuriya 	 14
juppe 	 11
juventus 	 12
kdp 	 29
maskhadov 	 17
masterkova 	 11
newmont 	 15
nymex 	 14
philippoussis 	 18
rkc 	 11
rtrs 	 14
ruutel 	 11
schalke 	 16
shr 	 13
ssangbangwool 	 14
st. 	 21
suu 	 16
three-run 	 14
tranmere 	 11
two-run 	 17
u.n. 	 71
u.s. 	 558
under-21 	 16
vitesse 	 11
volendam 	 14
