# Knot Image Dataset Generator

Note: A modified version of pyknotid needs to be downloaded to use this program

## The Problem
No available datasets exist for machine learning applications of knot theory research. The purpose of this program is to generate an arbitrary realistic dataset of labeled knot images and gauss codes to train an AI on. Additionally, while packages exist to generate images of knots, there are currently know packages that create knot projections with gaps between strands to denote a crossing strand going 'under' another crossing. To this end, this project builds on and adds this functionality to the Pyknotid package.

### The Knot Catalogue is downloaded for testing purposes

In [None]:
from pyknotid.catalogue.getdb import download_database
download_database()
from pyknotid.catalogue.indentify import get_knot

In [None]:
from PIL import Image
import numpy as np
%gui qt
%pylab qt
import os, os.path
import numpy as np
from pyknotid.catalogue import get_knot, from_invariants
from shapely.geometry import Point, LineString
from pyknotid.spacecurves import Knot
from pyknotid.representations import GaussCode, Representation
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
from collections import defaultdict
from tqdm import tqdm
import glob

### Preparation for conversion from Pyknotid gauss code to Knot_Awesome gauss code
__Overview:__ The gauss code format used in Pyknotid is unlike the Knot_Awesome gauss code. The purpose of the splitter() function is to split a single element of Pyknotid gauss code into its components to prepare for translating the Pyknotid gauss code into Knot_Awesome gauss code format.

_Example:_ Pyknotid code: '1+c,2-c,3+c,1-c,2+c,3-c'<br>
>s = '1+c'<br>
>splitter(s) $\rightarrow$  number: '1', sign: '+', hand: 'c'


In [None]:
def splitter(s):
    
    '''
    input: 
        s      - a string from an element of gauss code ( ex. s = '1+c' )
    
    output:
        number - The numerical character of 's' (ex. number = '1')
        sign   - The sign character of 's'          (ex. sign = '+')
        hand   - The letter character of 's'    (ex. hand = 'c')
    '''
    
    if ('+' in s) and (len(set(s))>1): # If the string triplet contains '+' and the string is longer than 1
        sign = '+'                     # Set sign equal to '+'
        number, hand = s.split(sign)   # Set number equal to the first element of s.split according to '+'
                                       # Set hand equal to the last element of s.split according to '+'
        
    elif ('-' in s) and (len(set(s))>1): # If the string triplet contains '-' and the string is longer than 1
        sign = '-'                       # Set sign equal to '-'
        number, hand = s.split(sign)     # Set number equal to the first element of s.split according to '-'
                                         # Set hand equal to the last element of s.split according to '-/'
        
    else:                 # If the string triplet contains neither '+' nor '-' return a False boolean
        return False
    
    return number, sign, hand  

### Convert Pyknotid gauss code into Knot_Awesome gauss code
__Overview:__ Since Pyknotid gauss code is unlike Knot_Awesome gauss code the gauss_reverter() function accepts a Pyknotid gauss code and translates it into a Knot_Awesome gauss code format.

_Example:_ Pyknotid code: g = '1+c,2-c,3+c,1-c,2+c,3-c'<br>
>gauss_reverter(g)$\rightarrow$ numbers: [1,-2,3,-1,2,-3],ex_g_code: [1,1,1] <br>
Knot_Awesome code: [[1,-2,3,-1,2,-3],[1,1,1]]

In [None]:
 def gauss_reverter(g):
    
    '''
    input: 
        g         - a string of a list of integers written in Pyknotid gauss code format ( ex. g = '1+c,2-c,3+c,1-c,2+c,3-c' ) 
    
    output:
        The two parts that make up the Knot_Awesome extended gauss code
        numbers   - a list of integers written in gauss code format (ex. numbers = [1,-2,3,-1,2,-3])
        ex_g_code - a list of integers written in extended gauss code format (ex. ex_g_code = [1,1,1])
    '''
    
    # Define empty lists to keep track of each aspect of the Gauss Code
    numbers = []
    signs = []
    hands = []
    
    
    gauss = g.split(',')   # Split g into a list separated by ','
    for i in gauss:        # Iterate through each element in 'gauss'
        
        if len(set(i))<=2: # If the length of the string is smaller than 3
            ex_g_code = []                   # ex_g_code will be the unknot and therefore
            return numbers,ex_g_code         # Return an empty pair of lists
        
        
        else:              # However, if the length of the string is greater than 2     
            number, sign, hand = splitter(i) # run 'i' through the splitter function and
                                             # append the results to the 'numbers','signs', and 'hands'
                                             # lists respectively
            numbers.append(int(number))      
            signs.append(sign)
            hands.append(hand)
    
    
    # Now create a numpy array of all zeros the same length as the maximal value of numbers
    ex_g_code = np.zeros(max(numbers))

    
    for i,j in enumerate(numbers): # Iterate through 'numbers' getting the value and index of each element
        
        if signs[i] == '-':          # If the sign at index 'i' of 'signs' is '-'
            numbers[i] *= -1                 # Then multiply the number at index 'i' by -1
        
        if ex_g_code[abs(j)-1] == 0: # If ex_g_code at the absolute value of 'j-1' is 0
        
            if hands[i] == 'a':              # and If 'hands' at index 'i' is 'a' (for 'anticlockwise')
                ex_g_code[abs(j)-1] = -1            # Then set ex_g_code at index 'j-1' to -1 (meaning an anticlockwise crossing) 
            
            else:                            # and if 'hands' at index 'i' is NOT 'a' (meaning 'clockwise')
                ex_g_code[abs(j)-1] = 1             # Then set ex_g_code at index 'j-1' to 1  (meaning a clockwise crossing)
    
    #Once the loop is done return the lists 'numbers' and 'ex_g_code'
    return numbers, ex_g_code

