#README

* Decimals
  * This program uses the Decimals type to compute up to 10^10 digits with eact precision
  * The parameters of the Decimals types is adjustable at the beginning of the program
  * Use **D()** to initialize a Decimal type
  * To convert decimals into quadratic roots, compute the continued fraction and reverse engineering after the periodic pattern is clear. (See the paper for more details.)

* How to Use
  * **inti_polydisk(b)** initializes the polydisk P(1,b)
  * **mutate(x)** mutates at the i-th vertex
  * **plot_nodes()** plots the current polygon
  * **print_embd()** prints the embedding for the current polygon

In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import decimal
from decimal import Decimal as D
decimal.getcontext().prec = 10000 # number of digits calculated
N = 29 # number of digits printed

In [None]:
class node (object):
  # a node contains a vertex, the nodal_ray attached to that vertex, and the edge departing from that vertex
  def __init__ (self, vertex, nodal_ray, edge, affine_length):
    self.vertex = [D(vertex[0]), D(vertex[1])]
    self.nodal_ray = [D(nodal_ray[0]), D(nodal_ray[1])]
    self.edge = [D(edge[0]), D(edge[1])]
    self.affine_length = D(affine_length)

In [None]:
def init_polydisk (x):
  # initialize polydisk P(1,x)
  global n
  global nodes

  n = 4
  nodes = [None] * 4
  nodes[0] = node([0,0], [1,1], [0,1], 1.)
  nodes[1] = node([0,1], [1,-1], [1,0], x)
  nodes[2] = node([x,1], [-1,-1], [0,-1], 1)
  nodes[3] = node([x,0], [-1,1], [-1,0], x)

In [None]:
def dist (x,y):
  # distance between x and y
  return ( (x[0]-y[0])**2 + (x[1]-y[1])**2 ).sqrt()

In [None]:
def dot (mat, vec):
  # multiplication of 2*2mat and 2*1vec
  return [ mat[0][0]*vec[0]+mat[0][1]*vec[1], mat[1][0]*vec[0]+mat[1][1]*vec[1] ]

In [None]:
def intersect_one (i,j):
  # solve the intersection between i-th nodal ray and j-th edge
  global n
  global nodes

  # copy as local variables
  n1 = nodes[i].vertex
  n2 = nodes[j].vertex
  n3 = nodes[(j+1)%n].vertex
  v1 = nodes[i].nodal_ray
  v2 = nodes[j].edge

  # solve for the intersection point
  vec = [ v1[1]*n1[0]-v1[0]*n1[1], v2[1]*n2[0]-v2[0]*n2[1] ]
  mat = [[ -v2[0], v1[0] ],
       [ -v2[1], v1[1] ]]

  itx = dot(mat, vec)
  itx[0] = itx[0] / (v1[0]*v2[1] - v1[1]*v2[0])
  itx[1] = itx[1] / (v1[0]*v2[1] - v1[1]*v2[0])

  # check the intersection is on the edge
  if abs(n2[0] - n3[0]) == 0:
    lmbda = (itx[1]-n3[1]) / (n2[1]-n3[1])
  else:
    lmbda = (itx[0]-n3[0]) / (n2[0]-n3[0])
  if (lmbda<0 or lmbda>1):
    return [-1,-1]
    
  return itx

In [None]:
def intersect_all (x):
  # solve the intersecting edge for the x-th nodal ray
  global n
  global nodes

  # the variables for the intersecting edge
  min_edge = x
  min_itx = []
  min_dis = math.inf
  
  for i in range(n):
    # i is adjacent to x
    if (i==x or i==(x-1)%n):
      continue

    # the itersection of x-th nodal ray and i-th edge is invalid
    itx = intersect_one(x,i)
    if (itx == [-1,-1]):
      continue
    
    # maintain the closest intersection
    dis = dist(nodes[x].vertex, itx)
    if (dis < min_dis):
      min_edge = i
      min_itx = itx
      min_dis = dis

  return (min_edge, min_itx)

In [None]:
def solve_matrix (v1, v2, w1, w2):
  # solve the eigen-direction
  mat = [ [w1[1], -v1[1]],
        [-w1[0], v1[0]] ]
  
  res = [ dot(mat, [v2[0],w2[0]]), 
        dot(mat, [v2[1],w2[1]]) ]
  
  res[0][0] = res[0][0] / (v1[0]*w1[1]-v1[1]*w1[0])
  res[0][1] = res[0][1] / (v1[0]*w1[1]-v1[1]*w1[0])
  res[1][0] = res[1][0] / (v1[0]*w1[1]-v1[1]*w1[0])
  res[1][1] = res[1][1] / (v1[0]*w1[1]-v1[1]*w1[0])

  return res

In [None]:
def sanity_check ():
  # check if the mutation was proper and print the result
  global n
  global nodes

  loc = [D(0),D(0)]
  for i in range(n):
    cur = nodes[i]
    print("vertex: ", round(cur.vertex[0],N), round(cur.vertex[1],N), end='\t')
    print("nodal ray: ", round(cur.nodal_ray[0],N), round(cur.nodal_ray[1],N), end='\t')
    print("edge: ", round(cur.edge[0],N), round(cur.edge[1],N), end='\t')
    print("affine length: ", round(cur.affine_length,N), end='\t')
    print()

    if (dist(loc, cur.vertex) > 1e-10):
      print("Failed the sanity check at the ", i, "-th node")
      return

    loc[0] = loc[0] + cur.affine_length * cur.edge[0]
    loc[1] = loc[1] + cur.affine_length * cur.edge[1]
  
  print("Passed the sanity check.\n")

