In [25]:
import os
import collections

# Data Preprocessing

In [3]:
datasets = ["fb15k", "fb15k237", "pathqueryFB", "pathqueryWN", "wn18", "wn18rr"]

## Preprocessing KBC Datasets

In [29]:
def get_unique_entities_relations(train, valid, test):
    entity_list = {}
    relation_list = {}
    
    for input_file in [train, valid, test]:
        print(f"Working with {input_file} file")
        # tab-separated (head, relation, tail) triples
        with open(input_file, "r") as f:
            for line in f.readlines():
                tokens = line.strip().split("\t")
                assert len(tokens) == 3
                entity_list[tokens[0]] = len(entity_list)
                entity_list[tokens[2]] = len(entity_list)
                relation_list[tokens[1]] = len(relation_list)

    return entity_list, relation_list

In [30]:
def write_vocab(vocabulary, entity_list, relation_list):
    fout = open(vocabulary, "w")
    fout.write("[PAD]" + "\n")
    for i in range(95):
        fout.write(f"[unused{i}]\n")
    fout.write("[UNK]" + "\n")
    fout.write("[CLS]" + "\n")
    fout.write("[SEP]" + "\n")
    fout.write("[MASK]" + "\n")
    for e in entity_list.keys():
        fout.write(e + "\n")
    for r in relation_list.keys():
        fout.write(r + "\n")
    vocab_size = 100 + len(entity_list) + len(relation_list)
    print(f"Vocabulary size {vocab_size}")
    fout.close()

In [31]:
def load_vocab(vocab_file):
    vocab = collections.OrderedDict()
    fin = open(vocab_file)
    for num, line in enumerate(fin):
        items = line.strip().split("\t")
        if len(items) > 2:
            break
        token = items[0]
        index = items[1] if len(items) == 2 else num
        token = token.strip()
        vocab[token] = int(index)
    return vocab

In [32]:
def write_true_triples(train, valid, test, vocab, output_file):
    true_triples = []
    for input_file in [train, valid, test]:
        with open(input_file, "r") as f:
            for line in f.readlines():
                h, r, t = line.strip('\r \n').split('\t')
                assert (h in vocab) and (r in vocab) and (t in vocab)
                hpos = vocab[h]
                rpos = vocab[r]
                tpos = vocab[t]
                true_triples.append((hpos, rpos, tpos))
    
    print(f"Number of true triples: {len(true_triples)}")
    fout = open(output_file, "w")
    for hpos, rpos, tpos in true_triples:
        fout.write(str(hpos) + "\t" + str(rpos) + "\t" + str(tpos) + "\n")
    fout.close()

In [33]:
def generate_mask_type(input_file, output_file):
    with open(output_file, "w") as fw:
        with open(input_file, "r") as fr:
            for line in fr.readlines():
                fw.write(line.strip('\r \n') + "\tMASK_HEAD\n")
                fw.write(line.strip('\r \n') + "\tMASK_TAIL\n")

In [42]:
def kbc_data_preprocess(old_train, old_valid, old_test, 
                        vocabulary, triples_file, 
                        new_train, new_valid, new_test):
    
    print("Extracting unique entities and relations...")
    entity_list, relation_list = get_unique_entities_relations(old_train, old_valid, old_test)
    
    print("Updating vocabulary...")
    write_vocab(vocabulary, entity_list, relation_list)
    
    # OrderedDict vocab
    vocab = load_vocab(vocabulary)
    
    print("Writing triples...")
    write_true_triples(old_train, old_valid, old_test, vocab, triples_file)

    print("Generating masks...")
    generate_mask_type(old_train, new_train)
    generate_mask_type(old_valid, new_valid)
    generate_mask_type(old_test, new_test)
    
    print("Preprocessing successful!!")

In [38]:
# kbc_datasets = ["fb15k", "fb15k237", "wn18", "wn18rr"]
data = "data\\fb15k"

In [43]:
# existing (input) files
old_train = os.path.join(os.getcwd(), data, "train.txt")
old_valid = os.path.join(os.getcwd(), data, "valid.txt")
old_test = os.path.join(os.getcwd(), data, "test.txt")

# new (output) file
vocabulary = os.path.join(os.getcwd(), data, "vocab.txt")
triples_file = os.path.join(os.getcwd(), data, "all.txt")
new_train = os.path.join(os.getcwd(), data, "train.coke.txt")
new_valid = os.path.join(os.getcwd(), data, "valid.coke.txt")
new_test = os.path.join(os.getcwd(), data, "test.coke.txt")

