<a href="https://colab.research.google.com/github/AlecTraas/computational-geo-lab/blob/main/Colab/Ridge/RangeSearchingV3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
class node:

  def __init__(self, value):
    self.value = value
    self.lchild = None
    self.rchild = None
    self.subtree = None
    self.span = None

  def get_value(self):
    return self.value

  def set_rchild(self, child):
    self.rchild = child

  def set_lchild(self, child):
    self.lchild = child

  def get_rchild(self):
    return self.rchild

  def get_lchild(self):
    return self.lchild

  def is_leaf(self):
    return self.lchild is None and self.rchild is None

  def has_lchild(self):
    return not self.lchild is None

  def has_rchild(self):
    return not self.rchild is None

  def set_subtree(self, root):
    self.subtree = root

  def get_subtree(self):
    return self.subtree

  def set_span(self, nodes):
    self.span = nodes

  def get_span(self):
    return self.span

  def display(self):
    format = ""
    queue = [(self,0)]
    prev_depth = 0
    while len(queue) != 0:
      cur = queue.pop(0)
      curnode = cur[0]
      cur_depth = cur[1]
      if(cur_depth > prev_depth):
        print(format)
        format = ""
        #format += "/n
        prev_depth = cur_depth
      format += str(curnode.value)
      if(curnode.has_lchild()):
        queue.append((curnode.lchild,cur_depth+1))
      if(curnode.has_rchild()):
        queue.append((curnode.rchild,cur_depth+1))

    print(format)

  def compareTo(self, otherNode, nvars, i):
    for j in range(0, nvars): #occurs nvars times
      out = self.value[i] - otherNode.value[i]
      if out !=0:
        return out
      i = (i+1) % nvars
    return 0 #only return 0 if all are same



