In [None]:
import pandas as pd 
import random 
import json 
import os 
import pickle 
import time 
import wikipedia as wp 
from wikipedia.exceptions import DisambiguationError,PageError
import networkx as nx 
import matplotlib.pyplot as plt 
import argparse

from gensim.models.doc2vec import Doc2Vec,TaggedDocument
import nltk 
nltk.download('punkt')


In [None]:
import numpy as np 
import torch 
from torch_geometric.utils.convert import from_networkx 
import warnings 
warnings.filterwarnings('ignore')

In [None]:
def create_graph(topics=["tests"], depth=20, max_size=20, simplify=False, plot=False, save_dir=None, max_nodes=None):
    rg = RelationshipGenerator(save_dir=save_dir)

    for topic in topics:
        rg.scan(topic, max_nodes=max_nodes)

    print(f"Created {len(rg.links)} links with {rg.rank_terms().shape[0]} nodes.")

    links = rg.links
    links = remove_self_references(links)

    node_data = rg.rank_terms()
    nodes = node_data.index.tolist()
    node_weights = node_data.values.tolist()
    node_weights = [nw * 100 for nw in node_weights]
    nodelist = nodes


    G = nx.DiGraph() # MultiGraph()

    # G.add_node()
    G.add_nodes_from(nodes)
    feature_vectors, model = doc2vec(nodes, rg)
    nx.set_node_attributes(G, feature_vectors, name="features")

    # Add edges
    G.add_weighted_edges_from(links)
    return G, nodelist, node_weights, model

In [None]:
def remove_self_references(l):
    return[i for i in l if i[0]!=i[1]]

In [None]:
def doc2vec(nodes, rg):
  # List of tuples page title, page content
    features = dict(filter(lambda x: x[0] in nodes, rg.features.items()))
    features = sorted(rg.features.items(), key=lambda key_value: nodes.index(key_value[0]))
    tokenized_docs = [nltk.word_tokenize(' '.join(doc).lower()) for doc in features]
    tagged_docs = [TaggedDocument(words=doc, tags=[str(i)]) for i, doc in enumerate(tokenized_docs)]
    # Model
    model = Doc2Vec(vector_size=300, min_count=1, epochs=50)
    model.build_vocab(tagged_docs)
    model.train(tagged_docs, total_examples=model.corpus_count, epochs=model.epochs)
    feature_vectors = {node: model.infer_vector(tokenized_docs[i]) for i, node in enumerate(nodes)}

    return feature_vectors, model

