In [None]:
import re
import time
import string
import pickle
import numpy as np
import pandas as pd

from datasets import load_dataset

from tqdm.notebook import tqdm
from collections import Counter

from sklearn import metrics
from sklearn.metrics import classification_report, f1_score

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
dataset = pickle.load(open('conll_graph_all.pickle', 'rb'))
print(', '.join([split + f' : {len(dataset[split])}' for split in dataset]))

In [None]:
word2labels = {}
for split in dataset:
    for doc in dataset[split]:
        if doc['word'][0] not in word2labels:
            word2labels[doc['word'][0]] = []
        word2labels[doc['word'][0]].append(doc['label'])

In [None]:
word2labels

In [None]:
counter = 0
anticounter = 0
for w in word2labels:
    if len(set(word2labels[w])) > 1:
        print(w, Counter(word2labels[w]).most_common())
        counter += 1
    else:
        anticounter += 1

In [None]:
counter

In [None]:
word2categories = pickle.load(open('word2categories.pickle', 'rb'))

In [None]:
cn = pd.read_csv('../conceptnet_en.csv')
cn_words = set([str(w).replace('_', '-') for w in cn.subject.unique().tolist()])

In [None]:
conll_dataset = load_dataset("conll2003")

In [None]:
all_tokens = []
all_pos = []
all_chunks = []

for split in conll_dataset:
    for doc in tqdm(conll_dataset[split], desc=f'Loading split {split}'):
            for i, (token, pos, chunk, label) in enumerate(zip(doc['words'], doc['pos'], doc['chunk'], doc['ner'])):
                all_tokens.append(token)
                all_pos.append(pos)
                all_chunks.append(chunk)

In [None]:
def replace_nums(s):
    numbers = '0123456789'
    nums = []
    for c in s:
        if c in numbers:
            nums.append(c)
        else:
            break
    if len(nums) < len(s):
        return '<NUM>' + s[len(nums):]
    else:
        return s

In [None]:
replace_nums('1990s')

In [None]:
train_words = []
all_words = []
for split in conll_dataset:
    for doc in tqdm(conll_dataset[split], desc=f'Loading split ' + split):
        if split == 'train':
            train_words.extend(doc['words'])
        all_words.extend(doc['words'])

In [None]:
train_words_counter = Counter([w.lower() for w in train_words])
all_words_counter = Counter([w.lower() for w in all_words])

In [None]:
[w for w in all_words_counter if w not in train_words_counter][:100]

In [None]:
graph_dataset = {}

non_alpha = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '`',
             '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '=', '?', '@', '[', ']']

vocabulary = {'word':set(), 'chunk': set(), 'pos':set(), 'classes':set(), 'extra':set()}
labels = set()
ignored = list()

for split in conll_dataset:
    graph_dataset[split] = []
    print(split.upper())
    for doc in tqdm(conll_dataset[split], desc=f'Loading split {split}'):
        graph = []
        doc = zip(doc['words'], doc['pos'], doc['chunk'], doc['ner'])
        for i, (word, pos, chunk, label) in enumerate(doc):

            if pos in [":", "''", ')', '.', '"', '(', ',']:
                continue
            
            for pos_prefix in ['WP', 'NN', 'VB', 'PR', 'JJ']:
                pos = pos_prefix if pos.startswith(pos_prefix) else pos
                    
            pos = '<' + pos.upper() + '>'
            chunk = '<' + chunk.split('-')[-1].upper() + '>'
            gt_label = label
            label = gt_label.split('-')[-1]

            surface = word
            word = surface.lower().replace('`', "'")
            
            if word.endswith('='):
                word = word[:-1]

            while word and word[0] in "!$%&'*+,-.:;<=>?@`":
                word = word[1:]

            
            if all([c in non_alpha for c in word]):
                word = '<NUM>'

            if word not in cn_words and \
               word not in word2categories and \
               all_words_counter[word] < 3:
                word = '<UNK>'
                ignored.append(word)
                        
            if word[0] in '0123456789':
                word = replace_nums(word)
            
            extra = []
            if word.count('.') > 0 and (word.count('.') + 1) == len(word.split('.')): # C.J or C.J.
                extra.append('<ACRONYM>')
            if surface == surface.upper():
                extra.append('<ALL CAPS>')
            if surface[0] == surface[0].upper() and surface[1:] == surface[1:].lower(): 
                extra.append('<CAPITALIZED>')
            
            classes = []
            # add classes only when the word is uppercased
            if surface[0] == surface[0].upper() and word in word2categories:
                classes = ['<'+l.upper()+'>' for l in word2categories[word]]
            
            graph.append({'word': [word], 
                          'label': label,
                          'gt_label': gt_label,
                          'surface': surface, 
                          'pos': [pos], 
                          'chunk': [chunk], 
                          'classes': classes,
                          'extra': extra})
            
            labels.add(label)
            vocabulary['word'].add(word)
            vocabulary['pos'].add(pos)
            vocabulary['chunk'].add(chunk)
            for tag in classes: vocabulary['classes'].add(tag)
            for tag in extra: vocabulary['extra'].add(tag)
        
        for i, node in enumerate(graph):
            graph[i]['left_context'] = [n['word'][0] for n in graph[:i]]
            graph[i]['right_context'] = [n['word'][0] for n in graph[i+1:]]
        
        graph_dataset[split].extend(graph)

In [None]:
graph_dataset['train'][2]

In [None]:
len(vocabulary['word'])

In [None]:
len(ignored)

In [None]:
len(vocabulary['word'])

In [None]:
pickle.dump(graph_dataset, open('conll_graph_all.pickle', 'wb'))

In [None]:
for key in vocabulary:
    vocabulary[key] = sorted(vocabulary[key])

In [None]:
pickle.dump(vocabulary, open('vocabulary_all.pickle', 'wb'))

In [None]:
sorted(labels)

In [None]:
pickle.dump(sorted(labels), open('labels.pickle', 'wb'))

In [None]:
len(set(rare_words['test']).intersection(set(rare_words['train'])))

In [None]:
len(set(all_tokens)), len(set(all_pos)), len(set(all_chunks))

In [None]:
print(set(all_pos))

In [None]:
Counter(all_pos).most_common()

In [None]:
print(set(all_chunks))

In [None]:
raw_voc = set([w.lower() for w in all_tokens])
len(raw_voc)

In [None]:
oov_words = []
non_alpha = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '`',
             '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '=', '?', '@', '[', ']']
probably_numbers = []

for word in tqdm(raw_voc):
    if word not in cn_words and word not in word2categories:
        if all([c in non_alpha for c in word]):
            probably_numbers.append(word)
        else:
            oov_words.append(word)

In [None]:
len(oov_words)

In [None]:
probably_numbers

In [None]:
print(sorted(set([c for w in oov_words for c in w])))

In [None]:
print(sorted(set([c for w in oov_words for c in w])))

In [None]:
not_accounted = 0
for w in sorted(oov_words):
    if words_counter[w] > 10:
        print(w, '\t', words_counter[w])
    else:
        not_accounted +=1