# Download Embeddings


In [None]:
!wget -c "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz"
!gzip -dk /content/GoogleNews-vectors-negative300.bin.gz

# Load word embeddings

In [3]:
from gensim.models import KeyedVectors

# Load pretrained model
model = KeyedVectors.load_word2vec_format('/content/GoogleNews-vectors-negative300.bin', binary=True)

# Class to contain a node

In [130]:
class Node: 

  def __init__(self, name, depth, cost):
    self.name = name
    self.depth = depth
    self.cost = cost
    self.value = 1 - cost + depth

# Text treatment

In [129]:
import urllib.request
import re


""" returns article link given title """
def get_link_title(title):
  return "https://en.wikipedia.org/wiki/"+"_".join(title.split())


""" returns article link given href """
def get_link_href(href):
  return "https://en.wikipedia.org/wiki/"+"_".join(href.split())


""" returns references of an article """
def get_refs(article):

  # fetch html code
  link = get_link_title(article)
  html = urllib.request.urlopen(link).read().decode('utf-8')

  # get the tiltles of the other articles
  references = re.findall("href=\"\/wiki/([A-Za-z0-9_]+?)\"", html)
  references = list(map(lambda x : " ".join(x.split("_")), references))

  return references

# Define Functions

In [131]:
import numpy as np
import time

""" get embedding of a word """
def get_vector(word):

  matrix = []
  for i in word.split():
    matrix.append( model[i] )
  
  return np.average(matrix, axis=0)


""" Compute cosine distance between two words """
def get_distance(node, goal):
  
  vector_1 = get_vector(node)
  vector_2 = get_vector(goal)

  return np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2))

""" tells if it is the goal """
def is_goal(article, goal, depth):

  references = get_refs(article)
  
  # if it's a goal, we return the 
  for ref in references :
    if ref.lower() == goal.lower():
      return True, get_link_href(ref)

  length = len(references)
  i = 0

  while i < length :

    if references[i] in model.vocab : 

      cost = get_distance(references[i], goal)
      references[i] = Node(references[i], depth, cost)
      i += 1
    
    else :

      del references[i]
      length -= 1
  
  return False, references


" Order closed and open lists, note that we are inserting in a sorted list"
def combine(open, result):

  length = len(open)

  for res in result :
    index = 0
    
    while index < length and res.value > open[index].value :
      index += 1

    open = open[:index] + [res] + open[index:]
  return open



# A* Algorithm

In [127]:
""" the actual algorithm """
def A_star(start, goal):


  came_from = {}

  node = Node(start, 1, 1)

  open = [node]
  closed = []

  while open :


    node = open[0]
    open = open[1:]

    if not ( node.name in closed ) :

      closed.append(node.name)

      found, result = is_goal(node.name, goal, node.depth)

      if found :
        path = [node]

        while came_from.get(node):
          parent = came_from.get(node)
          path = [parent] + path
          node = parent

        node_goal = Node(goal, parent.depth+1, 0)
        path.append(node_goal)
        return path

      # memorise path
      for i in result :
        came_from[i] = node

      # combine according to evaluation function
      open = combine(open, result)


""" Interface """
def find_path(start, goal):

  top = time.time()
  path = A_star(start, goal)
  end = time.time()

  t = end - top

  print("Solution found in ", t, "seconds\n")
  for p in path:
    print(p.name, get_link_title(p.name))

# Test

In [132]:
find_path("Tree", "Helium")

Solution found in  0.13705968856811523 seconds

Tree https://en.wikipedia.org/wiki/Tree
Lumberjack https://en.wikipedia.org/wiki/Lumberjack
Biochar https://en.wikipedia.org/wiki/Biochar
Oxygen https://en.wikipedia.org/wiki/Oxygen
Helium https://en.wikipedia.org/wiki/Helium