### Split array of points into x-values, y-values, crossing-types, and crossing-hand
__Overview:__ The get_new_points() function takes as input a list of real-valued coordinates in $\mathbb{R}^3$ of at least length 10 and returns:
- lists of the x and y coordinates of the knot projection
- lists of wheter each crossing strand is an over or undercrossing
- a list of the 'handedness' of each crossing type. 

In [None]:
def get_new_points(points):
    
    """
    input:
        points     - a list of real-valued coordinates (x,y,z)
        
    output:
        x          - a list of the x-coordinates of the knot generated from 'points'
        y          - a list of the y-coordinates of the knot generated from 'points'
        crossersx  - a list of the x-coordinates of the knot crossings from 'points'
        crossersy  - a list of the y-coordinates of the not crossings from 'points'
        hand_clock - a list of the handedness of each crossing of the knot from 'points'
    """
    
    if len(points) < 10: # Set the minimum tolerance on number of points to 10
        return
    
    k = Knot(points) # Use pyknotid function "Knot" to instantiate a Knot object,'k', from 'points'
    
    k.plot_projection(show=False)   # Plot the knot projection
    az = plt.gca()                  # Get the list of (x,y) coordinates from the plotted knot
    
    a = az.lines[0]
    x = a.get_xdata() # Get the x-coordinate data from 'az'
    y = a.get_ydata() # Get the y-coordinate data from 'az'
    plt.close()       # Close the plot since it's not needed
    
    
    crossers2 = k.raw_crossings2(include_closure=False,recalculate=False)      # Use modified Pyknotid 'raw_crossings2' function to get the x,y coordinates of the knot projections crossings
    crossers2 = [i for i in crossers2 if len(i)!=0]                            # Put crossers2 into a list
    
    crossersx = [int(i[0][0]) for i in crossers2]+[i[0][1] for i in crossers2] # crossersx is a list of the x-coordinates involved in under/over crossings
    crossersy = [int(i[0][1]) for i in crossers2]+[i[0][0] for i in crossers2] # crossersy is a list of the y-coordinates involved in under/over crossings
    hand_clock = [i[2][0] for i in crossers2]+[-i[2][0] for i in crossers2]    # hand_clock is a list of the handedness (clockwise or anticlockwise) for each crossing
    
    return x,y, crossersx, crossersy,hand_clock

### Find the coordinates of the knot's intersections
__Overview:__ Pyknotid does not store the x and y coordinates of the actual crossings. The get_intersections() function takes the closest the 4 pairs of x,y coordinates that make up each crossing, calculates the center point between them and labels these coordinates as the crossing coordinates. It then returns the coordinates of the crossing.

