In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import Layout, GridspecLayout
from IPython.display import display
import time
import threading
import networkx as nx
import nltk
from nltk import word_tokenize,sent_tokenize,ne_chunk
import json
from textblob import TextBlob
import openie5_client
import pdfplumber
import re
import copy

from IPython.display import clear_output

from fuzzywuzzy import fuzz
from fuzzywuzzy import process
import spacy
import textacy
from sklearn.metrics.pairwise import cosine_similarity
from nltk.corpus import stopwords
from nltk import download
download('stopwords')  # Download stopwords list.
stop_words = stopwords.words('english')
import gensim.downloader as api

import warnings
warnings.filterwarnings('ignore')


[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Zenya\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
class KnowledgeGraph:
    """ Knowledge Graph Object, stores variables and information about the knowledge graph. """
    
    def __init__(self, triplets_df):
        """ 
        Takes in a pandas dataframe with columns confidence, sentence, subject, relation, object. 
        ---------------------------------------------------
        
        Stores the following information:
        > ALL = "ALL"
        > triplets_map
        > numbers_map
        > map_triplets_df
        > topic_triplets_df
        > G
        > pos
        > edges
        > topics_list
        > relation_list
        > subject_list
        > object_list
        """
        self.ALL = "ALL"
        self.triplets_df = triplets_df
        self.triplets_map, self.numbers_map = self.get_triplets_map()
        self.map_triplets_df = self.map_triplets()
        self.topic_triplets_df = self.find_initial_nouns()
        
        # generate main graph
        self.G, self.pos, self.edges, topics = self.create_kg()
        self.topics_list = self.unique_sorted_values_plus_ALL(pd.Series(topics))
        self.relation_list = self.unique_sorted_values_plus_ALL(pd.Series(list(self.edges.values())))
        self.subject_list = self.unique_sorted_values_plus_ALL(self.triplets_df.subject, inc_all=False)
        self.object_list = self.unique_sorted_values_plus_ALL(self.triplets_df.object, inc_all=False)
        
    def create_kg(self):  
        """ 
        Create knowledge graph G with the triplets' mapped numerical value. 
        --------------------------
        
        KIV: 
        -- Store previous plot as a variable. Compare current graph with previous plot when replotting. If same, reuse the previous plot. 
            Most useful when plotting entire graph. Plotting subsets are fast and do not require this.
        """    
        subject_nodes = []
        object_nodes = []
        topic_nodes = []    
        relationship_edges = []
        labels = []    
        topic_edges = []    
        subject_nouns = []
        object_nouns = []
        for row in range(len(self.topic_triplets_df)):
            triplet = self.topic_triplets_df.loc[row]
            subject_nodes.append(triplet['subject_map'])
            object_nodes.append((triplet['object_map'], triplet['sentence']))        
            relationship_edges.append((triplet['subject_map'],triplet['object_map']))
            labels.append((triplet['relation']))        
            for topic in triplet['subject_tags']:
                topic_nodes.append(topic)
                topic_edges.append((topic, triplet['subject_map']))            
            for topic in triplet['object_tags']:
                topic_nodes.append(topic)
                topic_edges.append((topic, triplet['object_map']))        
        G = nx.MultiDiGraph()    
        # add nodes
        for node in subject_nodes:
            G.add_node(node)        
        for node in object_nodes:
            G.add_node(node[0], sentence = node[1])        
        for node in topic_nodes:
            G.add_node(node)    
        # add edges
        edgeCount = 0
        for edge in topic_edges:
            G.add_edge(edge[0], edge[1])        
        for edge in relationship_edges:
            G.add_edge(edge[0], edge[1], relation = labels[edgeCount])
            edgeCount += 1        
        edges = dict(zip(relationship_edges, labels))
        # choose your layout 
        pos = nx.spring_layout(G)
        #pos = nx.drawing.layout.multipartite_layout(G)
        #pos = nx.nx_agraph.graphviz_layout(G, prog='neato')
        #pos = nx.nx_agraph.graphviz_layout(G, prog='dot')
        #pos = nx.nx_agraph.graphviz_layout(G, prog='twopi')
        #pos = nx.nx_agraph.graphviz_layout(G, prog='fdp')
        return G, pos, edges, topic_nodes
    
    def display_full_graph(self):
        return self.draw_graph_triplets(self.get_triplet_position(), self.get_triplet_edges())
    
    def display_filtered_graph(self, rel=None, top=None):
        """
        Eventhandler for filtering, calls relevant function depending on whether
        relation and/or topic is filtered for. 
        """
        
        if (rel is None) and (top is None):
            self.display_full_graph()
        elif (top is None):
            self.draw_rel_filter(rel)
        elif (rel is None):
            self.draw_topic_filter(top)
        else:
            self.draw_double_filter(rel, top)
            
    def draw_rel_filter(self, rel):
        """
        Draws a graph filtered by relation.
        """
        
        try:
            filtered_graph = self.G.subgraph(self.get_nodes(rel))
            plt.figure(figsize=(16,8))
            filtered_pos = {k:v for k,v in self.pos.items() if k in filtered_graph.nodes}
            filtered_graph = nx.relabel_nodes(filtered_graph, self.numbers_map, copy = True)
            if nx.is_empty(filtered_graph):
                print('No Results')
                return
            filtered_pos = {self.numbers_map[k]:v for k,v in filtered_pos.items()}
            nx.draw(filtered_graph, filtered_pos, with_labels = True)
            plt.show()
        except KeyError:
            print('No Results')
        
    def draw_topic_filter(self, topic):
        """
        Draws a graph filtered by topic. 
        """
        
        try:
            data_successors = nx.dfs_successors(self.G,topic)
            successor_list = data_successors.values()
            allsuccessors = [item for sublist in successor_list for item in sublist]
            allsuccessors.append(topic)
            filteredG = self.G.subgraph(allsuccessors)
            if nx.is_empty(filteredG):
                print('No Results')
                return
            filteredG_edges = list(filteredG.edges)
            filtered_pos = {k:v for k,v in self.pos.items() if k in filteredG.nodes}
            filtered_edges = {}
            triplet_edges = self.get_triplet_edges()
            for u,v,e, in filteredG_edges:
                if (u,v) in triplet_edges:
                    filtered_edges[(u,v)] = triplet_edges[(u,v)]
            plt.figure(figsize=(16,8))
            filteredG = nx.relabel_nodes(filteredG, self.numbers_map, copy = True)
            #print('filtered_post.items()', filtered_pos.items())
            relabel_pos = {}
            for k,v in filtered_pos.items():
                if type(k) == np.int32:
                    #print(k, type(k))
                    relabel_pos[self.numbers_map[k]] = v
                else:
                    relabel_pos[k] = v  
            nx.draw_networkx_nodes(filteredG, relabel_pos)
            nx.draw_networkx_labels(filteredG, relabel_pos)
            nx.draw_networkx_edges(filteredG, relabel_pos , alpha=0.5)
            nx.draw_networkx_edge_labels(filteredG, relabel_pos , edge_labels = filtered_edges, font_color='red')
            plt.show()
        except KeyError:
            print('No Results')
    
    def draw_double_filter(self, rel, topic):
        """
        Draws a graph filtered by both topic and relation.
        """
        
        try:
            data_successors = nx.dfs_successors(self.G,topic)
            successor_list = data_successors.values()
            allsuccessors = [item for sublist in successor_list for item in sublist]
            allsuccessors.append(topic)
            filteredG = self.G.subgraph(allsuccessors)
            filteredG = self.G.subgraph(self.get_nodes(rel, filteredG))
            if nx.is_empty(filteredG):
                print('No Results')
                return
            filteredG_edges = list(filteredG.edges)
            filtered_pos = {k:v for k,v in self.pos.items() if k in filteredG.nodes}
            filtered_edges = {}
            triplet_edges = self.get_triplet_edges()
            for u,v,e, in filteredG_edges:
                if (u,v) in triplet_edges:
                    filtered_edges[(u,v)] = triplet_edges[(u,v)]
            plt.figure(figsize=(16,8)) 
            filteredG = nx.relabel_nodes(filteredG, self.numbers_map, copy = True)
            relabel_pos = {}
            for k,v in filtered_pos.items():
                if type(k) == np.int32:
                    relabel_pos[self.numbers_map[k]] = v
                else:
                    relabel_pos[k] = v  
            nx.draw_networkx_nodes(filteredG, relabel_pos)
            nx.draw_networkx_labels(filteredG, relabel_pos)
            nx.draw_networkx_edges(filteredG, relabel_pos , alpha=0.5)
            nx.draw_networkx_edge_labels(filteredG, relabel_pos , edge_labels = filtered_edges, font_color='red')
            plt.show()
        except KeyError:
            print('No Results')

    def get_triplets_map(self):
        """ 
        Maps object and subject strings to a unique number. 
        """
        
        triplets_map = {}
        numbers_map = {}
        index = 0
        for row in range(len(self.triplets_df)):
            triplet = self.triplets_df.loc[row]
            sub = triplet['subject'] 
            obj = triplet['object']
            if sub not in triplets_map:
                triplets_map[sub] = index
                numbers_map[index] = sub
                index += 1
            if obj not in triplets_map:
                triplets_map[obj] = index
                numbers_map[index] = obj
                index += 1
                
        return triplets_map, numbers_map
    
    def map_triplets(self):
        """ 
        Returns triplets_df with its mapped values. 
        """
        
        df = self.triplets_df.copy(deep = True)
        for row in range(len(df)):
            triplet = df.loc[row]
            sub = triplet['subject'] 
            obj = triplet['object']
            df.loc[row, 'subject_map'] = self.triplets_map[sub]
            df.loc[row, 'object_map'] = self.triplets_map[obj]
        df.subject_map = df.subject_map.astype(int)
        df.object_map = df.object_map.astype(int)
        
        return df
    
    def find_initial_nouns(self):
        """
        Finds the nouns for each subject/object to generate topic nodes.
        """
        
        df = self.map_triplets_df.copy(deep = True)
        df['subject_tags'] = np.empty((len(df), 0)).tolist()
        df['object_tags'] = np.empty((len(df), 0)).tolist()
        proper_nouns = [] # proper nouns
        subject_nouns = []
        object_nouns = []
        
        for row in range(len(df)):
            triplet = df.loc[row]                
            sentence = triplet['sentence'] 
            tokens = self.split_tokens(sentence)
            postags = self.POS_tagging(tokens)
            nounphrases = self.phrase_extraction(sentence)
            #postag_dict[sentence] = postags        
            sub = triplet['subject']
            obj = triplet['object']        
            subject_tags = []
            object_tags = []
            for tag in postags:
                if tag[1] == 'NNPS' or tag[1] == 'NNP':                
                    noun = tag[0]                
                    if noun in sub and noun not in subject_tags:
                        subject_tags.append(noun)
                        proper_nouns.append(noun)
                    if noun in obj and noun not in object_tags:
                        object_tags.append(noun)
                        proper_nouns.append(noun)
            for noun in nounphrases:
                if noun in sub and noun not in subject_tags:
                    if noun.upper() not in proper_nouns:
                        subject_tags.append(noun)
                    else:
                        subject_tags.append(noun.upper())
                if noun in obj and noun not in object_tags:
                    if noun.upper() not in proper_nouns:
                        object_tags.append(noun)
                    else:
                        object_tags.append(noun.upper())
                
            subject_nouns.append(subject_tags)
            object_nouns.append(object_tags)        
        df['subject_tags'] = subject_nouns
        df['object_tags'] = object_nouns 
        return df
    
    def find_nouns(self, text):
        """
        Takes in a string and returns proper nouns and noun phrases. 
        """
        
        tokens = self.split_tokens(text)
        postags = self.POS_tagging(tokens) 
        nounphrases = self.phrase_extraction(text)
        new_topics = []
        for tag in postags:
            if tag[1] == 'NNPS' or tag[1] == 'NNP':               
                noun = tag[0]                
                if noun not in new_topics:
                    new_topics.append(noun)
        for noun in nounphrases:
            if noun not in new_topics:
                new_topics.append(noun)
        return new_topics
    
    def get_triplet_position(self):    
        ''' return a dictionary of the position of triplets in the knowledge graph '''
        
        triplet_position = {}    
        for position in self.G:
            if position in self.numbers_map.keys():
                triplet_position[self.numbers_map[position]] = self.pos[position]
            elif position in self.pos.keys():
                triplet_position[position] = self.pos[position]
            #else:
                #triplet_position[position] = pos[position]
        return triplet_position
    
    def get_triplet_edges(self):    
        ''' return a dictionary of the relationship of triplets in the knowledge graph '''
        triplet_edges = {}    
        for k in self.edges:
            #print(k)
            sub = self.numbers_map[k[0]]
            obj = self.numbers_map[k[1]]
            rel = self.edges[k]    
            triplet_edges[(sub,obj)] = rel    
        return triplet_edges
    
    def draw_graph_triplets(self, triplet_pos, triplet_edges):    
        ''' draw knowledge graph G with the triplets' mapped numerical value '''
        relabel_G = nx.relabel_nodes(self.G, self.numbers_map, copy = True)
        plt.figure(figsize=(80,40))    
        nx.draw_networkx_nodes(relabel_G, triplet_pos, node_size=40)
        nx.draw_networkx_labels(relabel_G, triplet_pos)
        nx.draw_networkx_edges(relabel_G, triplet_pos, alpha=0.5) 
        #nx.draw_networkx_edges(G, triplet_pos, alpha=0.5, with_labels = True)    
        nx.draw_networkx_edge_labels(relabel_G, triplet_pos, edge_labels = triplet_edges, font_color='red')
        plt.show()
    
    def get_nodes(self, edge, filteredG=None):
        ''' get nodes connected by specified edge'''
        node = []
        if (filteredG is None):
            filteredG = self.G
        for u,v,e in filteredG.edges(data=True):
            if e == {}:
                continue
            if e['relation'] == edge:
                if u not in node:
                    node.append(u)
                if v not in node:
                    node.append(v)
        return node
    
    def unique_sorted_values_plus_ALL(self, array, inc_all = True):
        '''generates a unique list with ALL appended to the top'''
        unique = array.unique().tolist()
        unique.sort()
        if inc_all:
            unique.insert(0, self.ALL)
        return unique
    
    def add_triplet_to_df(self,sentence, sub, rel, obj, sub_map, obj_map, sub_top, obj_top):
        """
        Adds a triplet set to topic_triplets_df
        """
        
        new_entry = {}
        new_entry['confidence'] = 1
        new_entry['sentence'] = sentence
        new_entry['subject'] = sub
        new_entry['relation'] = rel
        new_entry['object'] = obj
        new_entry['subject_map'] = sub_map
        new_entry['object_map'] = obj_map
        new_entry['subject_tags'] = sub_top
        new_entry['object_tags'] = obj_top
        self.topic_triplets_df = self.topic_triplets_df.append(new_entry, ignore_index=True) 
    
    def add_as_subject(self, text):
        """
        Records a string as a subject. Adds as a node if not yet a node.
        """
        
        self.subject_list.append(text)
        if text not in self.triplets_map:
            return self.add_as_node(text)
    
    def add_as_object(self, text):
        """
        Records a string as an object. Adds as a node if not yet a node.
        """
        
        self.object_list.append(text)
        if text not in self.triplets_map:
            return self.add_as_node(text)
    
    def add_as_node(self, text):
        """
        Adds a string and its relevant topics as a node.
        """
        
        # extract topic from text
        new_topics = self.find_nouns(text)
        for noun in new_topics:
            if noun not in self.topics_list and noun.upper() not in self.topics_list:
                # if topic doesn't exist, add topic as node
                self.topics_list.append(noun)
                self.G.add_node(noun)
        # map text to an index (not added to self.map_triplets_df, self.triplets_map)
        if text not in self.numbers_map.values():
            index = max(self.numbers_map.keys())+1
            self.numbers_map[index] = text
            self.triplets_map[text] = index
        else:
            index = self.triplets_map[text]
        # add index to graph
        self.G.add_node(index)
        # create edge between text and topic
        for topic in new_topics:
            self.G.add_edge(topic, index)
        # update pos
        self.pos = nx.spring_layout(self.G, pos=self.pos)
    
    def add_as_edge(self, sub, obj, rel):
        """ 
        Adds edge to G, edges, and also to relation_list if not yet in relation_list.
        """
        
        if sub not in self.subject_list:
            self.add_as_subject(sub)
        if obj not in self.object_list:
            self.add_as_object(obj)
        sub_id = self.triplets_map[sub]
        obj_id = self.triplets_map[obj]
        if self.G.has_edge(sub_id, obj_id):
            return False
        self.G.add_edge(sub_id, obj_id, relation = rel)
        if rel not in self.relation_list:
            self.relation_list.append(rel)
        self.edges[(self.triplets_map[sub], self.triplets_map[obj])] = rel
        return True
    
    def edit_node(self, old_node, new_node):
        """
        Changes the stored string data of a given node. 
        Modifies data in graph and dataframe.
        """
        
        value = self.triplets_map.pop(old_node)
        # change name in triplets_map
        self.triplets_map[new_node] = value
        # change name in numbers_map
        self.numbers_map[value] = new_node
        # change name in subject_list
        try:
            idx = self.subject_list.index(old_node)
            if new_node not in self.subject_list:
                self.subject_list[idx] = new_node
            else:
                self.subject_list.pop(idx)
        except ValueError:
            pass
        # change name in object_list
        try:
            idx = self.object_list.index(old_node)
            if new_node not in  self.object_list:
                self.object_list[idx] = new_node
            else:
                self.object_list.pop(idx)
        except ValueError:
            pass
        # change name in dataframe
        self.topic_triplets_df.loc[self.topic_triplets_df.subject == old_node, "subject"] = new_node
        self.topic_triplets_df.loc[self.topic_triplets_df.object == old_node, "object"] = new_node
        
    def edit_edge(self, old_rel, new_rel):
        """
        Changes the stored string data of a given edge. 
        Modifies data in graph and dataframe.
        """
        
        # change name in relation_list
        idx = self.relation_list.index(old_rel)
        if new_rel not in self.relation_list:
            self.relation_list[idx] = new_rel
        else:
            self.relation_list.pop(idx)
        # change name in edges
        self.edges = {k:(new_rel if old_rel == v else v) for k,v in self.edges.items()}
        # change name in graph
        for n, nbrsdict in self.G.adjacency():
            for nbr, keydict in nbrsdict.items():
                for key, eattr in keydict.items():
                    keydict[key] = {k:(new_rel if v==old_rel else v) for k,v in eattr.items()}
        # change name in dataframe
        self.topic_triplets_df.loc[self.topic_triplets_df.subject == old_rel, "relation"] = new_rel
        
    def edit_topic(self, old_topic, new_topic):
        """
        Changes the stored string data of a given topic. 
        Modifies data in graph and dataframe.
        """
        
        # change name in topics_list
        idx = self.topics_list.index(old_topic)
        # change name in pos
        self.pos[new_topic] = self.pos.pop(old_topic)
        if new_topic not in self.topics_list:
            self.topics_list[idx] = new_topic
        else:
            self.topics_list.pop(idx)
        # change name in graph
        nx.relabel_nodes(self.G, {old_topic:new_topic}, copy=False)
        # change name in dataframe
        for idx,row in self.topic_triplets_df.iterrows():
            if (old_topic in row.subject_tags):
                i = row.subject_tags.index(old_topic)
                row.subject_tags[i] = new_topic
            if (old_topic in row.object_tags):
                i = row.object_tags.index(old_topic)
                row.object_tags[i] = new_topic
                
    def edit_ans(self, sub, rel, obj, ans):
        """
        Changes the stored string data of a given answer. 
        Modifies data in dataframe.
        """
        
        # check if row exists
        cond = (self.topic_triplets_df.subject == sub) & (self.topic_triplets_df.relation == rel) & (self.topic_triplets_df.object == obj)
        if any(cond):
            self.topic_triplets_df.loc[cond, "sentence"] = ans
            return True
        else:
            return False
    
    def remove_node(self, text):
        """
        Removes a node from the graph and dataframe.
        """
        
        # get index
        index = self.triplets_map[text]
        # find node's edges and remove them, remove edges from relation, refresh lists
        for k,v in list(self.edges.items()):
            _ = self.edges.pop(k) if (index in k) else None
        affected_rel = []
        # remove from successors
        for successor in list(self.G.succ[index]):
            self.G.remove_edge(index, successor)
            affected_rel.append([index, successor])
        # remove from predecessors
        for predecessor in list(self.G.pred[index]):
            self.G.remove_edge(predecessor, index)
            affected_rel.append([predecessor, index])
        # drop from df    
        for item in affected_rel:
            try: # quick fix for topic not being in numbers_map
                sub = self.numbers_map[item[0]]
                obj = self.numbers_map[item[1]]
                row = self.topic_triplets_df[(kg.topic_triplets_df.subject == sub)&(kg.topic_triplets_df.object == obj)].index[0]
                self.topic_triplets_df.drop(row)
            except KeyError:
                pass
        self.relation_list = self.unique_sorted_values_plus_ALL(pd.Series(self.edges.values()))
        # remove from triplets map, from numbers_map
        del self.triplets_map[text]
        del self.numbers_map[index]
        # remove from pos
        del self.pos[index]
        self.G.remove_node(index)
        # remove from subject and object list
        if text in self.subject_list:
            self.subject_list.remove(text)
        if text in self.object_list:
            self.object_list.remove(text)        
        # find node's topic, check if topic is empty, if yes, remove topic from graph and filter
        new_topics = self.find_nouns(text)
        for noun in new_topics:
            # QUICK HACK FIX: Sometimes, noun phrases found may not be in list.
            # Try to fix if time permits
            try:
                if len(nx.dfs_successors(self.G, noun)) == 0:
                    self.G.remove_node(noun)
                    del self.pos[noun]
                    self.topics_list.remove(noun)
            except KeyError:
                pass
            
    def add_doc(self, filename, port='9080', **kwargs):
        """
        Takes in a path to the pdf document and a port number, and
        adds the data from the document to the graph.
        """
        
        doc = Document(filename)
        extractions = []
        with openie5_client.OpenIEClient(port=port, **kwargs) as extractor:
            #time.sleep(120) #change accordingly to how fast your pc is
            for sentence in doc.sentences:
                extractions.extend(extractor.extract(sentence))
            extractor.server.kill()
            
        def info(triplet):
            d = {}
            if len(triplet['extraction']['arg2s']) == 0:
                return d
            d['confidence'] = triplet['confidence']
            d['sentence'] = triplet['sentence']
            ex = triplet['extraction']
            d['subject'] = ex['arg1']['text']
            d['relation'] = ex['rel']['text']
            d['object'] = list(map(lambda x: x['text'],ex['arg2s']))
            return d
        extracted_processed = list(map(info, extractions))
        extracted_processed = list(filter(lambda x: x!={}, extracted_processed))
        l = []
        for triple in extracted_processed:
            if len(triple['object']) > 1:
                temp_l = tuple(triple['object'])
                for i in range(len(temp_l)):
                    temp_d = copy.deepcopy(triple)
                    del temp_d['object']
                    temp_d['object'] = temp_l[i]
                    l.append(temp_d)
            else:
                triple['object'] = triple['object'][0]
                l.append(triple)
        add_df = pd.DataFrame(l)
        self.triplets_df = pd.concat([self.triplets_df,add_df], ignore_index=True)
        self.triplets_map, self.numbers_map = self.get_triplets_map()
        self.map_triplets_df = self.map_triplets()
        self.topic_triplets_df = self.find_initial_nouns()
        
        # generate main graph
        self.G, self.pos, self.edges, topics = self.create_kg()
        self.topics_list = self.unique_sorted_values_plus_ALL(pd.Series(topics))
        self.relation_list = self.unique_sorted_values_plus_ALL(pd.Series(list(self.edges.values())))
        self.subject_list = self.unique_sorted_values_plus_ALL(self.triplets_df.subject, inc_all=False)
        self.object_list = self.unique_sorted_values_plus_ALL(self.triplets_df.object, inc_all=False)
    
    def remove_edge(self, sub, obj):
        """ 
        Attempts to remove edge from graph.
        """
        
        sub_id = self.triplets_map[sub]
        obj_id = self.triplets_map[obj]
        try:
            del self.edges[(sub_id, obj_id)]
            self.G.remove_edge(sub_id, obj_id)
            index = self.topic_triplets_df[(kg.topic_triplets_df.subject == sub)&(kg.topic_triplets_df.object == obj)].index[0]
            self.topic_triplets_df = self.topic_triplets_df.drop(index)
        except:
            pass
        self.relation_list = self.unique_sorted_values_plus_ALL(pd.Series(self.edges.values()))
        
    def get_sentence(self, sub, rel, obj):
        """ 
        Given a set of  triplets. extract the sentence from the dataframe. 
        """
        
        cond = (self.topic_triplets_df.subject == sub) & (self.topic_triplets_df.relation == rel) & (self.topic_triplets_df.object == obj)
        if any(cond):
            return self.topic_triplets_df.loc[cond, "sentence"]
        else: 
            return None
        
    def get_tags(self, item, loc):
        tags = []
        if loc == "subject":
            idx = (self.topic_triplets_df.subject == item).idxmax()
            tags = self.topic_triplets_df.loc[idx, "subject_tags"]
        if loc == "object":
            idx = (self.topic_triplets_df.object == item).idxmax()
            tags = self.topic_triplets_df.loc[idx, "object_tags"]
        return tags
    
    def split_tokens(self, text):
        """ 
        Split text into tokens.
        """
        
        text = text.replace('/',' ')
        tokens = nltk.word_tokenize(text)
        
        return tokens
    
    def POS_tagging(self, text):
        """ 
        Generate Part of speech tagging of the text.
        """
        
        POSofText = nltk.tag.pos_tag(text)
        
        return POSofText
    
    def phrase_extraction(self, text):
        """
        Noun phrase extraction.        
        """
        blob = TextBlob(text)
        return blob.noun_phrases  

In [18]:
class Dashboard:
    """ 
    Uses a knowledge graph object to create the dashboard using ipython widgets. 
    
    Contains the following widgets:
    -------------------------------
    
    Main Class:
        > Knowledge Graph as 'kg'
        > Main output as 'output'
    
    Filter Widget:
        > dropdown_relation
        > dropdown_entity
        > refresh_btn
        > filter (aggregated widgets)
        
    Adding Widget:
        > node_text
        > subject_text
        > object_text
        > relation_text
        > answer_text
        > triplet_button
        > triplet_output
        > add (aggregated widgets)
        
    Add Document Widget:
        > doc_text
        > doc_button
        > doc_output
        > doc (aggregated widgets)
        
    Editing Widget (Tab Interface):
        > edit_tab
        
    Editing Widget (Node):
        > edit_node_dropdown
        > edit_node_confirm_button
        > edit_node_textbox
        > edit_node
        
    Editing Widget (Relation):
        > edit_rel_dropdown
        > edit_rel_confirm_button
        > edit_rel_textbox
        > edit_rel
        
    Editing Widget (Topic):
        > edit_topic_dropdown
        > edit_topic_confirm_button
        > edit_topic_textbox
        > edit_topic 
        
    Editing Widget (Candidate Answer):
        > edit_ans_sub_filter 
        > edit_ans_rel_filter 
        > edit_ans_obj_filter
        > edit_ans_confirm_button 
        > edit_ans_field
        > edit_ans
        
    Editing Widget (Overall)
        > edit (aggregated widget)
        
    Removing Widget:
        > node_text2
        > edge_subject_text2
        > edge_object_text2 
        > edge_relation_text2
        > node_button2 
        > edge_button2 
        > remv (aggregated widget)
        
    Querying Widget:
        > Q
        > query_input
        > query_button
        > query_response_textarea
        > query
        
    """
    
    def __init__(self, kg):    
        self.kg = kg
        self.ALL = "ALL"
        
        self.initialize_filters()
        self.initialize_adds()
        self.initialize_edit()
        self.initialize_remv()
        self.initialize_add_doc()
        self.initialize_query()
        self.output = widgets.Output(layout=Layout(width='100%', height='100%'))
        
        self.grid = GridspecLayout(18, 11)
        self.grid[1:13,:7] = self.output
        self.grid[0,:7] = self.filters
        self.grid[0:4,8:] = self.add
        self.grid[4:6,8:] = self.doc
        self.grid[6:14,8:] = self.remv
        self.grid[14:,:5] = self.edit
        self.grid[14:,6:10] = self.query
        display(self.grid)
        
        with self.output:
            self.kg.display_full_graph()

    def initialize_filters(self):
        """ Initializes the filters for the graph"""
        self.dropdown_relation = widgets.Combobox(description = "Filter by relation", options = self.kg.relation_list,
                                    style = {'description_width': 'initial'}, layout = Layout(width='40%', height='auto'), 
                                                  value=self.ALL)
        self.dropdown_entity = widgets.Combobox(description = "Filter by topic", options = self.kg.topics_list,
                                          style = {'description_width': 'initial'}, layout = Layout(width='40%', height='auto'), 
                                                value=self.ALL)
        self.refresh_btn = widgets.Button(description = "Refresh Graph",
                                          style = {'description_width': 'initial'}, layout = Layout(width='20%', height='45%'))
        
        self.dropdown_relation.observe(self.dropdown_relation_eventhandler, names='value')
        self.dropdown_entity.observe(self.dropdown_entity_eventhandler, names='value')
        self.refresh_btn.on_click(self.refresh_btn_eventhandler)
        self.filters = widgets.HBox((self.dropdown_relation, self.dropdown_entity, self.refresh_btn),
                            layout=Layout(width='100%', height = '100%'))

        
    def initialize_adds(self):
        """ Initializes the adding portion of the dashboard. """
        
        cell_title = widgets.HTML("<b>Add Triplets</b>")
        
        subject_label = widgets.HTML("Subject", 
                                    layout = Layout(height='auto', width='auto', grid_area='subject_label'))
        self.subject_text = widgets.Combobox(placeholder = "Add new or existing subject.", 
                                             options = list(set(self.kg.subject_list).union(set(self.kg.object_list))), 
                                             layout = Layout(height='auto', width='auto', grid_area='subject_text'))
        relation_label = widgets.HTML("Relation", 
                                    layout = Layout(height='auto', width='auto', grid_area='relation_label'))
        self.relation_text = widgets.Combobox(placeholder = "Add add new or existing relation to subject/object.", 
                                              options = self.kg.relation_list, 
                                              layout = Layout(height='auto', width='auto', grid_area='relation_text'))
        object_label = widgets.HTML("Object", 
                                    layout = Layout(height='auto', width='auto', grid_area='object_label'))
        self.object_text = widgets.Combobox(placeholder = "Add new or existing object.",
                                            options = list(set(self.kg.subject_list).union(set(self.kg.object_list))),
                                            layout = Layout(height='auto', width='auto', grid_area='object_text'))
        answer_label = widgets.HTML("Candidate Answer", 
                                    layout = Layout(height='auto', width='auto', grid_area='answer_label'))
        self.answer_text = widgets.Textarea(placeholder='Write the candidate answer for the query here.',
                                            layout = Layout(height='5rem', width='auto', grid_area='answer_text'))

        self.triplet_button = widgets.Button(description="Add to KG",
                                             layout = Layout(width='auto', height='auto', grid_area='self.triplet_button'))

        # output widget to provide feedback
        self.triplet_output = widgets.Output(layout = Layout(width='auto', height='auto', grid_area='self.triplet_output'))
        
        self.triplet_button.on_click(self.triplet_button_click_eventhandler)
        
        display_cell = widgets.GridBox(children=[subject_label, self.subject_text,
                                                relation_label, self.relation_text,
                                                object_label, self.object_text,
                                                answer_label, self.answer_text,
                                                self.triplet_button, self.triplet_output],
                                      layout=Layout(
                                          grid_template_rows='20% 20% 20% 40%',
                                          grid_template_columns='30% 70%',
                                          grid_template_areas='''
                                              "subject_label subject_text"
                                              "relation_label relation_text"
                                              "object_label object_text"
                                              "answer_label answer_text"
                                              "button output_feedback"
                                              ''')
                                      )
        self.add = widgets.VBox((cell_title, display_cell))
        
    def initialize_edit(self):
        """ Initializes the editing portion of the dashboard. """
        cell_title = widgets.HTML("<b>Edit Components</b>")
        
        # node dropdown widget
        edit_node_label = widgets.HTML("Old Node",
                                        layout=Layout(width='auto', height='auto', grid_area='edit_node_label'))
        self.edit_node_dropdown = widgets.Combobox(options = list(set(self.kg.subject_list).union(set(self.kg.object_list))),
                                                   placeholder = 'Select node to edit',
                                                   layout=Layout(width='auto', height='auto', grid_area='edit_node_dropdown'))
        # node edit text box
        update_node_label = widgets.HTML("New Node",
                                         layout=Layout(width='auto', height='auto', grid_area='update_node_label'))
        self.edit_node_textbox = widgets.Text(placeholder = "Type the edited node here.",
                                              layout=Layout(width='auto', height='auto', grid_area='update_node_textbox'))
        # node confirm button
        self.edit_node_confirm_button = widgets.Button(description = "Edit Node",
                                                       layout=Layout(width='auto', height='auto', grid_area='edit_node_button'))
        self.edit_node_confirm_button.on_click(self.edit_node_eventhandler)
        self.edit_node_output = widgets.Output(layout=Layout(width='auto', height='auto', grid_area='edit_node_output'))
        self.edit_node = widgets.GridBox(children=[edit_node_label, self.edit_node_dropdown,
                                                   update_node_label, self.edit_node_textbox,
                                                   self.edit_node_confirm_button, self.edit_node_output],
                                         layout=Layout(
                                             grid_template_rows='auto auto auto',
                                             grid_template_columns='30% 70%',
                                             grid_template_areas='''
                                                 "edit_node_label edit_node_dropdown"
                                                 "update_node_label update_node_textbox"
                                                 "edit_node_button edit_node_output"
                                             ''')
                                        )
        
        
        # rel dropdown widget
        edit_rel_label = widgets.HTML("Old relation",
                                         layout=Layout(width='auto', height='auto', grid_area='edit_rel_label'))
        self.edit_rel_dropdown = widgets.Combobox(options = self.kg.relation_list[1:],
                                                  placeholder = 'Select relation to edit',
                                                  layout=Layout(width='auto', height='auto', grid_area='edit_rel_dropdown'))
        # rel confirm button
        self.edit_rel_confirm_button = widgets.Button(description = "Edit Relation",
                                                      layout=Layout(width='auto', height='auto', grid_area='edit_rel_button'))
        self.edit_rel_confirm_button.on_click(self.edit_rel_eventhandler)
        # rel edit text box
        update_rel_label = widgets.HTML("New relation",
                                        layout=Layout(width='auto', height='auto', grid_area='update_rel_label'))
        self.edit_rel_textbox = widgets.Text(placeholder = "Type the edited relation here.",
                                             layout=Layout(width='auto', height='auto', grid_area='update_rel_textbox'))
        self.edit_rel_output = widgets.Output(layout=Layout(width='auto', height='auto', grid_area='edit_rel_output'))
        self.edit_rel = widgets.GridBox(children=[edit_rel_label, self.edit_rel_dropdown,
                                                  update_rel_label, self.edit_rel_textbox,
                                                  self.edit_rel_confirm_button, self.edit_rel_output],
                                        layout=Layout(
                                            grid_template_rows='auto auto auto',
                                            grid_template_columns='30% 70%',
                                            grid_template_areas='''
                                                "edit_rel_label edit_rel_dropdown"
                                                "update_rel_label update_rel_textbox"
                                                "edit_rel_button edit_rel_output"
                                            ''')
                                        )
        
        # topic dropdown widget
        edit_topic_label = widgets.HTML("Old topic",
                                         layout=Layout(width='auto', height='auto', grid_area='edit_topic_label'))
        self.edit_topic_dropdown = widgets.Combobox(options = self.kg.topics_list[1:],
                                                    placeholder = 'Select topic to edit',
                                                    layout=Layout(width='auto', height='auto', grid_area='edit_topic_dropdown'))
        # topic confirm button
        self.edit_topic_confirm_button = widgets.Button(description = "Edit Topic",
                                                        layout=Layout(width='auto', height='auto', grid_area='edit_topic_button'))
        self.edit_topic_confirm_button.on_click(self.edit_topic_eventhandler)
        # topic edit text box
        update_topic_label = widgets.HTML("New topic",
                                          layout=Layout(width='auto', height='auto', grid_area='update_topic_label'))
        self.edit_topic_textbox = widgets.Text(placeholder = "Type the edited topic here.",
                                               layout=Layout(width='auto', height='auto', grid_area='update_topic_textbox'))
        self.edit_topic_output = widgets.Output(layout=Layout(width='auto', height='auto', grid_area='edit_topic_output'))
        self.edit_topic = widgets.GridBox(children=[edit_topic_label, self.edit_topic_dropdown,
                                          update_topic_label, self.edit_topic_textbox,
                                          self.edit_topic_confirm_button, self.edit_topic_output],
                                layout=Layout(
                                    grid_template_rows='auto auto auto',
                                    grid_template_columns='30% 70%',
                                    grid_template_areas='''
                                        "edit_topic_label edit_topic_dropdown"
                                        "update_topic_label update_topic_textbox"
                                        "edit_topic_button edit_topic_output"
                                    ''')
                                )
        
        # sub filter
        edit_ans_sub_label = widgets.HTML("Subject",
                                          layout=Layout(width='auto', height='auto', grid_area='edit_ans_sub_label'))
        self.edit_ans_sub_filter = widgets.Combobox(options = self.kg.subject_list,
                                                    placeholder = 'Choose the subject',
                                                    layout=Layout(width='auto', height='auto', grid_area='edit_ans_sub_dropdown'))
        # rel filter
        edit_ans_rel_label = widgets.HTML("Relation",
                                          layout=Layout(width='auto', height='auto', grid_area='edit_ans_rel_label'))        
        self.edit_ans_rel_filter = widgets.Combobox(options = self.kg.relation_list[1:],
                                                    placeholder = 'Choose the relation',
                                                    layout=Layout(width='auto', height='auto', grid_area='edit_ans_rel_dropdown'))
        # obj filter
        edit_ans_obj_label = widgets.HTML("Object",
                                  layout=Layout(width='auto', height='auto', grid_area='edit_ans_obj_label'))   
        self.edit_ans_obj_filter = widgets.Combobox(options = self.kg.object_list,
                                                    placeholder = 'Choose the object',
                                                    layout=Layout(width='auto', height='auto', grid_area='edit_ans_obj_dropdown'))
        # confirm button
        self.edit_ans_confirm_button = widgets.Button(description = "Edit Answer",
                                                      layout=Layout(width='auto', height='auto', grid_area='edit_ans_button'))
        self.edit_ans_confirm_button.on_click(self.edit_ans_eventhandler)
        # sentence output
        edit_ans_field_label = widgets.HTML("Answer",
                                            layout=Layout(width='auto', height='auto', grid_area='edit_ans_field_label'))
        self.edit_ans_field = widgets.Textarea(placeholder = "No sentence available for current selection.", 
                                               disabled = True, 
                                               layout=Layout(width='100%', height='10rem', grid_area='edit_ans_field'))
        self.edit_ans_output = widgets.Output(layout=Layout(width='auto', height='auto', grid_area='edit_ans_output'))
        self.edit_ans_output.append_stdout(f" ")
        self.edit_ans = widgets.GridBox(children=[edit_ans_sub_label, self.edit_ans_sub_filter,
                                                    edit_ans_rel_label, self.edit_ans_rel_filter,
                                                    edit_ans_obj_label, self.edit_ans_obj_filter,
                                                    edit_ans_field_label, self.edit_ans_field,
                                                    self.edit_ans_confirm_button, self.edit_ans_output],
                                layout=Layout(
                                    grid_template_rows='auto auto auto',
                                    grid_template_columns='auto auto auto auto auto auto auto auto auto',
                                    grid_template_areas='''
                                        "edit_ans_sub_label edit_ans_sub_dropdown edit_ans_sub_dropdown edit_ans_rel_label edit_ans_rel_dropdown edit_ans_rel_dropdown edit_ans_obj_label edit_ans_obj_dropdown edit_ans_obj_dropdown"
                                        "edit_ans_field_label edit_ans_field edit_ans_field edit_ans_field edit_ans_field edit_ans_field edit_ans_field edit_ans_field edit_ans_button"
                                        "edit_ans_output edit_ans_output edit_ans_output edit_ans_output edit_ans_output edit_ans_output edit_ans_output edit_ans_output edit_ans_output"
                                    ''')
                                )
        
        # generate observer
        self.edit_ans_sub_filter.observe(self.ans_sub_filter_eventhandler, names="value")
        self.edit_ans_rel_filter.observe(self.ans_rel_filter_eventhandler, names="value")
        self.edit_ans_obj_filter.observe(self.ans_obj_filter_eventhandler, names="value")
        
        self.edit_tab = widgets.Tab(children = [self.edit_node, self.edit_rel, self.edit_topic, self.edit_ans])
        self.edit_tab._titles = {0:"Edit Nodes", 1:"Edit Relations", 2:"Edit Topics", 3:"Edit Candidate Answers"}
        self.edit = widgets.VBox((cell_title, self.edit_tab))
        
    def initialize_remv(self):
        """ Initializes the removing portion of the dashboard."""
        cell_title1 = widgets.HTML("<b>Remove Nodes</b>")
        
        node_label = widgets.HTML("Node", 
                            layout = Layout(height='auto', width='auto', grid_area='label'))
        self.node_text2 = widgets.Combobox(options = list(set(self.kg.subject_list).union(set(self.kg.object_list))),
                                           placeholder = 'Select node',
                                           layout = Layout(height='auto', width='auto', grid_area='input'))
        self.node_output2 = widgets.Output(layout = Layout(height='auto', width='auto', grid_area='output'))
        self.node_output2.append_stdout(" ")   
        
        cell_title2 = widgets.HTML("<b>Remove Edges</b>")
        
        sub_label = widgets.HTML("Subject", 
                                 layout = Layout(height='auto', width='auto', grid_area='sub_label'))
        self.edge_subject_text2 = widgets.Combobox(options = self.kg.subject_list,
                                                  placeholder = "Select subject",
                                                    layout = Layout(height='auto', width='auto', grid_area='subject'))
        obj_label = widgets.HTML("Object", 
                                 layout = Layout(height='auto', width='auto', grid_area='obj_label'))
        self.edge_object_text2 = widgets.Combobox(options = self.kg.object_list,
                                                 placeholder = "Select object",
                                                   layout = Layout(height='auto', width='auto', grid_area='object'))
        rel_label = widgets.HTML("Relation", 
                                 layout = Layout(height='auto', width='auto', grid_area='rel_label'))
        self.edge_relation_text2 = widgets.Output(layout = Layout(height='auto', width='auto', grid_area='relation'))
        
        self.edge_subject_text2.observe(self.edge_subject_text2_eventhandler, names='value')
        self.edge_object_text2.observe(self.edge_object_text2_eventhandler, names='value')
        self.edge_output2 = widgets.Output(layout = Layout(height='auto', width='auto', grid_area='output'))
        self.edge_output2.append_stdout(" ")  
        
        self.node_button2 = widgets.Button(description="Remove node from KG",
                                           layout = Layout(height='auto', width='auto', grid_area='button'))
        self.edge_button2 = widgets.Button(description="Remove edge from KG",
                                           layout = Layout(height='auto', width='auto', grid_area='button'))
                                           
        self.node_button2.on_click(self.node_button_click_eventhandler2)
        self.edge_button2.on_click(self.edge_button_click_eventhandler2)      
        
        self.update_rem_edge_output()
        
        display_cell1 = widgets.GridBox(children=[self.node_text2, self.node_button2, 
                                                  self.node_output2, 
                                                  node_label],
                                        layout=Layout(
                                            grid_template_rows='auto auto auto',
                                            grid_template_columns='30% 70%',
                                            grid_template_areas='''
                                            "label input"
                                            "button button"
                                            "output output"
                                            ''')
                                       )
        display_cell2 = widgets.GridBox(children=[sub_label, self.edge_subject_text2,
                                                  obj_label, self.edge_object_text2,
                                                  rel_label, self.edge_relation_text2,
                                                  self.edge_button2,
                                                  self.edge_output2],
                                        layout=Layout(
                                            grid_template_rows='auto auto auto auto auto',
                                            grid_template_columns='30% 70%',
                                            grid_template_areas='''
                                            "sub_label subject"
                                            "obj_label object"
                                            "rel_label relation"
                                            "button button"
                                            "output output"
                                            ''')
                                       )
        
        self.remv = widgets.VBox((cell_title1, display_cell1, cell_title2, display_cell2))
        
    def initialize_add_doc(self):
        cell_title = widgets.HTML("<b>Add a document</b>")
        doc_label = widgets.HTML("File location: ",
                                 layout=Layout(height='auto', width='auto', grid_area='doc_label'))
        self.doc_text = widgets.Text(placeholder = "Path to file", 
                                     layout=Layout(height='auto', width='auto', grid_area='doc_text'))
        self.doc_button = widgets.Button(description="Add Document",
                                         layout=Layout(height='auto', width='auto', grid_area='doc_button'))
        self.doc_output = widgets.Output(layout=Layout(height='auto', width='auto', grid_area='doc_output'))
        self.doc_button.on_click(self.doc_button_click_eventhandler)
        self.doc = widgets.VBox((cell_title, 
                                   widgets.GridBox(children=[doc_label, self.doc_text, self.doc_button,
                                                             self.doc_output],
                                                        layout=Layout(
                                                            grid_template_rows='auto auto auto',
                                                            grid_template_columns='30% 70%',
                                                            grid_template_areas='''
                                                                "doc_label doc_text" 
                                                                "doc_button ."
                                                                "doc_output doc_output"
                                                            ''')
                                                        )        
                                   ))
        
        

    '''Querying Input Box'''
    def initialize_query(self):
        cell_title = widgets.HTML("<b>Query</b>")
        self.Q = Query(self.kg.topic_triplets_df)
        query_label = widgets.HTML("Enter Query: ",
                                   layout = Layout(height='100%', width='100%', grid_area='query_label'))
        self.query_input = widgets.Text(placeholder='Enter a query here.', 
                                            layout = Layout(height='100%', width='100%', grid_area='query_textarea'))
        self.query_button = widgets.Button(description='Submit',
                                           layout = Layout(height='auto', width='auto', grid_area='query_button'))
        
        #display query
        query_response_label = widgets.HTML("Query Result: ",
                                   layout = Layout(height='100%', width='100%', grid_area='query_response_label'))
        self.query_response_textarea = widgets.Textarea(placeholder="Answer will be shown here.", disabled = True, resize = False,
                                               layout = Layout(height='10rem', width='100%', grid_area='query_response_textarea'))
        self.query_input.on_submit(self.query_submit)
        self.query_button.on_click(self.query_submit)
        
        #display answer
        self.query = widgets.VBox((cell_title, 
                                   widgets.GridBox(children=[query_label, self.query_input, self.query_button,
                                                 query_response_label, self.query_response_textarea],
                                                        layout=Layout(
                                                            grid_template_rows='20% 15% 35%',
                                                            grid_template_columns='20% 50% 20%',
                                                            grid_template_areas='''
                                                                ". . ."
                                                                "query_label query_textarea query_button"
                                                                "query_response_label query_response_textarea query_response_textarea"
                                                            ''')
                                                        )        
                                   ))
                                           
    def common_filtering(self,rel,top): 
        self.output.clear_output()
        with self.output:
            if (rel == 'ALL') & (top == 'ALL'):
                self.kg.display_filtered_graph()
            elif (rel == 'ALL'):
                self.kg.display_filtered_graph(top=top)
            elif (top == 'ALL'):
                self.kg.display_filtered_graph(rel=rel)
            else:
                self.kg.display_filtered_graph(rel=rel, top=top)

    def dropdown_relation_eventhandler(self, change):
        self.common_filtering(change.new, self.dropdown_entity.value)

    def dropdown_entity_eventhandler(self, change):
        self.common_filtering(self.dropdown_relation.value, change.new)

    def refresh_btn_eventhandler(self, b):
        self.common_filtering(self.dropdown_relation.value,self.dropdown_entity.value)
        
    def node_button_click_eventhandler(self,b):
        thread = threading.Thread(target=self.add_node_event, daemon=True)
        thread.start()
        # create thread and start thread
    
    def node_button_click_eventhandler2(self,b):
        thread = threading.Thread(target=self.rem_node_event, daemon=True)
        thread.start()
    
    def edge_subject_text2_eventhandler(self, change):
        self.update_rem_edge_output()
    
    def edge_object_text2_eventhandler(self, change):
        self.update_rem_edge_output()
        
    def update_rem_edge_output(self):
        if self.edge_subject_text2.value in self.kg.subject_list:
            index = self.kg.triplets_map[self.edge_subject_text2.value]
            objects = set(self.kg.numbers_map[successor] for successor in list(self.kg.G.succ[index]))
            self.edge_object_text2.options = tuple(objects.intersection(set(self.kg.object_list)))
        else:
            self.edge_object_text2.options = tuple(self.kg.object_list)
        if self.edge_object_text2.value in self.kg.object_list:
            index = self.kg.triplets_map[self.edge_object_text2.value]
            subjects = set(self.kg.numbers_map[predecessor] if predecessor in self.kg.numbers_map.keys() else None for predecessor in list(self.kg.G.pred[index]))
            self.edge_subject_text2.options = tuple(subjects.intersection(set(self.kg.subject_list)))
        else:
            self.edge_subject_text2.options = tuple(self.kg.subject_list)
        with self.edge_relation_text2:
            self.edge_relation_text2.outputs = ()
            try:
                sub = self.kg.triplets_map[self.edge_subject_text2.value]
                obj = self.kg.triplets_map[self.edge_object_text2.value]
                txt = self.kg.edges[(sub,obj)]
            except KeyError:
                txt = "NO RELATION "
            self.edge_relation_text2.append_stdout(txt)
                                           
    def add_node_event(self):
        with self.node_output:
            self.node_output.outputs = ()
            self.node_output.append_stdout(f"\'{self.node_text.value}\' added to KG as a node. ")
            self.kg.add_as_node(self.node_text.value) # to-do : add_as_node
            self.node_text.value = "" # reset text inside
            
            # refresh dropdown list
            self.refresh_all()
            
            time.sleep(3.0)
            self.node_output.outputs = ()
    
    def edit_node_eventhandler(self, b):
        thread = threading.Thread(target=self.edit_node_event, daemon=True)
        thread.start()
    
    def edit_node_event(self):
        with self.edit_node_output:
            old_node = self.edit_node_dropdown.value
            new_node = self.edit_node_textbox.value
            if old_node not in set(self.kg.subject_list).union(set(self.kg.object_list)):
                self.edit_node_output.append_stdout("Node not found. ")
            elif new_node == '':
                self.edit_node_output.append_stdout("Update failed. Edit box empty.")
            else:
                self.edit_node_output.append_stdout("Updating node... ")
                self.kg.edit_node(old_node, new_node)
                self.refresh_all()
                # refresh graph as well
                self.common_filtering(self.dropdown_relation.value,self.dropdown_entity.value)
                self.edit_node_output.outputs = ()
                self.edit_node_output.append_stdout("Successfully updated node. ")
                self.edit_node_dropdown.value = new_node
                self.edit_node_textbox.value = ''
            self.refresh_all()
            time.sleep(2.0)
            self.edit_node_output.outputs = ()
        
    def edit_rel_eventhandler(self, b):
        thread = threading.Thread(target=self.edit_rel_event, daemon=True)
        thread.start()
        
    def edit_rel_event(self):
        with self.edit_rel_output:
            old_rel = self.edit_rel_dropdown.value
            new_rel = self.edit_rel_textbox.value
            if old_rel not in self.kg.relation_list:
                self.edit_rel_output.append_stdout("Relation not found. ")
            elif new_rel == '':
                self.edit_rel_output.append_stdout("Update failed. Edit box empty.")
            else:
                self.edit_rel_output.append_stdout("Updating relation... ")
                self.kg.edit_edge(old_rel, new_rel)
                self.refresh_all()
                self.edit_rel_output.append_stdout("Successfully updated relation. ")
                self.edit_rel_dropdown.value = new_rel
                self.edit_rel_textbox.value= ''
            self.refresh_all()
            time.sleep(2.0)
            self.edit_rel_output.outputs = ()
    
    def edit_topic_eventhandler(self, b):
        thread = threading.Thread(target=self.edit_topic_event, daemon=True)
        thread.start()
    
    def edit_topic_event(self):
        with self.edit_topic_output:
            old_topic = self.edit_topic_dropdown.value
            new_topic = self.edit_topic_textbox.value
            if old_topic not in self.kg.topics_list:
                self.edit_topic_output.append_stdout("Topic not found. ")
            elif new_topic == '':
                self.edit_topic_output.append_stdout("Update failed. Edit box empty. ")
            else:
                self.edit_topic_output.append_stdout("Updating topic... ")
                self.kg.edit_topic(old_topic, new_topic)
                self.dropdown_entity.options = self.kg.topics_list
                self.edit_topic_output.append_stdout("Sucessfully updated topic. ")
                self.edit_topic_dropdown.value = new_topic
                self.edit_topic_textbox.value = ''
            self.refresh_all()
            time.sleep(2.0)
            self.edit_topic_output.outputs = ()
            
    def edit_ans_eventhandler(self, b):
        thread = threading.Thread(target=self.edit_ans_event, daemon=True)
        thread.start()
            
    def edit_ans_event(self):    
        sub = self.edit_ans_sub_filter.value
        rel = self.edit_ans_rel_filter.value
        obj = self.edit_ans_obj_filter.value
        ans = self.edit_ans_field.value
        response = self.kg.edit_ans(sub, rel, obj, ans)
        with self.edit_ans_output:
            self.edit_ans_output.outputs = ()
            if response:
                self.edit_ans_output.append_stdout(f"Canditate answer successfully updated.")
            else:
                self.edit_ans_output.append_stdout(f"Error. Triplet not found between selection.")
            time.sleep(2.0)
            self.edit_ans_output.outputs = ()
            self.edit_ans_output.append_stdout(f" ")
            
            
    def ans_sub_filter_eventhandler(self, change):
        return self.ans_filter(change.new, self.edit_ans_rel_filter.value, self.edit_ans_obj_filter.value)
    
    def ans_rel_filter_eventhandler(self, change):
        return self.ans_filter(self.edit_ans_sub_filter.value, change.new, self.edit_ans_obj_filter.value)
    
    def ans_obj_filter_eventhandler(self, change):
        return self.ans_filter(self.edit_ans_sub_filter.value, self.edit_ans_rel_filter.value, change.new)
    
    def ans_filter(self, sub, rel, obj):
        self.edit_ans_field.value = ""
        if sub in self.kg.subject_list:
            if rel in self.kg.relation_list:
                obj_selection = self.kg.topic_triplets_df.loc[(self.kg.topic_triplets_df.subject == sub) & (self.kg.topic_triplets_df.relation == rel), "object"].values
                self.edit_ans_obj_filter.options = tuple(obj_selection)
            elif obj in self.kg.object_list:
                rel_selection = self.kg.topic_triplets_df.loc[(self.kg.topic_triplets_df.subject == sub) & (self.kg.topic_triplets_df.object == obj), "relation"].values
                self.edit_ans_rel_filter.options = tuple(rel_selection)
            else:
                logic = (self.kg.topic_triplets_df.subject == sub)
                rel_selection = self.kg.topic_triplets_df.loc[logic, "relation"].values
                obj_selection = self.kg.topic_triplets_df.loc[logic, "object"].values
                self.edit_ans_rel_filter.options = tuple(rel_selection)
                self.edit_ans_obj_filter.options = tuple(obj_selection)
        elif rel in self.kg.relation_list:
            if obj in self.kg.object_list:
                sub_selection = self.kg.topic_triplets_df.loc[(self.kg.topic_triplets_df.relation == rel) & (self.kg.topic_triplets_df.object == obj), "subject"].values
                self.edit_ans_sub_filter.options = tuple(sub_selection)
            else:
                logic = (self.kg.topic_triplets_df.relation == rel)
                sub_selection = self.kg.topic_triplets_df.loc[logic, "subject"].values
                obj_selection = self.kg.topic_triplets_df.loc[logic, "object"].values
                self.edit_ans_rel_filter.options = tuple(sub_selection)
                self.edit_ans_obj_filter.options = tuple(obj_selection)
        elif obj in self.kg.object_list:
            logic = (self.kg.topic_triplets_df.object == obj)
            sub_selection = self.kg.topic_triplets_df.loc[logic, "subject"].values
            rel_selection = self.kg.topic_triplets_df.loc[logic, "relation"].values
            self.edit_ans_sub_filter.options = tuple(sub_selection)
            self.edit_ans_rel_filter.options = tuple(rel_selection)
        else:
            self.edit_ans_sub_filter.options = tuple(self.kg.subject_list)
            self.edit_ans_rel_filter.options = tuple(self.kg.relation_list)
            self.edit_ans_obj_filter.options = tuple(self.kg.object_list)
        ans = self.kg.get_sentence(sub, rel, obj)
        if ans is None:
            self.edit_ans_field.disabled = True
        else:
            self.edit_ans_field.value = ans.item()
            self.edit_ans_field.disabled = False
    
    def rem_node_event(self):
        with self.node_output2:
            if self.node_text2.value not in set(self.kg.subject_list).union(set(self.kg.object_list)):
                self.node_output2.append_stdout("Node not found in knowledge graph. ")
            else:
                self.node_output2.append_stdout("Removing node... ")
                self.kg.remove_node(self.node_text2.value)
                self.node_output2.append_stdout("Node successfully removed. ")
                self.refresh_all()
                self.update_rem_edge_output()
                self.node_text2.value = ""
            time.sleep(2.0)
            self.node_output2.outputs = ()   
            self.node_output2.append_stdout(" ")         

    def triplet_button_click_eventhandler(self, b):
        thread = threading.Thread(target=self.add_triplet_event, daemon=True)
        thread.start()
    
    def edge_button_click_eventhandler2(self,b):
        thread = threading.Thread(target=self.rem_edge_event, daemon=True)
        thread.start()
        
    def add_triplet_event(self):
        """
        Adds triplet to graph.
        Adds nodes to graph if not already in graph.
        """
        with self.triplet_output:
            # Add as edge
            success = self.kg.add_as_edge(self.subject_text.value, self.object_text.value, self.relation_text.value) 
            
            if success:
                self.triplet_output.append_stdout("Triplet successfully added.")

                subject_topics = self.kg.get_tags(self.subject_text.value, "subject")
                object_topics = self.kg.get_tags(self.object_text.value, "object")

                # Add to dataframe
                self.kg.add_triplet_to_df(self.answer_text.value, self.subject_text.value, self.relation_text.value,
                                         self.object_text.value, self.kg.triplets_map[self.subject_text.value], self.kg.triplets_map[self.object_text.value],
                                         subject_topics, object_topics)
            else:
                self.triplet_output.append_stdout(f"Unable to add edge. A relationship already exists between the two nodes.")
            
            self.subject_text.value = "" # reset text inside
            self.object_text.value = "" # reset text inside
            self.relation_text.value = "" # reset text inside
            self.answer_text.value = "" # reset text inside
            
            # refresh dropdown list
            self.refresh_all()
            
            time.sleep(2.0)
            self.triplet_output.outputs = ()
            
    def rem_edge_event(self):
        with self.edge_output2:
            if self.edge_relation_text2.outputs[0]['text'] == 'NO RELATION ':
                self.edge_output2.append_stdout("No relation to remove.")
            else:
                self.edge_output2.outputs = ()
                self.edge_output2.append_stdout("Removing edge... ")
                self.kg.remove_edge(self.edge_subject_text2.value, self.edge_object_text2.value)
                self.edge_output2.append_stdout("Edge successfully removed. ")
                self.dropdown_relation.options = self.kg.relation_list
                self.update_rem_edge_output()
            time.sleep(2.0)
            self.edge_output2.outputs = ()
            
    def doc_button_click_eventhandler(self, b):
        thread = threading.Thread(target=self.add_doc_event, daemon=True)
        thread.start()
        thread.join()
        
    def add_doc_event(self):
        with self.doc_output:
            self.doc_output.outputs = ()
            self.doc_output.append_stdout("Document processing...\n")
            self.kg.add_doc(self.doc_text.value)
            self.refresh_all()
            
            time.sleep(2.0)
            self.doc_output.outputs = ()
            self.doc_output.append_stdout(f'Document of {self.doc_text.value} added to the knowledge graph')
            self.doc_text.value = ''
            time.sleep(2.0)
            self.doc_output.outputs = ()
            
        self.output.clear_output()
        self.initialize_filters()
        self.initialize_adds() 
        self.initialize_remv()
        with self.output:
            self.kg.display_full_graph()
            
    def refresh_all(self):
        self.subject_text.options = list(set(self.kg.subject_list).union(set(self.kg.object_list)))
        self.object_text.options = list(set(self.kg.subject_list).union(set(self.kg.object_list)))
        self.relation_text.options = self.kg.relation_list
        self.node_text2.options = list(set(self.kg.subject_list).union(set(self.kg.object_list)))
        self.dropdown_relation.options = self.kg.relation_list
        self.dropdown_entity.options = self.kg.topics_list
        self.edge_subject_text2.options = self.kg.subject_list
        self.edge_object_text2.options = self.kg.object_list
        self.edit_node_dropdown.options = list(set(self.kg.subject_list).union(set(self.kg.object_list)))
        self.edit_rel_dropdown.options = self.kg.relation_list[1:] # without ALL
        self.edit_topic_dropdown.options = self.kg.topics_list[1:] # without ALL
        self.edit_ans_sub_filter.options = self.kg.subject_list
        self.edit_ans_rel_filter.options = self.kg.relation_list[1:] # without ALL
        self.edit_ans_obj_filter.options = self.kg.object_list
            
    def query_submit(self, b):
        self.query_response_textarea.value = ''
        #to be replaced with query function
        #print("Your query is: ", query.value)
        print("Thank you for your question! \n" + self.query_input.value)
        self.query_response_textarea.value = self.Q.get_sentence(self.query_input.value)
        self.query_input.value = ''

In [5]:
class Document:
    def __init__(self, filename):
        import pdfplumber
        self.file = pdfplumber.open(filename)
        self.process_text()
        self.process_table()
        self.sentence_split()
        self.process_bullets()
        self.sentenceCount = len(self.sentences)
        self.wordCounts = list(map(lambda x: len(x.split()), self.sentences))
        
    def process_text(self):
        pages = self.file.pages
        self.sentences = []
        questions = []
        for page in pages:
            extract = page.extract_text()
            sentences = extract.split('\n')
            sentences = list(filter(lambda x: len(re.sub('\n','',x)) > 1 , sentences))
            new_sentences = []
            for sentence in sentences:
                words = sentence.split(".")
                #\uf0d8 is the unicode bullet point, might need to update accordingly if there are other bullet points
                s = ''
                if len(words) > 1 and '$' not in words[0] and len(words[0])<5:
                    s = s.join(words[1:])
                else:
                    s = sentence
                s = re.sub('\uf0d8  ','~ ', s)
                s = re.sub('\uf0d8 ','~ ', s)
                s = re.sub('- ','~ ', s)
                s = re.sub('“','"',s)
                s = re.sub('”','"',s)
                s = re.sub("’","'",s)      
                s = s.strip(' \n')
                if len(s) <= 1:
                    continue
                new_sentences.append(s)
            new_sentences = list(filter(lambda x: len(x) > 1 , new_sentences))
            sentences = []
            temp = ''
            for s in new_sentences:
                s = s.strip()
                if s[-1] == '?':
                    if len(temp) != 0:
                        if s[0].isupper() and not s[1].isupper():
                            sentences.append(temp)
                            sentences.append(s)
                            questions.append(s)
                        else:
                            sentences.append(temp+s)
                            questions.append(temp+s)
                    else:
                        sentences.append(s)
                        questions.append(s)
                    temp = ''
                    continue
                elif s[0:2] == '~ ' and s[-1] != ';' and temp == '':
                    temp+=s.strip()
                    continue
                elif temp[0:2] == '~ ':
                    if s[0:2] == '~ ':
                        temp = temp+';'
                        temp = temp.strip()
                        sentences.append(temp)
                        temp = s
                        continue
                    elif s[0] == '*':
                        temp+=';'
                        temp = temp.strip()
                        sentences.append(temp)
                        temp = s
                        continue
                    elif s[-1] not in '.;:?!*':
                        temp+=' '
                        temp+=s
                        continue
                    elif s[-1] == ';' and s[0:2] != '~ ':
                        temp = temp+' '+s
                        temp = temp.strip()
                        sentences.append(temp)
                    elif s[-1] == ';':
                        temp = temp+';'
                        temp = temp.strip()
                        sentences.append(temp)
                        sentences.append(s)
                    elif s[0].isupper():
                        temp = temp+';'
                        temp = temp.strip()
                        sentences.append(temp)
                        sentences.append(s)
                        continue
                    elif s[0] == '*':
                        temp+=';'
                        temp = temp.strip()
                        sentences.append(temp)
                        temp = s
                    else:
                        temp = temp+' '+s+';'
                        temp = temp.strip()
                        sentences.append(temp)
                    temp = ''
                    continue
                if s[-1] not in '.;:?!*':
                    if len(temp) >= 1:
                        if temp[-1] in '.;:?!*':
                            temp = temp.strip()
                            sentences.append(temp)
                            temp = ''
                    temp += ' '
                    temp += s
                    temp += ' '
                    continue
                elif s[0].isupper() and not s[1].isupper():
                    if len(temp) <= 1:
                        temp = ''
                        continue
                    temp = temp.strip()
                    temp+='.'
                    sentences.append(temp)
                    temp = s
                    continue
                else:
                    temp = temp+' '+s
                    temp = temp.strip()
                    sentences.append(temp)
                    temp = ''
                    continue
                sentences.append(s)
            self.sentences.extend(sentences)
    def process_table(self):
        self.table_count = 0
        self.tables = []
        for page in self.file.pages:
            t = page.extract_tables()
            if len(t) != 0:
                self.table_count+=1
                table = t[0]
                table = list(map(lambda x: list(filter(lambda y: y != '' and y != None, x)), table))
                table = list(filter(lambda x: len(x) != 0, table))
                flat_table = [item for t in table for item in t]
                flat_table = list(map(lambda x: x.replace('-','~'), flat_table))
                self.tables.append(table)
                for s in range(len(self.sentences)):
                    if any(sub in self.sentences[s] for sub in flat_table):
                        for sub in flat_table:
                            self.sentences[s] = self.sentences[s].replace(sub, '')
                        self.sentences[s] = self.sentences[s].strip()
                data = table[1:]
                for row in data:
                    for column in table[0][1:]:
                        s = ''
                        s += column
                        s += ' is '
                        s += row[1]
                        s += ' for '
                        s += table[0][0]
                        s += ' '
                        s += row[0]
                        s += '.'
                        self.sentences.append(s)
    def sentence_split(self):
        new_final = []
        for s in self.sentences:
            l = s.split(sep='. ')
            for sub in l:
                if len(sub) > 0:
                    sub += '.'
                    new_final.append(sub)
        self.sentences = new_final
    def process_bullets(self):
        sentences = []
        final = self.sentences
        for i in range(len(final)):
            try:
                if len(final[i][:final[i].index('.')+1]) < 4:
                    final[i] = final[i][final[i].index('.')+1:]
            except ValueError:
                pass
            final[i] = final[i].strip()
            strings = final[i].split(sep='. ')
            if len(strings) > 1:
                strings = list(map(lambda x: x+'.',strings[:-1])) + [strings[-1]]
            sentences.extend(strings)
        #process bullet points
        processed_bullet_sentences = []
        temp = ''
        for s in sentences:
            if s[-1] == ':' and temp == '':
                temp += s[:-1]
                temp += ' '
                continue
            elif s[-1] == ';':
                if s[0:2].isupper():
                    t = s[:-1]
                else:
                    t = s[0].lower()+s[1:-1]
                processed_bullet_sentences.append(temp+t)
                continue
            elif s[-1] == ':' and temp != '':
                temp = ''
                temp += s[:-1]
                temp += ' '
                continue
            elif temp != '':
                temp = ''
                continue
            processed_bullet_sentences.append(s)
        self.sentences = processed_bullet_sentences

In [6]:
class Query:
    
    def __init__(self, kg_df):
        """ 
        Takes in a pandas dataframe with columns confidence, sentence, subject, relation, object
        which the Query object will do its querying on
        """
        self.kg_df = kg_df
        self.fzwz_threshold = 80 #able to adjust
        self.model = api.load('word2vec-google-news-300')
        self.nlp = spacy.load('en_core_web_sm')

    def get_subject(self, text):
        '''Call get_subject on query to obtain 1 single string of subject'''
        doc = self.nlp(text)
        all_subjects = []
        for chunk in doc.noun_chunks:
            all_subjects.append(str(chunk))
        return " ".join(all_subjects)

    def get_relation(self, text):
        '''Call get_relation on query to obtain 1 single string of relation'''
        relation_phrase = []
        pattern = r'(<VERB>?<ADV>*<VERB>+)'
        doc = textacy.make_spacy_doc(text, lang='en_core_web_sm')
        verb_phrases = textacy.extract.pos_regex_matches(doc, pattern)
    # Print all Verb Phrase
        for chunk in verb_phrases:
            relation_phrase.append(str(chunk.text))
        if len(relation_phrase) == 0 and "is" in text:
            return "is"
        return " ".join(relation_phrase)

    def get_top_entities(self, kg_df, qns_entity):
        '''Entities Filtering'''
        kg_df['fuzzywuzzy_score'] = kg_df['subject'].apply(lambda x: fuzz.partial_ratio(qns_entity, x))
        new_kg_df = kg_df.sort_values('fuzzywuzzy_score', ascending = False)
        return new_kg_df[new_kg_df['fuzzywuzzy_score']>= self.fzwz_threshold]

    def get_mean_vector(self, word2vec_model, relation):
        '''Handle multi word relation phrase cosine similarity comparison'''
        relation_words = [w for w in relation.lower().split()]
        # remove out-of-vocabulary words
        words = [word for word in relation_words if word in word2vec_model.vocab]
        if len(words) >= 1:
            return np.mean(word2vec_model[words], axis=0)
        else:
            return []

    def compute_cosine_similarity(self, qns_relation, kg_df):
        '''Cosine Similarity between every shortlisted relation in the kg_df and the relation in the query'''
        qns_relation_vec = np.array([self.get_mean_vector(self.model, qns_relation)])
        cosine_scores = []
        for i in range(len(kg_df)):
            vect = np.array([self.get_mean_vector(self.model, kg_df.iloc[i]['relation'])])
            cosine_scores.append(cosine_similarity(qns_relation_vec, vect)[0][0])
        kg_df['cosine similarity'] = cosine_scores
        return kg_df

    def preprocess(self, sentence):
        ''''''
        return [w for w in sentence.lower().split()]
    
    def process_question(self, query):
        if query.endswith("?"):
            query = query[:-1]
        fiveW_oneH = ["What", "When", "Who", "Why", "Where", "How"] 
        '''remove the what/when/who/why/where/how that is commonly used in question phrasing
        so that phrase noun and relation can be identified better'''
        return " ".join([item for item in query.split() if item not in fiveW_oneH])

    def compare_relation(self, kg_df, relation):
        '''Word Mover distance smaller the distance, the more similar.
        Using Word2Vec google news from gensim pacakge for relationship word vectors comparison'''
        wmd_relation = []
        for i in range(kg_df.shape[0]):
            wmd_relation.append(self.model.wmdistance(self.preprocess(relation), self.preprocess(kg_df['relation'].iloc[i])))
        kg_df['wmd_relation'] = wmd_relation
        sorted_wmd = kg_df.sort_values(by=['fuzzywuzzy_score', 'cosine similarity', 'wmd_relation'], ascending=[False, False, True])
        if sorted_wmd.shape[0] == 0:
            return "Sorry, we were unable to understand your question."
        return sorted_wmd.iloc[0]['sentence']

    def get_sentence(self, query):
        '''Compile the above functions into a singe function call'''
        refined_query = self.process_question(query)
        query_subject = ""
        query_relation = "" #if unable to get subject or relation from the query
        query_relation = self.get_relation(refined_query)
        query_subject = self.get_subject(refined_query)
        if query_subject == "" or query_relation == "":
            return "Sorry, we were unable to understand your question."
        top_entities = self.get_top_entities(self.kg_df, query_subject)
        if top_entities.shape[0] == 0:
            return "Sorry, no information related to your query was found."
        cos_sim_relation = self.compute_cosine_similarity(query_relation, top_entities)
        cos_sim_relation = cos_sim_relation.sort_values(by=['fuzzywuzzy_score', 'cosine similarity'], ascending=[False, False])
        top10_cos_sim = cos_sim_relation.head(10)
        answer_sentence = self.compare_relation(top10_cos_sim, query_relation)
        if answer_sentence.endswith("?"):
            #check for the retrieved sentence being a question mark
            #does not value-add to return a question but it is present because the triplets were extracted from the FAQ document.
            return "Sorry, no information related to your query was found."
        return answer_sentence

In [19]:
triplets_df = pd.read_json(r"json_extract_5.json")
kg = KnowledgeGraph(triplets_df)
d = Dashboard(kg)

GridspecLayout(children=(Output(layout=Layout(grid_area='widget001', height='100%', width='100%')), HBox(child…