In [1]:
import json
import csv
import os
from transformers import EncoderDecoderModel, BertTokenizer, BertModel, BertForNextSentencePrediction, BatchEncoding
import torch
from torch.nn import Softmax, Linear, CrossEntropyLoss
from torch.nn.functional import normalize
from typing import List, Iterable
from collections import defaultdict
import sys

sys.path.append('..')
from tools.TextProcessing import build_word_tree, process_keywords, batched_sent_tokenize, clean_text
from tools.BasicUtils import my_write, my_csv_read, my_read, my_json_read, ntopidx

In [None]:
loss = CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()

In [None]:
target

In [None]:
file_description = [
    "keyword_f.txt ---- CS keywords",
    "wordtree.json ---- word tree for cs keywords",
    "entity.txt ---- Reformed cs keywords with '_' replacing ' '"
]

if not os.path.exists('../data/temp/joint_score_function'):
    os.mkdir('../data/temp/joint_score_function')
    
my_write('../data/temp/joint_score_function/readme.txt', file_description)

## Prepare the data

In [None]:
# Collect keywords from terms-cs-cfl-epoch200.txt
stable_kw = []
unstable_kw = []
r = my_csv_read('../data/raw_data/terms-cs-cfl-epoch200.txt', delimiter='\t')
candidate_kw_list = [item[0] for item in r if float(item[1]) > 0.1]
stable_kw, unstable_kw = process_keywords(candidate_kw_list)
# Save keywords
my_write('../data/temp/joint_score_function/keyword_f.txt', stable_kw)
# Generate word tree (25 seconds)
build_word_tree('../data/temp/joint_score_function/keyword_f.txt', '../data/temp/joint_score_function/wordtree.json', '../data/temp/joint_score_function/entity.txt')

In [None]:
# Go to py folder and run followings in the backend 
# "python gen_co_occur.py ../data/temp/joint_score_function/wordtree.json ../data/corpus/small_sent.txt ../data/temp/joint_score_function/co_occur.txt"
# "python gen_occur.py ../data/temp/joint_score_function/keyword_f.txt ../data/temp/joint_score_function/co_occur.txt ../data/temp/joint_score_function/occur.json"

In [None]:
# Knowledge Graph filtering

# Load known cs keywords
kw_set = set(my_read('../data/temp/joint_score_function/keyword_f.txt'))
# Get potential cs entity id
eid_set = set([eid for eid, ent in my_csv_read('../data/raw_data/wikidata/entity_names.txt', delimiter='\t') if ent.lower() in kw_set])
# Get the subgraph that both entities are potential cs keywords
kg_cs_triples = [(eid1, eid2, rid) for eid1, eid2, rid in my_csv_read('../data/raw_data/wikidata/triples.txt', delimiter=' ') if eid1 in eid_set and eid2 in eid_set]
# Get cs entities and relations from subgraph
cs_eid_set = set()
cs_rid_set = set()
for eid1, eid2, rid in kg_cs_triples:
    cs_eid_set.update((eid1, eid2))
    cs_rid_set.add(rid)
# Map id to text
eid2ent_dict = {eid:ent.lower() for eid, ent in my_csv_read('../data/raw_data/wikidata/entity_names.txt', delimiter='\t') if eid in cs_eid_set}
rid2rel_dict = {rid:rel.lower() for rid, rel in my_csv_read('../data/raw_data/wikidata/relation_names.txt', delimiter='\t') if rid in cs_rid_set}
# Save files
json.dump(eid2ent_dict, open('../data/temp/joint_score_function/eid2ent.json', 'w'))
json.dump(rid2rel_dict, open('../data/temp/joint_score_function/rid2rel.json', 'w'))
csv.writer(open('../data/temp/joint_score_function/kg_cs_triples.csv', 'w')).writerows(kg_cs_triples)

## Load data

In [None]:
eid2ent_dict = my_json_read('../data/temp/joint_score_function/eid2ent.json')
rid2rel_dict = my_json_read('../data/temp/joint_score_function/rid2rel.json')
kg_cs_triples = my_csv_read('../data/temp/joint_score_function/kg_cs_triples.csv')
rel_list = list(set(rid2rel_dict.values()))

In [None]:
for i, item in enumerate(kg_cs_triples):
    ent1, ent2, rel = eid2ent_dict[item[0]], eid2ent_dict[item[1]], rid2rel_dict[item[2]]
    print('%s--%s--%s' % (ent1, ent2, rel))
    if i >= 10:
        break

## Setup model

In [None]:
class RetrieveSentForPairCoOccur:
    def __init__(self, sent_file:str, occur_file:str):
        self._sents = my_read(sent_file)
        self._occur_dict = defaultdict(set)
        for k, v in json.load(open(occur_file)).items():
            self._occur_dict[k] = set(v)

    def retrieve(self, kw1:str, kw2:str):
        co_occur_index = self._occur_dict[kw1] & self._occur_dict[kw2]
        return [self._sents[idx] for idx in co_occur_index]

retriever = RetrieveSentForPairCoOccur('../data/corpus/small_sent.txt', '../data/temp/joint_score_function/occur.json')