In [None]:
def get_intersections(cross_x,cross_y,points):
    '''
    input:
        crossx  - a list of the indices of the nearest x-coordinate for each crossing
        crossy  - a list of the indices of the nearest y-coordinate for each crossing
        points  - the list of points describing the knot
        
    output:
        intersectx - a list of the x coordinates for each crossing
        intersecty - a list of the y coordinates for each crossing
    '''
    
    # Define empty lists to keep track of the intersections for both x and y coordinates
    intersectx = []
    intersecty = []
    
    for i,j in enumerate(cross_x): # Iterate through cross x keeping track of the index 'i' and the value 'j'
        
        x1 = int(j)          # index of x value for the crossing
        y1 = int(cross_y[i]) # index of y value for the crossing
        x2 = x1+1            # index of next x value in the strand
        y2 = y1+1            # index of next y value in the strand
      
        # interpolate a line between each pair of points involved in the crossing
        line1 = LineString([(points[x1][0],points[x1][1]),(points[x2][0],points[x2][1])])
        line2 = LineString([(points[y1][0],points[y1][1]),(points[y2][0],points[y2][1])])
        
        #Find the intersection of the lines parameterized by the four points
        intersect = line1.intersection(line2)
        intersectx.append(intersect.coords[0][0]) # Append the x coordinate of the intersection to intersectx
        intersecty.append(intersect.coords[0][1]) # Append the y coordinate of the intersection to intersecty
    
    return intersectx,intersecty



### Include the crossing coordinates in the arrays of points
__Overview:__ Now that the exact coordinates of each crossing have been found, they need to be added to the list of points in the knot. the new_points() function does just this. It takes the list of points describing the knot and figures out the correct placement for each x and y coordinate, so that when the knot is plotted, the crossings show up in the correct locations. Additionally, if the current crossing strand is an understrand a gap is included in the points so that when the knot is plotted a visual gap will exist where the 'rope' goes under the crossing. Anywhere the crossing strand is an overstrand the points are made continuous so that when the knot is plotted the strand will continuously go through the crossing. Finally, since a b-spline interpolation creates a non-closed loop, points are added to the ends of the list of coordinates to close the knot.

In [None]:
def new_points(x,y,cx,cy,ix,iy,points,hand,cross):
    '''
    input:
        x,y    - lists of the x,y points of the knot
        cx,cy  - list of coordinates directly before crossing
        ix,iy  - list of crossing coordinates
        points - a list of the original knot points
        hand   - the extended part of the gauss code
        
    output:
        newx   - the x coordinates including the crossing coordinates
        newy   - the y coordinates including the crossing coordinates
        cross  - the extended part of the gauss code 
    '''
    
    # define empty lists to hold the new points after the coordinates are appended.
    newx = []
    newy = []
    
    # define empty lists to hold the current strand
    subx = []
    suby = []
    cross = []
    l = len(x)
    m = len(cx)
    q = 0
    holder = 0
    
    for i in range(l): #Loop through each entry of x
    
        subx.append(x[i]) # append x at i to the current x-strand
        suby.append(y[i]) # append y at i to the current y-strand
        cross.append(0)   # for each appended entry append a 0 to cross
        
        for j in range(m): # loop through each crossing point in the list of crossing points
        
            if (points[int(cx[j])][0]==x[i]) and (points[int(cx[j])][1]==y[i]): #if the crossing point is the current point        
            
                if cross[holder] == 1: #if we have an overcrossing, combine the lists
                    subx.append(ix[j]) #append the value
                    suby.append(iy[j])
                    holder+=1
                
                else: #if we have an undercrossing
                    subx.append(ix[j])
                    suby.append(iy[j])
                    newx.append(subx)
                    newy.append(suby)
                    subx = []
                    suby = []
                    subx.append(ix[j])
                    suby.append(iy[j])
                    holder+=1
    
    subx.append(x[i:])
    suby.append(y[i:])
    subx += newx[0]
    suby += newy[0]
    newx[0] = subx
    newy[0] = suby
    newx.append(subx)
    newy.append(suby)
    
    return newx,newy,cross



### Roll the array so the first element drawn is an open-ended strand
__Overview:__ If the integer list cross is left alone, then when crossing types are factored into the image plotting, the plotted crossings will be off by one (crossing 2 will be plotted as crossing 1 etc.) The roll_cross() function rolls the list of crossings to prevent this.

_Example:_ cross = [1,-2,3,-1,2,-3]<br>
>roll_cross(cross) $\rightarrow$ [-2,3,-1,2,-3,1]

In [None]:
def roll_cross(cross):
    """
    input:
        cross - a list of integers making up the gauss code
    
    output:
        cross - a modified list of integers making up a the gauss code
    """
    
    cross = [int(i) for i in cross if i !=0]
    if len(cross)== 0:
    
        return
    
    return cross



### Get a random gauss code from pyknotid
__Overview__ To make the dataset robust, we need non-contrived examples of knots. The get_code() function generates a random length list (of length greater than 10) of uniformly sampled random points on a unit sphere. It then uses the points to generate a pyknotid Knot() object and uses the gauss_code() function to get and return the knot objects corresponding Pyknotid gauss code. 