In [None]:
def mutate_counterclockwise (head, tail, itx):
  # mutate with nodal_ray < intersecting edge
  global n
  global nodes 

  mat = solve_matrix( nodes[head].nodal_ray, nodes[head].nodal_ray, nodes[head].edge, nodes[(head-1)%n].edge )

  # construct the new node
  new_length = nodes[tail].affine_length * dist(itx, nodes[(tail+1)%n].vertex) / dist(nodes[tail].vertex, nodes[(tail+1)%n].vertex)
  new = node(itx, [-nodes[head].nodal_ray[0], -nodes[head].nodal_ray[1]], nodes[tail].edge, new_length)
  nodes = np.insert(nodes, tail+1, new)

  # adjust the head and tail node
  nodes[tail].affine_length -= new_length
  nodes[head-1].affine_length += nodes[head].affine_length
  nodes = np.delete(nodes, head)

  # update remaining nodes
  for i in range(head, tail):
    pre = nodes[(i-1)%n]
    nodes[i].vertex[0] = pre.vertex[0] + pre.affine_length * pre.edge[0]
    nodes[i].vertex[1] = pre.vertex[1] + pre.affine_length * pre.edge[1]
    nodes[i].nodal_ray = dot(mat, nodes[i].nodal_ray)
    nodes[i].edge = dot(mat, nodes[i].edge)

  sanity_check()

In [None]:
def mutate_clockwise (head, tail, itx):
  # mutate with nodal_ray > intersecting edge
  global n
  global nodes 

  mat = solve_matrix( nodes[tail].nodal_ray, nodes[tail].nodal_ray, nodes[(tail-1)%n].edge, nodes[tail].edge )

  # construct the new node
  new_length = nodes[head].affine_length * dist(itx, nodes[(head+1)%n].vertex) / dist(nodes[head].vertex, nodes[(head+1)%n].vertex)
  new = node(itx, [-nodes[tail].nodal_ray[0], -nodes[tail].nodal_ray[1]], nodes[head].edge, new_length)
  nodes = np.insert(nodes, head+1, new)

  # adjust the old head and tail node
  nodes[head].affine_length -= new_length
  nodes[tail].affine_length += nodes[tail+1].affine_length
  nodes = np.delete(nodes, tail+1)

  for i in range(head+1, tail+1):
    pre = nodes[(i-1)%n]
    nodes[i].vertex[0] = pre.vertex[0] + pre.affine_length * pre.edge[0]
    nodes[i].vertex[1] = pre.vertex[1] + pre.affine_length * pre.edge[1]
    nodes[i].edge = dot(mat, nodes[i].edge)
    # uncomment the following line if 
    #if (nodes[i-1].affine_length > D(1e-30)):
    nodes[i].nodal_ray = dot(mat, nodes[i].nodal_ray)
    

  sanity_check()

In [None]:
def mutate (x):
  # mutate once by x-th nodal_ray
  global n
  global nodes

  # y is the intersecting edge
  # itx is the intersection point
  (y, itx) = intersect_all(x)

  if (x<y):
    mutate_counterclockwise(x,y,itx)
    return y
  else:
    mutate_clockwise(y,x,itx)
    return y+1

In [None]:
def plot_nodes():
  # credit to Jemma
  global n 
  global nodes

  # plot the polygon
  shape_list = []
  for i in range(0,n):
    shape_list.append([float(nodes[i].vertex[0]), float(nodes[i].vertex[1])])
  shape_list.append([float(nodes[0].vertex[0]), float(nodes[0].vertex[1])])
  xs, ys = zip(*shape_list)

  plt.figure(dpi=120)
  plt.gca().set_aspect('equal', adjustable='box')
  plt.xlim([min(xs)-0.5,max(xs)+0.5])
  plt.ylim([min(ys)-0.5,max(ys)+0.5])
  plt.plot(xs,ys) 

  # plot the vertices
  for i in range(n):
    plt.plot(nodes[i].vertex[0], nodes[i].vertex[1], marker='o', markersize=2, color='black')

In [None]:
def print_embd ():
  # print embedding (z, lambda)
  global n
  global nodes

  x = nodes[0].affine_length
  y = nodes[n-1].affine_length
  if (x>y):
    (x,y) = (y,x)

  print("Embedding: ", round(y/x,N), round(1/x,N))
  print()

In [None]:
# an example that computes the accumulation point for P(1,b)
b = (6 + 5 * D(30).sqrt()) / 12

init_polydisk(b)
plot_nodes()
print_embd()

mutate(2)
plot_nodes()
print_embd()

mutate(2)
plot_nodes()
print_embd()

mutate(1)
plot_nodes()
print_embd()

mutate(3)
plot_nodes()
print_embd()

mutate(1)
plot_nodes()
print_embd()

mutate(1)
plot_nodes()
print_embd()

mutate(1)
plot_nodes()
print_embd()

mutate(1)
plot_nodes()
print_embd()