In [44]:
kbc_data_preprocess(old_train, old_valid, old_test, 
                    vocabulary, triples_file, 
                    new_train, new_valid, new_test)

Extracting unique entities and relations...
Working with D:\Workspace\_RWTH\_msc_mi\SoSe 2022\XAI\CoKE\CoKE\data\fb15k\train.txt file
Working with D:\Workspace\_RWTH\_msc_mi\SoSe 2022\XAI\CoKE\CoKE\data\fb15k\valid.txt file
Working with D:\Workspace\_RWTH\_msc_mi\SoSe 2022\XAI\CoKE\CoKE\data\fb15k\test.txt file
Updating vocabulary...
Vocabulary size 16396
Writing triples...
Number of true triples: 592213
Generating masks...
Preprocessing successful!!


## Preprocessing Path Query Datasets

In [47]:
def pathquery_get_unique_entities_relations(train, valid, test):
    entity_list = {}
    relation_list = {}
    
    for input_file in [train, valid, test]:
        with open(input_file, "r") as f:
            for line in f.readlines():
                tokens = line.strip().split("\t")
                assert len(tokens) == 3
                entity_list[tokens[0]] = len(entity_list)
                entity_list[tokens[2]] = len(entity_list)
                relations = tokens[1].split(",")
                for relation in relations:
                    relation_list[relation] = len(relation_list)
    
    return entity_list, relation_list

In [49]:
def filter_base_data(old_train, old_valid, old_test,
                     train_base, valid_base, test_base):
    def fil_base(input_file, output_file):
        fout = open(output_file, "w")
        base_n = 0
        with open(input_file, "r") as f:
            for line in f.readlines():
                tokens = line.strip().split("\t")
                assert len(tokens) == 3
                relations = tokens[1].split(",")
                if len(relations) == 1:
                    fout.write(line)
                    base_n += 1
        fout.close()
        return base_n

    train_base_n = fil_base(old_train, train_base)
    valid_base_n = fil_base(old_valid, valid_base)
    test_base_n = fil_base(old_test, test_base)

In [50]:
def generate_onlytail_mask_type(input_file, output_file):
    with open(output_file, "w") as fw:
        with open(input_file, "r") as fr:
            for line in fr.readlines():
                fw.write(line.strip('\r \n') + "\tMASK_TAIL\n")

In [None]:
def pathquery_data_preprocess(old_train, old_valid, old_test,
                              vocab_path, sen_candli_file, trivial_sen_file,
                              new_train, new_valid, new_test,
                              train_base, valid_base, test_base):
    
    print("Extracting unique entities and relations...")
    entity_list, relation_list = pathquery_get_unique_entities_relations(old_train, old_valid, old_test)
    
    print("Updating vocabulary...")
    write_vocab(vocab_path, entity_list, relation_list)
    
    filter_base_data(old_train, old_valid, old_test,
                     train_base, valid_base, test_base)
    
    generate_mask_type(old_train, new_train)
    generate_onlytail_mask_type(old_valid, new_valid)
    generate_onlytail_mask_type(old_test, new_test)
    
    vocab = load_vocab(vocab_path)
    
#     generate_eval_files(vocab_path, old_test, 
#                         train_base, valid_base, test_base, 
#                         sen_candli_file, trivial_sen_file)

In [45]:
# pathquery_datasets = ["pathqueryFB", "pathqueryWN"]
data = "data\\pathqueryFB"

In [46]:
# existing (input) files
old_train = os.path.join(os.getcwd(), data, "train")
old_valid = os.path.join(os.getcwd(), data, "valid")
old_test = os.path.join(os.getcwd(), data, "test")

new_train = os.path.join(os.getcwd(), data, "train.coke.txt")
new_valid = os.path.join(os.getcwd(), data, "valid.coke.txt")
new_test = os.path.join(os.getcwd(), data, "test.coke.txt")

vocab_file = os.path.join(os.getcwd(), data, "vocab.txt")
sen_candli_file = os.path.join(os.getcwd(), data, "sen_candli.txt")
trivial_sen_file = os.path.join(os.getcwd(), data, "trivial_sen.txt")

train_base = os.path.join(os.getcwd(), data, "train.base.txt")
valid_base = os.path.join(os.getcwd(), data, "valid.base.txt")
test_base = os.path.join(os.getcwd(), data, "test.base.txt")

In [None]:
#  pathquery_data_preprocess(old_train, old_valid, old_test,
#                               vocab_file, sen_candli_file, trivial_sen_file,
#                               new_train, new_valid, new_test,
#                               train_base, valid_base, test_base)