In [2]:
from uniparse import Vocabulary
from uniparse.dataprovider import batch_by_buckets
import random
from collections import defaultdict
import numpy as np
from sklearn.utils import class_weight
from collections import defaultdict
import numpy as np
from uniparse import Model, Vocabulary

from uniparse.callbacks import ModelSaveCallback
from uniparse.dataprovider import batch_by_buckets

# uniparse intro

In [None]:
vocab = Vocabulary()
samples = vocab._read_conll("../data/ptb_conllu/train.conllu", tokenize=False)

In [None]:
sample = random.choice(samples)

In [None]:
sample[2][:5]

# extract parent -> head -> modifier from treebank

In [3]:
from collections import Counter
import pandas as pd

def count_3rd_order_sets(treebank_file, vocab):
    samples = vocab._read_conll(treebank_file, tokenize=False)
    counter = Counter()
    for sample in samples:
        extract_3O_paths(sample[3], sample[2], counter)
        
    factors = [{"afactor":"%s->%s->%s"%k, "count":v} for k,v in counter.items()]
    df = pd.DataFrame.from_dict(factors).sort_values("count", ascending=False)
    
    return df

def extract_3O_paths(tree, tags, counter, root_token="<root>"):    
    for m, h in enumerate(tree):
        p = tree[h]
        parent_tag, head_tag, modifier_tag = [tags[i] if i >= 0 else tags[0] for i in [p, h, m]]
        parent_tag, head_tag, modifier_tag = [t if t != 1 else root_token for t in [parent_tag, head_tag, modifier_tag]]
        counter.update([(parent_tag, head_tag, modifier_tag)])

def factor_to_str(t):
    return "->".join(t)

def get_factor_weights(dataframe):
    dataset = []
    factor_map = defaultdict(int)
    for i, row in dataframe.iterrows():
        factor = row["afactor"]
        factor_map[i] = factor
        for _ in range(row["count"]):
            dataset.append(i)
    classes = np.unique(dataset)
    weights = class_weight.compute_class_weight("balanced", classes, dataset)
    return {factor_map[c]:w for c,w in zip(classes,weights)}

def batch_to_weights(sample, factor_weights, tag_map):
    (words, tags), (trees, labels) = sample
    # batch :: (b, n)
    output = np.zeros(trees.shape)
    for b, tree in enumerate(trees):
        for m, h in enumerate(tree):
            p = tree[h]
            parent_tag, head_tag, modifier_tag = [
                tags[b, i] if i >= 0 else tags[b, 0]
                for i in [p, h, m]
            ]
            factor_ = (parent_tag, head_tag, modifier_tag)
            factor = [tag_map[e] for e in factor_]
            factor = factor_to_str(factor)
            try:
                output[b, m] = factor_weights[factor]
            except:
                print(factor_)
                raise
        
    return output

In [None]:
ptb_df = count_3rd_order_sets("../data/ptb_conllu/train.conllu", vocab)

In [None]:
print(ptb_df.shape)
ptb_df.head()

In [None]:
(ptb_df["count"] == 1).sum()/ptb_df.shape[0]

In [None]:
ptb_df.plot.hist(bins=100)

In [None]:
x = ptb_df["count"] < 10
x.astype(int).plot.hist()

# UD EWT

In [None]:
ud_en_df = count_3rd_order_sets("../data/en_ewt-ud-train.conllu", vocab)

In [None]:
print(ud_en_df.shape)
ud_en_df.head()

In [None]:
ud_factor_weights = get_factor_weights(ud_en_df)
ud_factor_weights

# lets test it all out

In [None]:
vocab = Vocabulary()
vocab.fit("../data/en_ewt-ud-train.conllu")

In [None]:
X = vocab.tokenize_conll("../data/en_ewt-ud-train.conllu")

In [None]:
from uniparse.dataprovider import batch_by_buckets
X = batch_by_buckets(X, batch_size=32, shuffle=True)

In [None]:
_, samples = X

In [None]:
x, y = samples[0]

In [None]:
label_weights = batch_to_weights(samples[0], ud_factor_weights, vocab._id2tag)

