# Co-occurrence notebook
+ This notebook is used for handling keyword co-occurrence related work

## Import needed packages

In [1]:
import sys
import math
from typing import Dict, List
import numpy as np
import spacy
nlp = spacy.load('en_core_web_sm')

sys.path.append('..')
from tools.BasicUtils import my_read, my_write, MultiProcessing
from tools.TextProcessing import sent_lemmatize
from tools.DocProcessing.Occurrence import Occurrence, occurrence_dump, occurrence_load, occurrence_post_operation
from tools.DocProcessing.CoOccurGraph import build_graph, graph_load, graph_dump, get_subgraph
from tools.DocProcessing.CoOccurrence import gen_co_occur, co_occur_load, co_occur_dump
import networkx as nx

## Fundamental code

### Load fundamental data (50 seconds)

In [2]:
sent_list = my_read('../data/corpus/small_sent_reformed.txt')
keyword_list = my_read('../data/corpus/entity.txt')
word2idx_dict = {word:i for i, word in enumerate(keyword_list)}
occur_dict = occurrence_load('../data/corpus/ent_occur.json')
co_occur_list = co_occur_load('../data/corpus/ent_co_occur.txt')
pair_graph = graph_load('../data/corpus/ent_pair.gpickle')

### Generate occurrence dictionary if needed

In [11]:
# Generate sentence file with line number
!grep -n '' ../data/corpus/small_sent_reformed.txt > ../data/corpus/small_sent_line.txt

In [None]:
# Generate occurrence file
# To run the code in the backend, use the gen_occur.py in the "py" folder
p = MultiProcessing()
occur_dict = p.run(lambda: Occurrence('../data/corpus/wordtree.json', '../data/corpus/keyword_f.txt'), open('../data/corpus/small_sent_line.txt').readlines(), 8, occurrence_post_operation)
occurrence_dump('../data/corpus/occur.json', occur_dict)

In [None]:
# Remove the file with line number
!rm ../data/corpus/small_sent_line.txt

### Generate co-occurrence list if needed

In [4]:
# Generate co_occurrence file
co_occur_list = gen_co_occur(occur_dict, len(sent_list), word2idx_dict)
co_occur_dump('../data/corpus/ent_co_occur.txt', co_occur_list)

### Generate co-occurrence graph if needed

In [None]:
# Generate pair graph (about 5 minutes)
pair_graph = build_graph(co_occur_list, keyword_list)
graph_dump(pair_graph, '../data/corpus/ent_pair.gpickle')

## Play around in the below

### Test of highly related pairs

In [5]:
# Helper functions


def find_highly_related_keyword(g:nx.Graph, keyword:str, word2idx_dict:Dict[str, int], keyword_list:List[str]):
    neighbors = g.neighbors(word2idx_dict[keyword])
    related_kws = [keyword_list[idx] for idx in neighbors]
    print(related_kws)

def mark_sent_in_html(sent:str, keyword_list:List[str], is_entity:bool=True):
    reformed_sent = sent.split() if is_entity else sent_lemmatize(sent.replace('-', ' - '))
    reformed_keywords = [[k] for k in keyword_list] if is_entity else [k.replace('-', ' - ').split() for k in keyword_list]
    mask = np.zeros(len(reformed_sent), dtype=np.bool)
    for k in reformed_keywords:
        begin_idx = 0
        while reformed_sent[begin_idx:].count(k[0]) > 0:
            begin_idx = reformed_sent.index(k[0], begin_idx)
            is_good = True
            i = 0
            for i in range(1, len(k)):
                if begin_idx + i >= len(reformed_sent) or reformed_sent[begin_idx + i] != k[i]:
                    is_good = False
                    break
            if is_good:
                mask[begin_idx:begin_idx+i+1] = True
            begin_idx += (i+1)
    i = 0
    insert_idx = 0
    while i < len(mask):
        if mask[i] and (i == 0 or mask[i-1] == False):
            reformed_sent.insert(insert_idx, '<font style=\"color:red;\">')
            insert_idx += 2
            i += 1
            while i < len(mask) and mask[i]:
                i += 1
                insert_idx += 1
            reformed_sent.insert(insert_idx, '</font>')
            insert_idx += 1
        insert_idx += 1
        i += 1
    return ' '.join(reformed_sent)

