In [3]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import time, random

In [4]:
class Timer:
    def __init__(self):
        self.CURTIME = 0
    def dtime(self, msg):
        o, self.CURTIME = time.time_ns()/1000-self.CURTIME, time.time_ns()/1000
        tf.print(f"mark: {msg}, time since called: {o}")

In [5]:
class Node(dict):
    bino = lambda x,w: tf.einsum("i->", w*binomial(x))
    def __init__(self, *args, **kwds):
        super(Node, self).__init__(*args, **kwds)
        self.__dict__ = self
    
    def __call__(self, x):
        i = [0,0]
        if not isinstance(x, list):
            x = list(x)
        if isinstance(self.left, int):
            i[0] = x[self.left]
        else:
            i[0] = self.left(x)
        if isinstance(self.right, int):
            i[1] = x[self.right]
        else:
            i[1] = self.right(x)
        return self.op(i, self.weights)
    
    def __getitem__(self, key):
        if key == 0:
            return self.left
        elif key == 1:
            return self.right
        else:
            raise IndexError(f"Tree has only 2 branches, can't get branch {key}")

    def __setitem__(self, key, value):
        if key == 0:
            self.left = value
        elif key == 1:
            self.right = value
        else:
            raise IndexError(f"Tree has only 2 branches,  can't set branch {key}")
    
    def __hash__(self):
        return hash(
            str(hash(self.op))
            +str(hash(self.weights))
            +str(hash(self.left))
            +str(hash(self.right))
        )
    
    def __repr__(self):
        w = self.weights
        eq = f"{w[0]}x^2 + {w[1]}x + {w[2]}xy + {w[3]}y + {w[4]}y^2 + {w[5]}"
        return ' '.join(('( eq:', eq, 'l:', str(self.left), 'r:', str(self.right), ')'))
        


In [6]:
def binomial(x):
    return tf.concat(
        [
            x[0]**2,
            x[0],
            x[0]*x[1],
            x[1],
            x[1]**2,
            1.
        ], 0)

In [56]:
MAX_STEP = 1000
SSF = MAX_STEP/5
timer = Timer()
def gradient_descent_node(test_inputs, test_results, descent_rate, max_grad=1, max_err=10, max_grad_conv=None, assumptions=None):
    
    if assumptions is None:
        assumptions = tf.ones(shape=(6), dtype=tf.float32)
        
    if max_grad_conv is None:
        max_grad_conv = max_grad*1e-2
    
    inputs_bino = tf.TensorArray(tf.float32, size=test_inputs.shape[0])
    for i in tf.range(test_inputs.shape[0]):
        inputs_bino = inputs_bino.write(i, binomial(test_inputs[i]))
    inputs_bino = inputs_bino.stack()

    grad = 1e3+1
    gradp = 0.
    step = 0
    while True:
        weights = tf.Variable([random.random()*assumptions[_] for _ in range(6)], dtype=tf.float32)
        y = tf.einsum("j,ij->i",weights,inputs_bino)
        eps = (y - test_results)
        grad = tf.einsum("i->", eps)
        err = tf.einsum("i->", eps**2)
        if random.random() > 1/np.log(err):
            break
    grad = 1001.
    while (((tf.abs(tf.abs(grad)-tf.abs(gradp)) > max_grad_conv and tf.abs(grad) > max_grad) 
            or err > max_err) and step < MAX_STEP):
        gradp = grad
        y = tf.einsum("j,ij->i",weights,inputs_bino)
        eps = (y - test_results)
        jacobi = tf.einsum("ij->i",eps*tf.transpose(inputs_bino))
        grad = tf.einsum("i->", jacobi)
        scale = (SSF-(step%SSF))/(SSF)
        weights.assign_sub(descent_rate*scale*jacobi*assumptions/tf.norm(jacobi))
        step+=1
        err = tf.einsum("i->", eps**2)
    return {"weights": weights, "error": err}