In [None]:
def get_code(low=5,high=25):
    """
    input:
        low  - integer representing the lower bound on crossings for the generated knot
        high - integer representing the upper bound on crossings for the generated knot
    
    output:
        k - the gauss code for a knot object
    """
    
    low *= 2
    gauss = np.ones(21)
    
    while len(gauss) != low: # Check to ensure the knot has at least 'low' number of crossings
        
        # Generate a random number of points
        point_amt = np.random.randint(low,high) 
        points = np.random.rand(point_amt,3)
        
        pm = lambda x:x*np.random.choice([-1,1])
        
        mult = np.array([np.random.randint(1,30) for i in range(point_amt)])
        points = [i*np.random.randint(1,100)*np.random.choice([-1,1]) for i in points]
        k=Knot(points,add_closure=True) #Generate knot object from points tuple
        gauss = k.gauss_code()
        
    return k
            


### Get the coordinates of the plotted knot from the random gauss code
__Overview:__ If the knot from the pyknotid gauss code from the get_code() function was plotted, the resulting projection would be pointy and boxy. This hardly represents realistic looking knots and so a little more cleaning is necessary to make the images look nice. To this end, the get_points() function generates a random knot using the get_code() function, simplifies some trivial moves in the knot, and then interpolates the resulting points wiht a b-spline interpolation. The function then returns the set of interpolated points along with the knot's gauss code. When these points are plotted, they result in a nice-looking smooth curve.

In [None]:
def get_points(l,h,r,numpo,terp,cur_cross):
    """
    input:
        l         - an integer representing the lower bound on the number of crossings in the generated knot
        h         - an integer representing the upper bound on the number of crossing sin the generated knot
        r         - an integer indicating the number of simplification runs to perform on the generated knot
        numpo     - an integer indicating the number of points required to describe the knot
        terp      - an integer indicating the number of points used to interpolate the knot
        cur_cross - an integer indicating the number of crossings the knot is required to have
    
    output:
        k  - the generated knot object
        gc - the generated knot gauss code
    """
    gauss = []
    while len(gauss) != cur_cross:
        k = get_code(low=l,high=h)
        k.octree_simplify(runs=r)
        k = k.reparameterised(num_points = numpo)
        k.interpolate(terp)
        gauss = k.gauss_code(recalculate =True)
    gc = str(gauss)  
    return k,gc            



### Grab the intervals that contain over crossings and undercrossings
__Overview:__ The get_over_under() function takes the main portion of the gauss code as input and returns a list indicating whether each part of the crossing is the under crossing or the over crossing.

_Example:_ gc=[1,-2,3,-1,2,-3]<br>
>get_over_under(gc) $\rightarrow$[1,-1,1,-1,1,-1]

In [None]:
def get_over_under(gc):
    """
    input:
        gc - a list of integers representing the gauss code of a knot
    
    output:
        ov_und - a list of integer representing whether each crossing strand is an over or under crossing
    """
    ov_und = []
    for i in gc:
        if i < 0:
            ov_und.append(-1)
        else:
            ov_und.append(1)
    return ov_und



### Rotate an image 45 degrees
__Overview:__ The forty_five() function takes as input an image of a knot and rotates the image 45 degrees

In [None]:
def forty_five(I):
    """
    input:
        I - an image object
    
    output:
        out - the image 'I' rotated 45 degrees clockwise
    """
    I2 = I.convert('RGBA')
    rot = I2.rotate(45,expand=1)
    fff = Image.new('RGBA',rot.size,(255,)*4)
    out = Image.composite(rot,fff,rot)
    return out

### Generate a random image
__Overview:__ The get_img_rp() function, combines all of the previous functions together to generate a single Knot image. This image is then saved in a specified file path under a specified name.

