In [32]:
import sys
sys.path.append('..')
from tools.DocProcessing.CoOccurrence import co_occur_load
from tools.BasicUtils import MultiProcessing, my_read, my_write
from tools.DocProcessing.CoOccurGraph import graph_load, get_subgraph
# from tools.DocProcessing.DatasetGenerator import TriEntities
import time
import pandas as pd
import csv

In [None]:
keyword_list = my_read('../data/corpus/entity.txt')
co_occur_list = co_occur_load('../data/corpus/ent_co_occur.txt')
pair_graph = get_subgraph(graph_load('../data/corpus/ent_pair.gpickle'), 0.3, 3)

In [28]:
from typing import List
import networkx as nx
import spacy
nlp = spacy.load('en_core_web_sm')
class TriEntities:
    def __init__(self, entity_list:List[str], co_occur_list:List[List[int]], pair_graph:nx.Graph, kw_dist_max:int=6, sent_length_max:int=32):
        self.keyword_list = entity_list
        self.keyword_set = set(entity_list)
        self.pair_graph = pair_graph
        self.kw_dist_max = kw_dist_max
        self.sent_length_max = sent_length_max
        self.co_occur_list = co_occur_list
        self.line_record = []

    def line_operation(self, line:str):
        line_id, sent = line.split(':', 1)
        tokens = sent.split()
        if len(tokens) > self.sent_length_max:
            return
        kws = self.co_occur_list[int(line_id)-1]
        kws = [idx for idx in kws if tokens.count(self.keyword_list[idx]) == 1]
        if len(kws) < 3:
            return
        pairs = [(self.keyword_list[kws[i]], self.keyword_list[kws[j]]) for i in range(len(kws)-1) for j in range(i+1, len(kws)) if self.pair_graph.has_edge(kws[i], kws[j])]
        pairs = [pair for pair in pairs if abs(tokens.index(pair[0]) - tokens.index(pair[1])) <= self.kw_dist_max]
        if len(pairs) < 2:
            return
        kws = [self.keyword_list[idx] for idx in kws]
        temp_list = []
        for kw in kws:
            pair_idx = [1 if kw in pair else 0 for pair in pairs]
            if sum(pair_idx) < 2:
                continue
            kw_set = set()
            sub_pairs = [pair for i, pair in enumerate(pairs) if pair_idx[i] == 1]
            for pair in sub_pairs:
                kw_set.update(pair)
            kw_set.remove(kw)
            temp_list.append((kw, kw_set))
        if not temp_list:
            return
        
        doc = nlp(sent)
        tokens = [token.text for token in doc]
        kws = [kw for kw in kws if kw in tokens]
        for media_kw, kw_set in temp_list:
            subj_test = False
            subj_text = ''
            obj_test = False
            obj_text = ''
            for kw in kw_set:
                if kw not in kws:
                    continue
                if doc[tokens.index(kw)].dep_ == 'nsubj':
                    subj_test = True
                    subj_text = kw
                elif doc[tokens.index(kw)].dep_ == 'dobj':
                    obj_test = True
                    obj_text = kw
            if subj_test and obj_test:
                self.line_record.append((sent, media_kw, subj_text, obj_text))
            

In [35]:
import numpy as np
def mark_sent_in_html(sent:str, keyword_list:List[str], is_entity:bool=True):
    reformed_sent = sent.split() if is_entity else sent_lemmatize(sent.replace('-', ' - '))
    reformed_keywords = [[k] for k in keyword_list] if is_entity else [k.replace('-', ' - ').split() for k in keyword_list]
    mask = np.zeros(len(reformed_sent), dtype=np.bool)
    for k in reformed_keywords:
        begin_idx = 0
        while reformed_sent[begin_idx:].count(k[0]) > 0:
            begin_idx = reformed_sent.index(k[0], begin_idx)
            is_good = True
            i = 0
            for i in range(1, len(k)):
                if begin_idx + i >= len(reformed_sent) or reformed_sent[begin_idx + i] != k[i]:
                    is_good = False
                    break
            if is_good:
                mask[begin_idx:begin_idx+i+1] = True
            begin_idx += (i+1)
    i = 0
    insert_idx = 0
    while i < len(mask):
        if mask[i] and (i == 0 or mask[i-1] == False):
            reformed_sent.insert(insert_idx, '<font style=\"color:red;\">')
            insert_idx += 2
            i += 1
            while i < len(mask) and mask[i]:
                i += 1
                insert_idx += 1
            reformed_sent.insert(insert_idx, '</font>')
            insert_idx += 1
        insert_idx += 1
        i += 1
    return ' '.join(reformed_sent)

In [29]:
te = TriEntities(keyword_list, co_occur_list, pair_graph)
for idx, line in enumerate(open('../data/corpus/small_sent_line.txt').readlines()):
    te.line_operation(line)
    if idx % 100000 == 0:
        print(idx)

0
100000
200000
300000
400000
500000
600000
700000
800000
900000
1000000
1100000
1200000
1300000
1400000
1500000
1600000
1700000
1800000
1900000


In [30]:
len(te.line_record)

1135

In [38]:
content = []
for sent, media_kw, subj_text, obj_text in te.line_record:
    content.append('<h3>subject: %s, object: %s, connection: %s</h3><br>' % (subj_text, obj_text, media_kw))
    content.append('%s<br><br>' % mark_sent_in_html(sent, [subj_text, obj_text, media_kw]))
my_write('triples.html', content)


In [None]:
df = pd.DataFrame(te.line_record, columns=['sent', 'media_kw', 'connected_kw'])
df.to_csv('../data/corpus/connected_kws.csv')

In [None]:
test_lines = df.sample(frac=1).reset_index(drop=True)[:1000]

In [None]:
'python' in set(test_lines['media_kw'])