def gen_co_occur_report(report_file:str, g:nx.Graph, keyword:str, word2idx_dict:Dict[str, int], keyword_list:List[str], occur_dict:Dict[str, set], sent_list:List[str], is_entity:bool=True, kw_dist_max:int=6):
    neighbors = g.neighbors(word2idx_dict[keyword])
    related_kws = [keyword_list[idx] for idx in neighbors]
    content = ['<a href=\"#%s__%s\">%s, %s</a><br>' % (keyword, kw, keyword, kw) for kw in related_kws]
    for kw in related_kws:
        content.append('<a id=\"%s__%s\"><h1>%s, %s</h1></a> ' % (keyword, kw, keyword, kw))
        sents = [sent_list[i] for i in occur_dict[keyword] & occur_dict[kw]]
        if is_entity:
            sents = [sent.split() for sent in sents]
            sents = [' '.join(sent) for sent in sents if sent.count(keyword) == 1 and sent.count(kw) == 1 and abs(sent.index(keyword) - sent.index(kw)) <= kw_dist_max]
        content += ['%s<br><br>' % mark_sent_in_html(sent, [keyword, kw], is_entity=is_entity) for sent in sents]
    
    my_write(report_file, content)

In [3]:
# Generate highly related subgraph
sub_g = get_subgraph(pair_graph, 0.3, 3)

In [9]:
neighbor_test_pairs = [('python', 'java'), ('stack', 'queue')]

title = []
content = []
for pair in neighbor_test_pairs:
    mid_set = set(sub_g.neighbors(word2idx_dict[pair[0]])) & set(sub_g.neighbors(word2idx_dict[pair[1]]))
    if not mid_set:
        print('%s and %s fail in one hop relation' % (pair[0], pair[1]))
        continue
    title.append('<a href=\"#%s__%s\">%s, %s</a><br>' % (pair[0], pair[1], pair[0], pair[1]))
    content.append('<a id=\"%s__%s\"></a> <h1>%s, %s</h1>' % (pair[0], pair[1], pair[0], pair[1]))
    for mid in mid_set:
        mid_text = keyword_list[mid]
        content.append('<h2>%s, %s</h2>' % (pair[0], mid_text))
        temp_sents = [sent_list[i].split() for i in occur_dict[pair[0]] & occur_dict[mid_text]]
        temp_sents = [' '.join(sent) for sent in temp_sents if sent.count(pair[0]) == 1 and sent.count(mid_text) == 1 and abs(sent.index(pair[0]) - sent.index(mid_text)) <= 6]
        mark_sents = []
        for sent in temp_sents:
            doc = nlp(sent)
            tokens = [word.text for word in doc]
            try:
                idx1, idx2 = tokens.index(pair[0]), tokens.index(mid_text)
            except:
                continue
            if doc[idx1].dep_ == 'nsubj' or doc[idx1].dep_ == 'dobj' or doc[idx2].dep_ == 'nsubj' or doc[idx2].dep_ == 'dobj':
                mark_sents.append(sent)
        content += ['%s<br><br>' % mark_sent_in_html(sent, [pair[0], mid_text]) for sent in mark_sents]
        
        content.append('<h2>%s, %s</h2>' % (pair[1], mid_text))
        temp_sents = [sent_list[i].split() for i in occur_dict[pair[1]] & occur_dict[mid_text]]
        temp_sents = [' '.join(sent) for sent in temp_sents if sent.count(pair[1]) == 1 and sent.count(mid_text) == 1 and abs(sent.index(pair[1]) - sent.index(mid_text)) <= 6]
        mark_sents = []
        for sent in temp_sents:
            doc = nlp(sent)
            tokens = [word.text for word in doc]
            try:
                idx1, idx2 = tokens.index(pair[1]), tokens.index(mid_text)
            except:
                continue
            if doc[idx1].dep_ == 'nsubj' or doc[idx1].dep_ == 'dobj' or doc[idx2].dep_ == 'nsubj' or doc[idx2].dep_ == 'dobj':
                mark_sents.append(sent)
        content += ['%s<br><br>' % mark_sent_in_html(sent, [pair[1], mid_text]) for sent in mark_sents]

        # content.append('<h2>%s, %s, %s</h2>' % (pair[0], mid_text, pair[1]))
        # temp_sents = [sent_list[i].split() for i in occur_dict[pair[0]] & occur_dict[mid_text] & occur_dict[pair[1]]]
        # temp_sents = [' '.join(sent) for sent in temp_sents if sent.count(pair[0]) == 1 and sent.count(pair[1]) == 1 and sent.count(mid_text) == 1]
        # content += ['%s<br><br>' % mark_sent_in_html(sent, [pair[0], mid_text, pair[1]]) for sent in temp_sents]
