In [75]:
import sys
import os
from IPython import get_ipython

sys.path.append(os.path.abspath("./lib"))
if 'autoreload' not in get_ipython().magics_manager.magics['line']:
    %load_ext autoreload
%autoreload 2

from lib import dataloading as dl
from lib import tokenizer as tk
import torch
import tokenizers
import warnings
import numpy as np
import pandas as pd
from matplotlib_venn import venn2, venn3  
from matplotlib import pyplot as plt 
import plotly.express as px
from collections import defaultdict

In [76]:
data_df = dl.load_conllu(
    r"D:\Dropbox\Bachlorarbeit\Datasets\Universal Dependencies 2.15\ud-treebanks-v2.15\UD_English-GUM\en_gum-ud-train.conllu"
)
data_df = dl.clear_non_UPOS_tags(data_df)
print(data_df.head())

Dropped 2810 rows with non-UPOS tags 
Tags dropped: ['_']
            FORM         LEMMA   UPOS XPOS        FEATS HEAD DEPREL  \
ID                                                                    
1      Aesthetic     aesthetic    ADJ   JJ   Degree=Pos    2   amod   
2   Appreciation  appreciation   NOUN   NN  Number=Sing    0   root   
3            and           and  CCONJ   CC            _    5     cc   
4        Spanish       Spanish    ADJ   JJ   Degree=Pos    5   amod   
5            Art           art   NOUN   NN  Number=Sing    2   conj   

          DEPS                                               MISC  
ID                                                                 
1       2:amod  Discourse=organization-heading:1->57:8:grf-ly-...  
2       0:root                       Entity=1)|MSeg=Appreciat-ion  
3         5:cc                                                  _  
4       5:amod     Entity=(2-abstract-new-cf2-2-sgl|MSeg=Span-ish  
5   2:conj:and                      

In [77]:
data_df.describe()

Unnamed: 0,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL,DEPS,MISC
count,164108,164108,164108,164108,164108,164108,164108,164108,164108
unique,17226,12957,17,47,181,106,51,10849,68967
top,",",",",NOUN,NN,_,4,punct,0:root,_
freq,8647,8647,27288,20260,53205,10049,22748,9409,62470


In [78]:
vocab_size = 1000
upos_tags = [
        "ADJ",
        "ADP",
        "ADV",
        "AUX",
        "CCONJ",
        "DET",
        "INTJ",
        "NOUN",
        "NUM",
        "PART",
        "PRON",
        "PROPN",
        "PUNCT",
        "SCONJ",
        "SYM",
        "VERB",
        "X",
    ]
special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]

In [79]:
merges_upos["ADJ"]