In [1]:
def get_img_rp(l,hg,r,numpo,terp,filenum,cur_cross,path):
    """
    input:
        l         - an integer representing the lower bound on the number of crossings in the generated knot
        hg        - an integer representing the upper bound on the number of crossing sin the generated knot
        r         - an integer indicating the number of simplification runs to perform on the generated knot
        numpo     - an integer indicating the number of points required to describe the knot
        terp      - an integer indicating the number of points used to interpolate the knot
        filenum   - an integer indicating the number of the file created (used for organization purposes)
        cur_cross - an integer indicating the number of crossings the knot is required to have
        path      - a string indicating the path to the folder where the dataset is to be stored
        
    output:
        an integer indicating how many files have successfully been generated
        
    """
    
    i = True
    while i:
        #gc2 = []
        k,gc1 = get_points(l,hg,r,numpo,terp,cur_cross) # Grab the x,y coordinates for the knot,'k', and the corresponding gauss code
        gc2,ext = gauss_reverter(gc1)                   # Change the gauss code format from PyKnotId gauss code to Knot_Awesome gauss code
        gc2_out = gc2                                   # Copy the gauss code for filename
        points = k.points                               # points is the x,y coordinates for k 
        
        # Get the points for thek not and the intersections
        x,y,cx,cy,h = get_new_points(points)            
        ix,iy = get_intersections(cx,cy,points)
        h = get_over_under(gc2)
        
        
        if len(h) != len(cx): # If the knot has no crossings, start the loop over and generate a new knot.
            
            continue
        
        else: # If the knot has enough crossings, then
            
            # Find the intersections and add the gaps between crossings
            newx,newy,cross = new_points(x,y,cx,cy,ix,iy,points,h,h)
            
            folder = str(int(len(cross)/2)) # Generate the new folder name for the current image set
            cross = roll_cross(cross)
            fig = plt.figure()
            
            # The following loop determines the size of the gaps in each crossing
            for k in range(len(newx)):
                l = len(newx[k])
                if l <= 10:
                    h = 2
                elif l <= 50:
                    h = 5
                elif l <= 200:
                    h = 10
                else:
                    h = 20
                    
                plt.plot(newx[k][h:-h],newy[k][h:-h],linewidth = 1.5,color='k') # Plot the knot
                
            filename = str(gc2)+','+str(ext) # Create the name of the file
            plt.axis('off')

            new_path = path+folder # Create the new path destination to save the images into
            
            if not os.path.exists(new_path): # In the event the path doesn't exist, create the folder
                os.makedirs(new_path)

            new_p= new_path+'\ '+filename+','+str(filenum)+'.png' 
            
            # Save the picture
            plt.savefig(new_p)
            plt.close()

            #return gc2_out,ext
            return len(os.listdir(new_path))

### Building the Dataset
__Overview:__ Despite efforts to generate random, realistic knots, it is still possible that trends in the image data still exists. To combat this and further diversify the data, each image is rotated by 45,90,135,180,225,270,and 315 degrees and the result is saved. In the following cell, all functions are implemented to generate and save the new image dataset.

In [None]:
firstpath = # The file path for the dataset

for ii in range(3,12):
    loop = tqdm(total=(1000),position=0,leave=False)
    file_size = 0
    while file_size < 1000:
        loop.set_description('ii:{},file size:{}'.format(ii,file_size))
        loop.update(1)
        file_size = get_img_rp(l = 10, hg = 20, r = 10, numpo = 1000, terp= 5000,filenum=i,cur_cross = ii)
    loop.close()
        

        
# Write the gauss codes to a txt file to create a dataset of feasible gauss codes
with open("ImageGC.txt","w") as file:
    h = 5000
    lengths = {}
    loop = tqdm(total=(h),position=0,leave=False)
    w = 6
    for i in range(h):
        loop.set_description('N:{}'.format(i+1))
        loop.update(1)
        gc,ext = get_img_rp(l = 10, hg = 20, r = 10, numpo = 1000, terp= 5000,filenum=i,cur_cross = w)
        gc_len = int(len(gc)/2)
        if gc_len in lengths:
            lengths[gc_len] += 1
        else:
            lengths[gc_len] = 1
        file.write(str(gc)+','+str(ext))
        file.write("\n")
    loop.close()

    

# Create the image and its corresponding 7 rotated images and save them to the file path
for i in range(20):    
    path =firstpath + str(i)+'\*.png'
    files = glob.glob(path)

    loop = tqdm(total=(h),position=0,leave=False)
    for i,file in enumerate(files):
        loop.update(1)
        knotImage = Image.open(file)
        rotated = knotImage.rotate(90)
        rotated.save(file[:-4]+'90'+file[-4:])
        rotated = knotImage.rotate(180)
        rotated.save(file[:-4]+'180'+file[-4:])
        rotated = knotImage.rotate(270)
        rotated.save(file[:-4]+'270'+file[-4:])
        rotated = forty_five(knotImage)
        rotated.save(file[:-4]+'45'+file[-4:])
        a = rotated.rotate(90)
        a.save(file[:-4]+'135'+file[-4:])
        b = rotated.rotate(180)
        b.save(file[:-4]+'225'+file[-4:])
        c = rotated.rotate(270)
        c.save(file[:-4]+'315'+file[-4:])
    loop.close()

    with open("ImageDict.txt","w") as dict_file:
        lists = sorted(lengths.items())
        x,y =zip(*lists)
        plt.plot(x,y)
        plt.show()
        dict_file.write(str(lengths))