In [None]:
class ScoreFunction1(torch.nn.Module):
    def __init__(self, model_file:str, additional_special_tokens:List[str]=None, device:str=None):
        super().__init__()
        self._score_function = BertForNextSentencePrediction.from_pretrained(model_file)
        self._tokenizer = BertTokenizer.from_pretrained(model_file)
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
        self._sm = Softmax(1)
        
        if additional_special_tokens is not None:
            self._tokenizer.add_special_tokens({'additional_special_tokens' : additional_special_tokens})
            self._score_function.resize_token_embeddings(len(self._tokenizer))
        self._score_function.to(self._device)

    def forward(self, candidate_sents:List[str], query:str):
        inputs = BatchEncoding(self._tokenizer(candidate_sents, [query]*len(candidate_sents), padding=True, truncation=True, max_length=80, return_tensors="pt")).to(self._device)
        output = self._score_function(**inputs, labels=torch.LongTensor([1]*len(candidate_sents)).to(self._device))
        return self._sm(output.logits)[:, 1]
        
sf1 = ScoreFunction1('bert-base-uncased', additional_special_tokens=['<RELATION>'])

In [None]:
class ScoreFunction2(torch.nn.Module):
    def __init__(self, context_model_file:str, query_model_file:str, additional_special_tokens:List[str]=None, device:str=None):
        super().__init__()
        self._context_encoder = BertModel.from_pretrained(context_model_file)
        self._query_encoder = BertModel.from_pretrained(query_model_file)
        self._tokenizer = BertTokenizer.from_pretrained(query_model_file)
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
        
        if additional_special_tokens is not None:
            self._tokenizer.add_special_tokens({'additional_special_tokens' : additional_special_tokens})
            self._query_encoder.resize_token_embeddings(len(self._tokenizer))
        self._query_encoder.to(self._device)
        self._context_encoder.to(self._device)

    def forward(self, candidate_sents:List[str], query:str):
        context_inputs = BatchEncoding(self._tokenizer(candidate_sents, padding=True, truncation=True, max_length=80, return_tensors="pt")).to(self._device)
        query_inputs = BatchEncoding(self._tokenizer(query, padding=True, truncation=True, max_length=20, return_tensors="pt")).to(self._device)
        context_emb = normalize(self._context_encoder(**context_inputs).last_hidden_state[:, 0, :])
        query_emb = normalize(self._query_encoder(**query_inputs).last_hidden_state[:, 0, :])
        return torch.inner(context_emb, query_emb)

sf2 = ScoreFunction2('bert-base-uncased', 'bert-base-uncased', additional_special_tokens=['<RELATION>'])

In [None]:
class Reader1(torch.nn.Module):
    def __init__(self, encoder_model:str, rels:List[str], device:str=None):
        super().__init__()
        self._rel2cls = {rel:i for i, rel in enumerate(rels)}
        self._rels = rels
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
        self._classifier = Linear(768, len(self._rel2cls), device=self._device)
        self._encoder = BertModel.from_pretrained(encoder_model).to(self._device)
        self._tokenizer = BertTokenizer.from_pretrained(encoder_model)
        self._loss_cal = CrossEntropyLoss()
        self._sm = Softmax(1)

    def forward(self, sents:List[str], score:torch.Tensor, rel:str=None):
        inputs = BatchEncoding(self._tokenizer(sents, padding=True, truncation=True, max_length=80, return_tensors="pt")).to(self._device)
        sents_emb = self._encoder(**inputs).last_hidden_state[:, 0, :]
        merged_emb = torch.inner(score, sents_emb)
        cls_ret = self._classifier(merged_emb)
        if rel is not None:
            temp_cls = self._rel2cls[rel]
            return self._rels[torch.argmax(self._sm(cls_ret), dim=1)], self._loss_cal(cls_ret, torch.tensor(temp_cls, dtype=torch.long))
        else:
            return self._rels[torch.argmax(self._sm(cls_ret), dim=1)]

reader1 = Reader1('bert-base-uncased', rel_list)

In [None]:
class Reader2(torch.nn.Module):
    def __init__(self, encoder_model:str, rels:List[str], device:str=None):
        super().__init__()
        self._rel2cls = {rel:i for i, rel in enumerate(rels)}
        self._rels = rels
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
        self._classifier = Linear(768, len(self._rel2cls), device=self._device)
        self._encoder = BertModel.from_pretrained(encoder_model).to(self._device)
        self._tokenizer = BertTokenizer.from_pretrained(encoder_model)
        self._loss_cal = CrossEntropyLoss()
        self._sm = Softmax(1)

    def forward(self, sents:List[str], score:torch.Tensor, rel:str=None):
        inputs = BatchEncoding(self._tokenizer(sents, padding=True, truncation=True, max_length=80, return_tensors="pt")).to(self._device)
        sents_emb = self._encoder(**inputs).last_hidden_state[:, 0, :]
        merged_emb = torch.inner(score, sents_emb)
        cls_ret = self._classifier(merged_emb)
        if rel is not None:
            temp_cls = self._rel2cls[rel]
            return self._rels[torch.argmax(self._sm(cls_ret), dim=1)], self._loss_cal(cls_ret, torch.tensor(temp_cls, dtype=torch.long))
        else:
            return self._rels[torch.argmax(self._sm(cls_ret), dim=1)]

reader2 = Reader2('bert-base-uncased', rel_list)

In [None]:
m = BertModel.from_pretrained('bert-base-uncased')
t = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
# initialize Bert2Bert from pre-trained checkpoints
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id

## Training

In [None]:
for eid1, eid2, rid in kg_cs_triples:
    ent1, ent2, rel = eid2ent_dict[eid1], eid2ent_dict[eid2], rid2rel_dict[rid]
    sents = retriever.retrieve(ent1, ent2)
    score = sf1(sents, ' '.join((ent1, '<RELATION>', ent2)))
    top_idx = ntopidx(10, score)
    sub_sents = [sents[idx] for idx in top_idx]
    sub_score = [score[idx] for idx in top_idx]
    