In [None]:
ud_sample_weights = [batch_to_weights(samples[0], ud_factor_weights, vocab._id2tag) for sample in samples]

In [None]:
ud_sample_weights[1].shape

In [None]:
X_hat = [(x+(w,),y)for (x,y), w in zip(X[1], ud_sample_weights)]
words, tags, weights = X_hat[0][0]

In [None]:
words.shape, weights.shape

# lets do it from the top

In [4]:
from uniparse.models.dynet.syntax_att import Parser

def train(train_file, dev_file, test_file, n_epochs, parameter_file, vocab_file, model_class):
    """Training procedure."""
    vocab = Vocabulary()
    vocab = vocab.fit(train_file)
    
    # "../data/en_ewt-ud-train.conllu"
    train_file_df = count_3rd_order_sets(train_file, vocab)
    
    # save vocab for reproducability later
    if vocab_file:
        print("> saving vocab to", vocab_file)
        vocab.save(vocab_file)

    # prep data
    print(">> Loading in data")
    train_data = vocab.tokenize_conll(train_file)
    dev_data = vocab.tokenize_conll(dev_file)
    test_data = vocab.tokenize_conll(test_file)

    train_batches = batch_by_buckets(train_data, batch_size=32, shuffle=True)
    dev_batches = batch_by_buckets(dev_data, batch_size=32, shuffle=True)
    test_batches = batch_by_buckets(test_data, batch_size=32, shuffle=False)

    indicies, samples = train_batches
    
    factor_weights = get_factor_weights(train_file_df)

    label_weights = [batch_to_weights(sample, factor_weights, vocab._id2tag) for sample in samples]
    X_hat = [(x+(w,),y) for (x,y), w in zip(train_batches[1], label_weights)]
    train_batches = (indicies, X_hat)

    model = model_class(vocab)

    save_callback = ModelSaveCallback(parameter_file)
    callbacks = [save_callback]

    # prep params
    parser = Model(model, optimizer="adam", vocab=vocab)

    parser.train(train_batches, dev_file, dev_batches, epochs=n_epochs, callbacks=callbacks, verbose=True)
    parser.load_from_file(parameter_file)

    metrics = parser.evaluate(test_file, test_batches, delete_output=False)
    test_UAS = metrics["nopunct_uas"]
    test_LAS = metrics["nopunct_las"]

    print(metrics)

    print()
    print(">>> Model maxed on dev at epoch", save_callback.best_epoch)
    print(">>> Test score:", test_UAS, test_LAS)

In [5]:
TRAIN = "../data/en_ewt-ud-train.conllu"
DEV = "../data/en_ewt-ud-dev.conllu"
TEST = "../data/en_ewt-ud-test.conllu"
train(TRAIN, DEV, TEST, 30, "model.params", "model.vocab", Parser)

> saving vocab to model.vocab
>> Loading in data


[1/30] arc 0.00, rel 0.71, loss 51.458: 100%|██████████| 451/451 [03:05<00:00,  2.44it/s] 


[1][195s] 0.19691, 0.15200 


[2/30] arc 0.00, rel 0.80, loss 64.475: 100%|██████████| 451/451 [02:54<00:00,  2.59it/s] 


[2][183s] 0.44211, 0.37686 


[3/30] arc 0.00, rel 0.90, loss 2.980: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s]  


[3][174s] 0.42772, 0.37827 


[4/30] arc 0.41, rel 0.92, loss 3.281: 100%|██████████| 451/451 [02:45<00:00,  2.72it/s]  


[4][174s] 0.69156, 0.56833 


[5/30] arc 0.19, rel 0.84, loss 49.354: 100%|██████████| 451/451 [02:48<00:00,  2.68it/s] 


[5][176s] 0.76883, 0.67128 


[6/30] arc 0.51, rel 0.89, loss 7.437: 100%|██████████| 451/451 [02:53<00:00,  2.60it/s]  


[6][181s] 0.80032, 0.70885 


[7/30] arc 0.66, rel 0.89, loss 1.830: 100%|██████████| 451/451 [02:47<00:00,  2.70it/s]  