my_write('overlap_test.html', title + content)

In [None]:
sub_g.edges[word2idx_dict['python'], word2idx_dict['just-in-time compilation']]

In [8]:
gen_co_occur_report('ds_co_occur.html', sub_g, 'data_structure', word2idx_dict, keyword_list, occur_dict, sent_list, 6)

In [10]:
find_dependency_path('efficient_point location in the sinr diagram , i.e. , building a data_structure to determine , for a query_point , whether any transmitter is heard there , and if so , which one , has been recently investigated .', 'data_structure', 'query_point')

'i_dobj prep pobj'

### Analyze the sentences with OLLIE, Stanford OpenIE or OpenIE5

In [None]:
# test_data = my_read('../data/test/co_occur_test.txt')
# test_data = [data.split(',') for data in test_data]
# test_dict = {data[0] : data[1:] for data in test_data}

In [None]:
# test_sent_dict = {central_kw : set() for central_kw in test_dict}
# for central_kw, kws in test_dict.items():
#     for kw in kws:
#         test_sent_dict[central_kw] |= occur_dict[kw]
#     test_sent_dict[central_kw] &= occur_dict[central_kw]

# for central_kw, sents in test_sent_dict.items():
#     content = [sent_list[i] for i in sents]
#     my_write('../data/temp/%s_wiki.txt' % central_kw.replace(' ', '_'), content)

In [None]:
# test_lines = occur_dict['python'] & (occur_dict['java'] | occur_dict['ruby'])
# my_write('python_java_ruby.txt', [sent_list[i] for i in test_lines])

In [None]:
# openie_data = my_read('../data/temp/pl_wiki_ollie_triple.txt')
# # openie_data = my_read('pjr_ollie_triple.txt')
# # keywords = set(['data structure', 'binary tree', 'hash table', 'linked list'])
# keywords = set(['programming language', 'python', 'java', 'javascript', 'lua', 'scala', 'lisp', 'php', 'ruby', 'smalltalk'])
# # keywords = set(['python', 'java', 'ruby'])

# qualified_triples = []
# for data in openie_data:
#     if data:
#         arg1, rel, arg2 = data.split(';')
#         for kw in keywords:
#             if kw in arg1:
#                 for kw in keywords:
#                     if kw in arg2:
#                         qualified_triples.append(data)
#                         break
#                 break
# my_write('pl_wiki_ollie_triple_f.txt', qualified_triples)

In [None]:
# data_structure_idx = occur_dict['data structure']

In [None]:
# len(co_occur_set)

In [None]:
# co_occur_set = {}
# for keyword, idx_set in occur_dict.items():
#     intersection = idx_set & data_structure_idx
#     if intersection:
#         co_occur_set[keyword] = list(intersection)

In [None]:
# sorted_co_occur_list = sorted(co_occur_set.items(), key=lambda x: len(x[1]), reverse=True)[:100]
# sorted_co_occur_count = [(word, len(idx)) for word, idx in sorted_co_occur_list]

In [None]:
# sorted_co_occur_count[:40]

In [None]:
# 'b-tree' in co_occur_set

In [None]:
# # sent_list[co_occur_set['b-tree'][0]]
# temp_list = [sent_list[idx] for idx in co_occur_set['b-tree']]
# my_write('ds_bt_sent.txt', temp_list)