In [None]:
import json
import csv
import os
from transformers import EncoderDecoderModel, BertTokenizer, BertModel, BertForNextSentencePrediction, BatchEncoding
import torch
import sys

sys.path.append('..')
from tools.TextProcessing import build_word_tree, process_keywords
from tools.BasicUtils import my_write, my_csv_read, my_read, my_json_read, ntopidx, batch
from joint_score_func import SparseRetrieveSentForPairCoOccur, ScoreFunction1, demo_score_function, ScoreFunction2, Reader1, demo_reader, Reader2

In [None]:
file_description = [
    "keyword_f.txt ---- CS keywords",
    "wordtree.json ---- word tree for cs keywords",
    "entity.txt ---- Reformed cs keywords with '_' replacing ' '"
]
    
my_write('readme.txt', file_description)

## Prepare the data

In [None]:
# Collect keywords from terms-cs-cfl-epoch200.txt
stable_kw = []
unstable_kw = []
r = my_csv_read('../data/raw_data/terms-cs-cfl-epoch200.txt', delimiter='\t')
candidate_kw_list = [item[0] for item in r if float(item[1]) > 0.1]
stable_kw, unstable_kw = process_keywords(candidate_kw_list)
# Save keywords
if not os.path.exists('data'):
    os.mkdir('data')
my_write('data/keyword.txt', stable_kw)
# Generate word tree (25 seconds)
build_word_tree('data/keyword.txt', 'data/wordtree.json', 'data/entity.txt')

In [None]:
# Go to py folder and run followings in the backend 
# "python gen_co_occur.py ../joint_score_func/data/wordtree.json ../data/corpus/small_sent.txt ../joint_score_func/data/co_occur.txt"
# "python gen_occur.py ../joint_score_func/data/keyword.txt ../joint_score_func/data/co_occur.txt ../joint_score_func/data/occur.json"

In [None]:
# Knowledge Graph filtering

# Load known cs keywords
kw_set = set(my_read('data/keyword.txt'))
# Get potential cs entity id
eid_set = set([eid for eid, ent in my_csv_read('../data/raw_data/wikidata/entity_names.txt', delimiter='\t') if ent.lower() in kw_set])
# Get the subgraph that both entities are potential cs keywords
kg_cs_triples = [(eid1, eid2, rid) for eid1, eid2, rid in my_csv_read('../data/raw_data/wikidata/triples.txt', delimiter=' ') if eid1 in eid_set and eid2 in eid_set]
# Get cs entities and relations from subgraph
cs_eid_set = set()
cs_rid_set = set()
for eid1, eid2, rid in kg_cs_triples:
    cs_eid_set.update((eid1, eid2))
    cs_rid_set.add(rid)
# Map id to text
eid2ent_dict = {eid:ent.lower() for eid, ent in my_csv_read('../data/raw_data/wikidata/entity_names.txt', delimiter='\t') if eid in cs_eid_set}
rid2rel_dict = {rid:rel.lower() for rid, rel in my_csv_read('../data/raw_data/wikidata/relation_names.txt', delimiter='\t') if rid in cs_rid_set}
# Save files
json.dump(eid2ent_dict, open('data/eid2ent.json', 'w'))
json.dump(rid2rel_dict, open('data/rid2rel.json', 'w'))
csv.writer(open('data/kg_cs_triples.csv', 'w')).writerows(kg_cs_triples)

## Load data

In [None]:
eid2ent_dict = my_json_read('data/eid2ent.json')
rid2rel_dict = my_json_read('data/rid2rel.json')
kg_cs_triples = my_csv_read('data/kg_cs_triples.csv')
rel_list = list(set(rid2rel_dict.values()))

In [None]:
# Show some examples of graph triples
for i, item in enumerate(kg_cs_triples):
    ent1, ent2, rel = eid2ent_dict[item[0]], eid2ent_dict[item[1]], rid2rel_dict[item[2]]
    print('%s--%s--%s' % (ent1, ent2, rel))
    if i >= 10:
        break

## Setup model

In [None]:
sparse_retriever = SparseRetrieveSentForPairCoOccur('../data/corpus/small_sent.txt', 'data/occur.json')

In [None]:
test_sent = sparse_retriever.retrieve('java', 'python')
print(len(test_sent))
print(test_sent[0])

In [None]:
# Some helper functions

def train(kg_cs_triples, eid2ent_dict, rid2rel_dict, retriever, sf):
    for eid1, eid2, rid in kg_cs_triples:
        ent1, ent2, rel = eid2ent_dict[eid1], eid2ent_dict[eid2], rid2rel_dict[rid]
        candidate_sents = retriever.retrieve(ent1, ent2)
        query = '%s <RELATION> %s' % (ent1, ent2)
        with torch.no_grad():
            scores = [sf(sents, query) for sents in batch(candidate_sents, 16)]
        top_idx = ntopidx(5, scores)
        sub_sents = [candidate_sents[idx] for idx in top_idx]
        sub_score = sf(sub_sents, query).unsqueeze(0)
        torch.cuda.empty_cache()
        break

In [None]:
sf1 = ScoreFunction1('bert-base-uncased', additional_special_tokens=['<RELATION>'])

In [None]:
# ScoreFunction1 Example
scores, candidate_sents = demo_score_function('python', 'java', sf1, sparse_retriever)

In [None]:
sf2 = ScoreFunction2('bert-base-uncased', 'bert-base-uncased', additional_special_tokens=['<RELATION>'])

In [None]:
# ScoreFunction2 Example
scores, candidate_sents = demo_score_function('python', 'java', sf2, sparse_retriever)

In [None]:
reader1 = Reader1('bert-base-uncased', rel_list)

In [None]:
# Reader1 Example
rels = demo_reader('python', 'java', reader1, sparse_retriever)
print(rels)

In [None]:
reader2 = Reader2('bert-base-uncased', rel_list)

In [None]:
m = BertModel.from_pretrained('bert-base-uncased')
t = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
# initialize Bert2Bert from pre-trained checkpoints
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id

## Training