[7][175s] 0.80957, 0.73934 


[8/30] arc 0.48, rel 0.88, loss 23.262: 100%|██████████| 451/451 [02:45<00:00,  2.72it/s] 


[8][173s] 0.81411, 0.74260 


[9/30] arc 0.95, rel 1.00, loss 0.139: 100%|██████████| 451/451 [02:45<00:00,  2.72it/s]  


[9][174s] 0.82178, 0.75626 


[10/30] arc 0.70, rel 0.87, loss 2.994: 100%|██████████| 451/451 [02:45<00:00,  2.72it/s]  


[10][173s] 0.82740, 0.74941 


[11/30] arc 0.83, rel 0.95, loss 0.312: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s]  


[11][174s] 0.83457, 0.77600 


[12/30] arc 0.75, rel 0.91, loss 2.192: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[12][174s] 0.83711, 0.77627 


[13/30] arc 0.71, rel 0.92, loss 1.563: 100%|██████████| 451/451 [02:45<00:00,  2.72it/s]  


[13][173s] 0.84533, 0.78757 


[14/30] arc 1.00, rel 1.00, loss 0.000: 100%|██████████| 451/451 [02:45<00:00,  2.72it/s]  


[14][174s] 0.84678, 0.78462 


[15/30] arc 0.82, rel 0.89, loss 0.390: 100%|██████████| 451/451 [02:46<00:00,  2.72it/s] 


[15][174s] 0.84973, 0.78770 


[16/30] arc 0.77, rel 0.90, loss 1.351: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[16][174s] 0.85290, 0.77319 


[17/30] arc 0.91, rel 0.96, loss 0.243: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[17][174s] 0.85318, 0.79383 


[18/30] arc 0.77, rel 0.93, loss 0.827: 100%|██████████| 451/451 [02:45<00:00,  2.72it/s] 


[18][174s] 0.85018, 0.79397 


[19/30] arc 0.90, rel 0.95, loss 0.312: 100%|██████████| 451/451 [02:46<00:00,  2.70it/s] 


[19][174s] 0.85431, 0.78975 


[20/30] arc 0.80, rel 0.91, loss 2.016: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s]  


[20][174s] 0.85413, 0.79469 


[21/30] arc 0.82, rel 0.92, loss 0.522: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[21][174s] 0.85563, 0.80268 


[22/30] arc 0.84, rel 0.95, loss 0.563: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[22][174s] 0.85313, 0.79873 


[23/30] arc 0.83, rel 0.93, loss 0.716: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[23][174s] 0.85535, 0.80749 


[24/30] arc 0.86, rel 0.94, loss 0.412: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[24][174s] 0.85653, 0.80104 


[25/30] arc 0.80, rel 0.92, loss 1.658: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[25][174s] 0.86044, 0.80299 


[26/30] arc 0.90, rel 0.97, loss 0.109: 100%|██████████| 451/451 [02:46<00:00,  2.72it/s] 


[26][174s] 0.85758, 0.80948 


[27/30] arc 0.81, rel 0.92, loss 0.592: 100%|██████████| 451/451 [02:45<00:00,  2.72it/s] 


[27][173s] 0.86148, 0.81130 


[28/30] arc 0.85, rel 0.93, loss 0.467: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[28][174s] 0.86266, 0.81084 


[29/30] arc 0.95, rel 0.95, loss 0.040: 100%|██████████| 451/451 [02:46<00:00,  2.71it/s] 


[29][174s] 0.86007, 0.80590 


[30/30] arc 0.89, rel 0.95, loss 18.990: 100%|██████████| 451/451 [02:48<00:00,  2.67it/s]


[30][176s] 0.86384, 0.81298 
>> Finished at epoch 30
>> outputed predictions to /var/folders/0t/t8j215jn7z77m328dvr0dyfh3b9248/T/tmpob3j6o2w
{'uas': 0.8500159387950271, 'las': 0.8057857825948358, 'nopunct_uas': 0.8615959531513466, 'nopunct_las': 0.8120129426240714}

>>> Model maxed on dev at epoch 30
>>> Test score: 0.8615959531513466 0.8120129426240714