In [57]:
def regress(goal, 
            arg_constraints, 
            n_points, 
            assumptions=None,
            keep=3,
            err_goal=1,
            max_nodes=30,
            batch_size=15
           ):
    """
    goal: function [int] -> int
    arg_constraints: [[left bound: int, right bound: int]]
    """
    input_vectors = []
    layer_structure = []
    nargs = len(arg_constraints)
    for a in arg_constraints:
        input_vectors.append(tf.constant([random.random()*(a[1]-a[0])+a[0] for _ in range(n_points)], dtype=tf.float32))
    
    out = tf.concat(list(map(goal, tf.transpose(tf.stack(input_vectors)))), 0)
    
    error = err_goal+1
    
    batch = 0
    
    while np.abs(error) > err_goal and len(layer_structure)*keep < max_nodes:
    
        nodes = []
        batch += 1
        
        for i in range(batch_size):
            print(f"node {i} batch {batch}")
            left = random.randrange(nargs)
            
            if len(input_vectors) > len(arg_constraints):
                right = random.randrange(len(input_vectors)-keep, len(input_vectors))
            else:
                right = random.randrange(nargs)

            inp = tf.transpose(tf.stack([input_vectors[left], input_vectors[right]]))

            g = gradient_descent_node(test_inputs=tf.constant(inp, dtype=tf.float32),
                                      test_results=out,
                                      descent_rate=1e-1,
                                      max_grad=1e-1,
                                      max_err=1,
                                      assumptions=assumptions)
            
            nodes.append(Node(left=left, right=right, op=Node.bino, weights=g["weights"], err=g["error"]))

        nodes.sort(key=lambda x: np.abs(x.err))
        error = nodes[0].err
        nodes = nodes[:keep]

        layer_structure.append([])
        for n in nodes:
            if n.right >= nargs:
                n.right = layer_structure[(n.right-nargs)//keep][(n.right-nargs)%keep]
            input_vectors.append(tf.concat([n(x) for x in inp], 0))
            layer_structure[-1].append(n)
    
    return layer_structure

In [None]:
t_f = lambda x: x[0]**3+x[1]**2
t_p = [(random.random(), random.random()*np.pi) for _ in range(100)]
t_y = list(map(t_f,t_p))

bino = lambda x,w: tf.einsum("i->", w*binomial(x))
ass = tf.constant([1,1,1,1,1,1], dtype=tf.float32)
o = regress(goal = t_f, arg_constraints=[[-5,10], [-5, 10]], n_points=100, err_goal=1e-3, max_nodes=12)

o

node 0 batch 1
node 1 batch 1
node 2 batch 1
node 3 batch 1
node 4 batch 1
node 5 batch 1
node 6 batch 1
node 7 batch 1
node 8 batch 1
node 9 batch 1
node 10 batch 1
node 11 batch 1
node 12 batch 1
node 13 batch 1
node 14 batch 1
node 0 batch 2
node 1 batch 2
node 2 batch 2
node 3 batch 2
node 4 batch 2
node 5 batch 2
node 6 batch 2
node 7 batch 2
node 8 batch 2
node 9 batch 2
node 10 batch 2
node 11 batch 2
node 12 batch 2
node 13 batch 2
node 14 batch 2
node 0 batch 3
node 1 batch 3
node 2 batch 3
node 3 batch 3
node 4 batch 3
node 5 batch 3
node 6 batch 3
node 7 batch 3
node 8 batch 3
node 9 batch 3
node 10 batch 3
node 11 batch 3
node 12 batch 3
node 13 batch 3
node 14 batch 3


In [None]:
from matplotlib import pyplot as plt

n=o[-1][0]

r_x = np.linspace(-15, 15)
r_y1 = [t_f([p, 0]) for p in r_x]
r_y2 = [t_f([p, 5]) for p in r_x]
r_y3 = [t_f([p, -5]) for p in r_x]
n_y1 = [n(tf.constant([p, 0], dtype=tf.float32)) for p in r_x]
n_y2 = [n(tf.constant([p, 5], dtype=tf.float32)) for p in r_x]
n_y3 = [n(tf.constant([p, -5], dtype=tf.float32)) for p in r_x]

plt.plot(
    r_x, r_y1, "r--",
    r_x, r_y2, "g--",
    r_x, r_y3, "b--",
    r_x, n_y1,  "ro",
    r_x, n_y2,  "go",
    r_x, n_y3,  "bo",
)