[('a', 'l'),
 ('e', 'r'),
 ('▁', 's'),
 ('r', 'e'),
 ('t', 'i'),
 ('a', 'n'),
 ('l', 'e'),
 ('s', 't'),
 ('▁', 'm'),
 ('o', 'n'),
 ('a', 'r'),
 ('e', 'n'),
 ('i', 'c'),
 ('▁', 'f'),
 ('i', 'n'),
 ('▁', 'c'),
 ('▁', 'p'),
 ('▁', 'l'),
 ('o', 'r'),
 ('▁', 'o'),
 ('▁', 'n'),
 ('t', 'h'),
 ('▁', 'g'),
 ('▁', 'd'),
 ('▁', 'a'),
 ('▁', 'b'),
 ('▁', 'e'),
 ('u', 'l'),
 ('a', 't'),
 ('i', 'g'),
 ('o', 'u'),
 ('e', 'd'),
 ('en', 't'),
 ('v', 'e'),
 ('▁', 'h'),
 ('i', 't'),
 ('▁', 're'),
 ('b', 'le'),
 ('ti', 'c'),
 ('i', 'r'),
 ('▁n', 'e'),
 ('th', 'er'),
 ('s', 'i'),
 ('▁', 'in'),
 ('u', 'n'),
 ('a', 'ti'),
 ('on', 'al'),
 ('▁', 'w'),
 ('▁o', 'ther'),
 ('r', 'o'),
 ('m', 'p'),
 ('▁s', 'u'),
 ('o', 'l'),
 ('i', 'f'),
 ('c', 'h'),
 ('ou', 's'),
 ('ig', 'h'),
 ('i', 's'),
 ('▁m', 'o'),
 ('an', 't'),
 ('e', 'st'),
 ('▁ne', 'w'),
 ('r', 'a'),
 ('e', 'c'),
 ('▁f', 'ir'),
 ('▁fir', 'st'),
 ('u', 'r'),
 ('a', 'st'),
 ('t', 'er'),
 ('o', 'o'),
 ('▁', 'un'),
 ('in', 'g'),
 ('i', 'l'),
 ('▁m', 'an'),
 ('

In [80]:
tokenizers_upos = {}
merges_upos = {}
vocab_upos = {}
vocab_size = 1000
special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]
for upos_tag in upos_tags:
    text = data_df[data_df["UPOS"] == upos_tag]["FORM"].values.tolist()
    tokenizers_upos[upos_tag] = tk.train_tokenizer(text, vocab_size)
    vocab_upos[upos_tag], merges_upos[upos_tag] = tk.extract_vocab_and_merges(tokenizers_upos[upos_tag])

target_allocation = data_df["UPOS"].value_counts(normalize=True).sort_index() # Sort by index to match upos_tags order
vocab_allocation = np.array([5] * len(upos_tags))  # Ensure space for five special tokens

vocab_set = set()
while len(vocab_set) < vocab_size:
    vocab_allocation += tk.assign_proportionally(vocab_allocation, target_allocation, vocab_size - len(vocab_set))
    for idx, upos_tag in enumerate(upos_tags):
        vocab_set.update(list(vocab_upos[upos_tag])[:vocab_allocation[idx]])

merges_set = set()
for upos_tag in upos_tags:
    for merge in merges_upos[upos_tag]:
        if all(token in vocab_set for token in [merge[0], merge[1], merge[0] + merge[1]]):
            merges_set.add(merge)

print(f"Total merges: {len(merges_set)}")
print(merges_set)
print(f"Total vocab size: {len(vocab_set)}")
print(vocab_set)

Total merges: 1025
{('u', 'ght'), ('at', 'ed'), ('▁', '"'), ('▁l', 'i'), ('ic', 'e'), ('al', 's'), ('▁res', 'earch'), ('▁w', 'e'), ('o', 'r'), ('▁p', 'er'), ('i', 'st'), ('enc', 'es'), ('▁l', 'ine'), ('▁com', 'p'), ('s', 't'), ('▁p', 'ar'), ('▁of', 'f'), ('▁up', 'on'), ('v', 'er'), ('e', 't'), ('▁p', 'ers'), ('ens', 'e'), ('i', 'v'), ('i', 'ds'), ('▁s', 'a'), ('▁num', 'ber'), ('i', 'te'), ('s', 's'), ('e', 'n'), ('e', 's'), ('▁', 'out'), ('▁an', 'oth'), ('ic', 'k'), ('▁meth', 'od'), ('▁ho', 'me'), ('▁vide', 'o'), ('▁bo', 'x'), ('▁your', 'self'), ('▁gover', 'n'), ('▁sin', 'ce'), ('e', 'p'), ('▁mon', 'th'), ('▁re', 'v'), ('▁r', 'es'), ('▁tod', 'ay'), ('b', 'er'), ('eve', 'r'), ('▁m', 'on'), ('▁', 'f'), ('o', 'ver'), ('▁y', 'ear'), ('▁th', 'ings'), ('▁h', 'o'), ("▁'", 's'), ('▁’', 's'), ('▁sh', 'ow'), ('u', 't'), ('▁i', 't'), ('▁fin', 'd'), ('▁dis', 'c'), ('p', 't'), ('▁o', 'ver'), ('▁m', 'od'), ('▁', 'us'), ('▁st', 'ud'), ('a', 'k'), ('u', 'n'), ('er', 'v'), ('if', 'e'), ('u', 's'), ('gh

In [81]:
vocab = {token: idx for idx, token in enumerate(vocab_set)}
merges = list(merges_set)

In [82]:
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(
    vocab=vocab,
    merges=merges,
    unknown_token="[UNK]",
    padding_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]",
))

In [83]:
df = px.data.tips()
fig = px.treemap(df, path=[px.Constant("all"), 'day', 'time', 'sex'], values='total_bill')
fig.update_traces(root_color="lightgrey")
fig.update_layout(margin = dict(t=50, l=25, r=25, b=25))
fig.show()