In [None]:
class RelationshipGenerator():
    """Generates relationships between terms, based on wikipedia links"""
    def __init__(self, save_dir):
        self.links = [] # [start, end, weight]
        self.features = {} #{page: page_content}
        self.page_links = {}


    def scan(self, start=None, repeat=0, max_nodes=None):
        print("On depth: ", repeat)
        """Start scanning from a specific word, or from internal database

        Args:
            start (str): the term to start searching from, can be None to let
                algorithm decide where to start
            repeat (int): the number of times to repeat the scan
        """
        nodes_visited = 0
        while repeat >= 0:
            if max_nodes != None and nodes_visited == max_nodes:
              return
            # should check if start page exists
            # and haven't already scanned
            # if start in [l[0] for l in self.links]:
            #     raise Exception("Already scanned")

            term_search = True if start is not None else False

            # If a start isn't defined, we should find one
            if start is None:
                start = self.find_starting_point()

            # Scan the starting point specified for links
            print(f"Scanning page {start}...")
            try:
                # Fetch the page through the Wikipedia API
                page = wp.page(start)
                self.features[start] = page.content
                links = list(set(page.links))

                # ignore some uninteresting terms
                links = [l for l in links if not self.ignore_term(l)]

                # Add links to database
                link_weights = []
                for link in links:
                    weight = self.weight_link(page, link)
                    link_weights.append(weight)

                link_weights = [w / max(link_weights) for w in link_weights]

                #add the links
                for i, link in enumerate(links):
                  if max_nodes != None and nodes_visited == max_nodes:
                    return

                 #Access all the pages that link to the links that have been added
                  try:
                    link = link.lower()
                    if link not in self.features or link not in self.page_links:
                        time.sleep(np.random.randint(0, 10))
                        page = wp.page(link)
                        self.features[link] = page.content
                        self.page_links[link] = [l.lower() for l in page.links]
                        print("Page Accessed: ", link)
                        nodes_visited += 1
                    else:
                        print("Page has previously been accessed: ", link)
                    total_nodes = set([l[1].lower() for l in self.links])
                    for links_to in set([l.lower() for l in self.page_links[link]]).intersection(total_nodes):
                        self.links.append([link, links_to, 0.1]) # 3 works pretty well
                        print("hi")
                    self.links.append([start, link, link_weights[i] + 2 * int(term_search)]) # 3 works pretty well

                  except (DisambiguationError, PageError):
                      print("Page not found: ", link)


                # Print some data to the user on progress
                explored_nodes = set([l[0] for l in self.links])
                explored_nodes_count = len(explored_nodes)
                total_nodes = set([l[1] for l in self.links])
                total_nodes_count = len(total_nodes)
                new_nodes = [l.lower() for l in links if l not in total_nodes]
                new_nodes_count = len(new_nodes)
                print(f"New nodes added: {new_nodes_count}, Total Nodes: {total_nodes_count}, Explored Nodes: {explored_nodes_count}")

            except (DisambiguationError, PageError):
                # This happens if the page has disambiguation or doesn't exist
                # We just ignore the page for now, could improve this
                # self.links.append([start, "DISAMBIGUATION", 0])
                print("ERROR, I DID NOT GET THIS PAGE")
                pass

            repeat -= 1
            start = None

    def find_starting_point(self):
        """Find the best place to start when no input is given"""
        # Need some links to work with.
        if len(self.links) == 0:
            raise Exception("Unable to start, no start defined or existing links")

        # Get top terms
        res = self.rank_terms()
        sorted_links = list(zip(res.index, res.values))
        all_starts = set([l[0] for l in self.links])

        # Remove identifiers (these are on many Wikipedia pages)
        all_starts = [l for l in all_starts if '(identifier)' not in l]

        # print(sorted_links[:10])
        # Iterate over the top links, until we find a new one
        for i in range(len(sorted_links)):
            if sorted_links[i][0] not in all_starts and len(sorted_links[i][0]) > 0:
                return sorted_links[i][0]

        # no link found
        raise Exception("No starting point found within links")
        return

    @staticmethod
    def weight_link(page, link):
        """Weight an outgoing link for a given source page

        Args:
            page (obj):
            link (str): the outgoing link of interest

        Returns:
            (float): the weight, between 0 and 1
        """
        weight = 0.1

        link_counts = page.content.lower().count(link.lower())
        weight += link_counts

        if link.lower() in page.summary.lower():
            weight += 3

        return weight

    def rank_terms(self, with_start=True):
        # We can use graph theory here!
        # tws = [l[1:] for l in self.links]
        df = pd.DataFrame(self.links, columns=["start", "end", "weight"])

        if with_start:
            df = df.append(df.rename(columns={"end": "start", "start":"end"}))

        return df.groupby("end").weight.sum().sort_values(ascending=False)

    def get_key_terms(self, n=20):
        return "'" + "', '".join([t for t in self.rank_terms().head(n).index.tolist() if "(identifier)" not in t]) + "'"

    @staticmethod
    def ignore_term(term):
        """List of terms to ignore"""
        if "(identifier)" in term or term == "doi":
            return True
        return False

In [None]:
G, nodelist, node_weights, model = create_graph(topics=["z-test","standard deviation","t-test","mean"], max_nodes=25)

In [None]:
def simplified_plot(G, nodelist, node_weights):
    pos = nx.spring_layout(G, k=1, seed=7)  # positions for all nodes - seed for reproducibility

    fig = plt.figure(figsize=(12,12))

    nx.draw_networkx_nodes(
        G, pos,
        nodelist=nodelist,
        node_size=node_weights,
        node_color='lightblue',
        alpha=0.7
    )

    widths = nx.get_edge_attributes(G, 'weight')
    nx.draw_networkx_edges(
        G, pos,
        edgelist = widths.keys(),
        width=list(widths.values()),
        edge_color='lightblue',
        alpha=0.6
    )

    nx.draw_networkx_labels(G, pos=pos,
                            labels=dict(zip(nodelist,nodelist)),
                            font_color='black')
    fig = plt.show()
    plt.show()

In [None]:
simplified_plot(G, nodelist, node_weights)