Mar 11

Build OWER datasets with more sentences. In fact, include all Ryn sentences
in the OWER dataset and limit the number of sentences just before training.
The same could be done for classes.

Also, include the classes, debug information like entity names and other
information from the Ryn dataset that is later required.

# Imports

In [None]:
%load_ext autoreload
%autoreload 2

import logging
from pathlib import Path
from shutil import copyfile
from typing import List, Tuple, Dict, Set

from IPython.core.display import display
from pandas import DataFrame

from dao.ower.ower_dir import OwerDir
from dao.ryn.ryn_dir import RynDir

# Config

In [None]:
ryn_dir_path = '../data/ryn/irt.fb.irt.30.clean'
ower_dir_path = '../data/ower/ower-fb-irt'
class_count = 8
sent_count = 3

# 1. Load Triples TXTs

In [None]:
ryn_dir = RynDir('Ryn Directory', Path(ryn_dir_path))
ryn_dir.check()

split_dir = ryn_dir.split_dir
cw_train_triples: List[Tuple[int, int, int]] = split_dir.cw_train_triples_txt.load_triples()
cw_valid_triples: List[Tuple[int, int, int]] = split_dir.cw_valid_triples_txt.load_triples()
ow_valid_triples: List[Tuple[int, int, int]] = split_dir.ow_valid_triples_txt.load_triples()
ow_test_triples: List[Tuple[int, int, int]] = split_dir.ow_test_triples_txt.load_triples()

train_triples = cw_train_triples + cw_valid_triples
valid_triples = ow_valid_triples
test_triples = ow_test_triples

display('train_triples', len(train_triples), train_triples[:4])
display('valid_triples', len(valid_triples), valid_triples[:4])
display('test_triples', len(test_triples), test_triples[:4])

In [None]:
ent_to_label = ryn_dir.split_dir.entity_labels_txt.load()
rel_to_label = ryn_dir.split_dir.relation_labels_txt.load()

df_cols = ['head', 'rel', 'tail']

df_data = ((ent_to_label[head], rel_to_label[rel], ent_to_label[tail]) for head, rel, tail in train_triples)
df = DataFrame(data=df_data, columns=df_cols)
display(df)

df_data = ((ent_to_label[head], rel_to_label[rel], ent_to_label[tail]) for head, rel, tail in valid_triples)
df = DataFrame(data=df_data, columns=df_cols)
display(df)

df_data = ((ent_to_label[head], rel_to_label[rel], ent_to_label[tail]) for head, rel, tail in test_triples)
df = DataFrame(data=df_data, columns=df_cols)
display(df)

# 2. Save triples to Triples DBs

In [None]:
ower_dir = OwerDir('OWER Directory', Path(ower_dir_path))
ower_dir.create()

train_triples_db = ower_dir.tmp_dir.train_triples_db
train_triples_db.create_triples_table()
train_triples_db.insert_triples(train_triples)

valid_triples_db = ower_dir.tmp_dir.valid_triples_db
valid_triples_db.create_triples_table()
valid_triples_db.insert_triples(valid_triples)

test_triples_db = ower_dir.tmp_dir.test_triples_db
test_triples_db.create_triples_table()
test_triples_db.insert_triples(test_triples)

# 3. Copy Ryn Label TXTs to OWER Dir

In [None]:
copyfile(ryn_dir.split_dir.entity_labels_txt._path, ower_dir.ent_labels_txt._path)
copyfile(ryn_dir.split_dir.relation_labels_txt._path, ower_dir.rel_labels_txt._path)

ent_to_label = ower_dir.ent_labels_txt.load()
rel_to_label = ower_dir.rel_labels_txt.load()

# 4. Query most common classes and write them to Classes TSV

In [None]:
rel_tail_supps = train_triples_db.select_top_rel_tails(class_count)

rel_tail_supps

In [None]:
ent_count = len(ent_to_label)

rel_tail_freq_labels = [(rel, tail, supp / ent_count, f'{rel_to_label[rel]} {ent_to_label[tail]}')
                        for rel, tail, supp in rel_tail_supps]

ower_dir.classes_tsv.save(rel_tail_freq_labels)

rel_tail_freq_labels

# 5. Query classes' entities

In [None]:
train_class_ents = []
valid_class_ents = []
test_class_ents = []

for rel, tail, _ in rel_tail_supps:
    class_ents = ower_dir.tmp_dir.train_triples_db.select_heads_with_rel_tail(rel, tail)
    train_class_ents.append(class_ents)

for rel, tail, _ in rel_tail_supps:
    class_ents = ower_dir.tmp_dir.valid_triples_db.select_heads_with_rel_tail(rel, tail)
    valid_class_ents.append(class_ents)
    
for rel, tail, _ in rel_tail_supps:
    class_ents = ower_dir.tmp_dir.test_triples_db.select_heads_with_rel_tail(rel, tail)
    test_class_ents.append(class_ents)

display('train, class 0:', len(train_class_ents[0]), list(train_class_ents[0])[:10])
display('valid, class 0:', len(valid_class_ents[0]), list(valid_class_ents[0])[:10])
display('test, class 0:', len(test_class_ents[0]), list(test_class_ents[0])[:10])

# 6. Create OWER Sample TSVs

In [None]:
train_ent_to_sents: Dict[int, Set[str]] = ryn_dir.text_dir.cw_train_sentences_txt.load()
valid_ent_to_sents: Dict[int, Set[str]] = ryn_dir.text_dir.ow_valid_sentences_txt.load()
test_ent_to_sents: Dict[int, Set[str]] = ryn_dir.text_dir.ow_test_sentences_txt.load()

def get_samples(ent_to_sents, class_ents):
    """
    :param ent_to_sents: {ent: {sent}}
    :param class_ents: [[ent]]
    :return: [(ent, label, [has class], [sent])
    """

    ent_lbl_classes_sents_list = []
    
    for ent, sents in ent_to_sents.items():
        
        ent_classes = []
        for class_ in range(len(class_ents)):
            ent_classes.append(int(ent in class_ents[class_]))
        
        if len(sents) < sent_count:
            logging.warning(f"Entity '{ent_to_label[ent]}' ({ent}) has less than {sent_count} sentences. Skipping")
            continue
        
        ent_lbl_classes_sents_list.append((ent, ent_to_label[ent], ent_classes, sents))
    
    return ent_lbl_classes_sents_list

train_samples = get_samples(train_ent_to_sents, train_class_ents)
valid_samples = get_samples(valid_ent_to_sents, valid_class_ents)
test_samples = get_samples(test_ent_to_sents, test_class_ents)

ower_dir.train_samples_tsv.save(train_samples)
ower_dir.valid_samples_tsv.save(valid_samples)
ower_dir.test_samples_tsv.save(test_samples)