In [3]:
def quicksort(arr, nvars, i):

  # ASSUMES NO POINTS ARE EXACTLY EQUAL

  if len(arr) <= 1:
      return arr  # Base case: a list of zero or one elements is sorted, by definition.

  pivot = arr[len(arr) // 2]  # Choose the pivot element. Here we take the middle element as the pivot.

  left = []
  right = []

  for x in arr:
    comparison = compare_points(x,pivot,nvars,i)
    if comparison < 0:
      left += [x]
    elif comparison > 0:
      right += [x]

  # Recursively apply the Quicksort to the sub-arrays and concatenate the results
  return quicksort(left, nvars, i) + [pivot] + quicksort(right, nvars, i)

In [4]:
def compare_points(this_point, next_point, nvars, i):
  for j in range(0, nvars): #occurs nvars times
    out = this_point[i] - next_point[i]
    if out !=0:
      return out
    i = (i+1) % nvars
  return 0 #only return 0 if all are same


In [5]:
class database:

  def __init__(self, nvars, values=None):
    self.cur_id = 1
    self.table = []
    if values == None:
      self.size = 0
    else:
      self.size = len(values)
      for i in values:
        self.table += [[self.cur_id]+i]
        self.cur_id += 1

    self.nvars = nvars
    self.id = self.size + 1
    self.db_rt = range_tree(self.table, self.nvars) #check if nvars matches up


  def add(self, value):
    self.table += [[self.cur_id]+value]
    self.size += 1
    self.cur_id += 1


  def update_rt(self):
    self.db_rt = range_tree(self.table, self.nvars) #check if nvars matches up


  def display(self):
    header = "id"
    for i in range (1, self.nvars + 1):
      header += "\tvar" + str(i)
    print(header)
    for i in range(0,self.size):
      print("--------------------")
      point = self.table[i]
      line = ""
      for j in range(0, self.nvars + 1):
        line += str(point[j]) + "\t"
      print(line)

  def query(self, params):
    myquery = range_search_query(self.db_rt.root, params, self.nvars)
    return myquery.answer




In [69]:
class range_tree:

  def __init__(self,points,nvars):
    temp = nvars
    self.root = self.rec_range_tree(points, temp, 0)


  def rec_range_tree(self, points, nvars, i):

    if len(points) == 0:
      return None

    if len(points) == 1:
      root = node(points[0])
      if i < nvars:
        root.set_subtree(node(points[0]))
      else:
        root.set_span(points)

    else:
      median = len(points) // 2
      root = node(points[median])
      root.set_lchild(self.rec_range_tree(points[:median],nvars, i))
      root.set_rchild(self.rec_range_tree(points[median:],nvars, i)) # median included in right tree

      if i < nvars:
        #ypoints = sorted(points, key=lambda x: x[i+1]) # sorted on next variable
        ypoints = quicksort(points,nvars,i+1) #sort on next var
        root.set_subtree(self.rec_range_tree(ypoints, nvars, i+1))
      else:
        root.set_span(points)

    return root

  def display(self):
    format = ""
    queue = [(self.root,0)]
    prev_depth = 0
    while len(queue) != 0:
      cur = queue.pop(0)
      curnode = cur[0]
      cur_depth = cur[1]
      if(cur_depth > prev_depth):
        print(format)
        format = ""
        #format += "/n
        prev_depth = cur_depth
      format += str(curnode.value)
      if(curnode.has_lchild()):
        queue.append((curnode.lchild,cur_depth+1))
      if(curnode.has_rchild()):
        queue.append((curnode.rchild,cur_depth+1))

    print(format)


In [94]:
class range_search_query:

  def __init__(self, root, params, nvars):
    self.answer = []
    self.rec_query(root, params, nvars, 0)

  def find_split_node(self, root, params, nvars, i): #pass in all queries so can use compare
    curnode = root
    curval = curnode.get_value()[i] #use compare

    start = []
    end = []
    for point in params:
      start += [point[0]]
      end += [point[1]]

    end_cmp = compare_points(end,curnode.get_value(),nvars,i) < 0
    start_cmp = compare_points(start,curnode.get_value(),nvars,i) >= 0

    #while not curnode.is_leaf() and (end < curval or start >= curval): #use compare

    while not curnode.is_leaf() and (end_cmp or start_cmp):
      if end_cmp: #use compare
        curnode = curnode.get_lchild()
      else:
        curnode = curnode.get_rchild()

      end_cmp = compare_points(end,curnode.get_value(),nvars,i) < 0
      start_cmp = compare_points(start,curnode.get_value(),nvars,i) >= 0

    return curnode


  def check_leaf(self, leaf, start, end, i): #pass in all queries???? -> no; only used for going left or right
    value = leaf.get_value()[i]
    return value >= start and value <= end


  def rec_query(self, root, param, nvars, i):

    last_var = i >= nvars

    start = []
    end = []
    for point in param:
      start += [point[0]]
      end += [point[1]]

    split = self.find_split_node(root, param, nvars, 0)

    # split is leaf -> check if included
    if split.is_leaf():
      include_leaf = self.check_leaf(split, param[i][0], param[i][1], i)
      if include_leaf and last_var:
        self.answer += [split.get_value()]
      elif include_leaf:
        self.rec_query(split, param, nvars, i+1)
      return

    # follow left path
    curnode = split.get_lchild()

    while not curnode.is_leaf():
      follow_left = compare_points(start, curnode.get_value(), nvars, i) < 0

      #follow_left = start < curnode.get_value()[i]
      if follow_left and last_var:
        self.answer += [curnode.get_rchild().get_span()]
        curnode = curnode.get_lchild()
      elif follow_left:
        self.rec_query(curnode.get_rchild().get_subtree(), param, nvars, i+1)
        curnode = curnode.get_lchild()
      else:
        curnode = curnode.get_rchild()

    # is terminal leaf included
    if self.check_leaf(curnode, param[i][0], param[i][1], i):
      if last_var:
        self.answer += [curnode.get_value()]
      else:
        self.rec_query(curnode, param, nvars, i+1)

    # follow right path
    curnode = split.get_rchild()

    while not curnode.is_leaf():
      #follow_right = end >= curnode.get_value()[i]
      follow_right = compare_points(end, curnode.get_value(), nvars, i) >= 0
      if follow_right and last_var:
        self.answer += [curnode.get_lchild().get_span()]
        curnode = curnode.get_rchild()
      elif follow_right:
        self.rec_query(curnode.get_lchild().get_subtree(), param, nvars, i+1)
        curnode = curnode.get_rchild()
      else:
        curnode = curnode.get_lchild()

    # is terminal leaf included
    if self.check_leaf(curnode, param[i][0], param[i][1], i):
      if last_var:
        self.answer += [curnode.get_value()]
      else:
        self.rec_query(curnode, param,nvars,i+1)

In [99]:
mypoints = [[9,23],[2,23],[8,23],[6,23],[7,23]]
mypoints = quicksort(mypoints, 2, 0)
rt = range_tree(mypoints, 1)
rquery= range_search_query(rt.root, [[8,100],[23,100]], 1)
print(rquery.answer)

[[8, 23], [9, 23]]
