In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib widget

from tqdm.notebook import tqdm
import os
import json

In [None]:
class Node:
    def __init__(self, x, y, parent=None):
        self.x = x
        self.y = y
        self.parent = parent
        self.children = []
        self.thickness = 1

        self._length = None
        self._theta = None
        self._preferred_theta = None # preferred angle at which to branch
        self.fertile = True
        self.dead = False

    @property
    def length(self):
        if self._length is None:
            if self.parent is None:
                self._length = -1
            else:
                self._length = np.sqrt((self.x - self.parent.x)**2 + (self.y - self.parent.y)**2)
        return self._length
    
    @property
    def theta(self):
        if self.parent is None:
            return 0
        self._theta = np.arctan2(self.y - self.parent.y, self.x - self.parent.x)
        return self._theta
    
    @property
    def preferred_theta(self):
        if self._preferred_theta is None:
            if not self.children:
                self._preferred_theta = self.theta
                return self.theta
            
            x = np.sum([(child.x - self.x)/child.length*child.thickness for child in self.children])
            y = np.sum([(child.y - self.y)/child.length*child.thickness for child in self.children])
            w = np.sum([child.thickness for child in self.children])

            if self.parent:
                parent_bias = 1.2
                x += parent_bias*(self.parent.x - self.x)/self.length*self.thickness
                y += parent_bias*(self.parent.y - self.y)/self.length*self.thickness
                w += parent_bias*self.thickness

            x = x/w
            y = y/w
            self._preferred_theta = np.pi + np.arctan2(y, x)

        return self._preferred_theta

    
    def field(self, x, y):
        # repulsion field for growth of nearby nodes: number between 0 and 1
        return np.exp(-4*((self.x - x)**2 + (self.y - y)**2)/self.length**2)
    
    def increase_thickness(self):
        self.thickness += 1
        if self.parent is not None:
            self.parent.increase_thickness()
    
    def stem(self, length, dtheta):
        child = Node(self.x + length*np.cos(self.preferred_theta + dtheta), self.y + length*np.sin(self.preferred_theta + dtheta), self)
        self.children.append(child)
        self._preferred_theta = None
        if len(self.children) > 1:
            self.increase_thickness()
        self.fertile = False
        return child

    



class Tree:
    def __init__(self, root_x=0, root_y=0, length=1, theta=0):
        self.theta = theta
        self.length = length
        self.root = Node(root_x, root_y)
        self.nodes = []
        self.grid = {}
        self.add_to_grid(self.root)
        self.add_node(self.root.stem(length, theta))

    def add_node(self, node):
        self.nodes.append(node)
        self.add_to_grid(node)

    def grid_index(self, x, y):
        return (int(x/self.length), int(y/self.length))

    def add_to_grid(self, node):
        k = self.grid_index(node.x, node.y)
        if k not in self.grid:
            self.grid[k] = []
        self.grid[k].append(node)

    def neighbors(self, x, y, r=None):
        if r is None:
            r = self.length
        kx, ky = self.grid_index(x, y)
        ks = [(kx, ky)]
        if int(x + r)/self.length > kx:
            ks.append((kx + 1, ky))
            if int(y + r)/self.length > ky:
                ks.append((kx + 1, ky + 1))
            elif int(y - r)/self.length < ky:
                ks.append((kx + 1, ky - 1))
        elif int(x - r)/self.length < kx:
            ks.append((kx - 1, ky))
            if int(y + r)/self.length > ky:
                ks.append((kx - 1, ky + 1))
            elif int(y - r)/self.length < ky:
                ks.append((kx - 1, ky - 1))
        elif int(y + r)/self.length > ky:
            ks.append((kx, ky + 1))
        elif int(y - r)/self.length < ky:
            ks.append((kx, ky - 1))

        grid_neighbors = []
        for k in ks:
            if k in self.grid:
                grid_neighbors.extend(self.grid[k])

        return [node for node in grid_neighbors if np.sqrt((node.x - x)**2 + (node.y - y)**2) < r]

    def grow(self, gamma=0.9, branching_prob=0.1, dtheta_sigma=0.1, repulsion=0, min_branch_length=0.01, twist=0, forward_drive=0, drive_mode='avg',transverse_drive=0, max_branching_attempts=3):
        fertile_nodes = []

        for node in self.nodes:
            if node.fertile:
                fertile_nodes.append(node)
            elif not node.dead and np.random.random() < branching_prob/len(node.children):
                node.fertile = True
                fertile_nodes.append(node)

        print(f'{len(fertile_nodes)} fertile nodes')

        for fertile_node in fertile_nodes:
            lenght = fertile_node.length*gamma

            if lenght < min_branch_length: # we are too small: killing the fertile node
                fertile_node.fertile = False
                fertile_node.dead = True
                continue

            if repulsion:
                # check if the node is too close to another branch
                field = 0
                for node in self.neighbors(fertile_node.x, fertile_node.y):
                    if node is fertile_node or node is fertile_node.parent:
                        continue 
                    field += node.field(fertile_node.x, fertile_node.y)

                if field > 1 - repulsion: # we are too close: killing the fertile node
                    fertile_node.dead = True
                    fertile_node.fertile = False
                    continue
            
            ## try to branch
            for i in range(max_branching_attempts):
                success = True
                # select a branching angle
                dtheta = dtheta_sigma*np.random.uniform(-np.pi/2, np.pi/2) + twist*np.pi/2

                if transverse_drive:
                    if drive_mode == 'avg':
                        v = np.array([np.cos(fertile_node.preferred_theta + dtheta), np.sin(fertile_node.preferred_theta + dtheta)])
                        v_transverse = np.array([-np.sin(self.theta), np.cos(self.theta)])
                        v = (1-np.abs(transverse_drive))*v + transverse_drive*v_transverse
                        dtheta = np.arctan2(v[1], v[0]) - fertile_node.preferred_theta
                    elif drive_mode == 'cosine_threshold':
                        if np.cos(fertile_node.preferred_theta + dtheta - self.theta - np.pi/2*np.sign(transverse_drive)) + 1 < np.abs(transverse_drive):
                            success = False
                            continue

                if forward_drive:
                    if drive_mode == 'avg':
                        v = np.array([np.cos(fertile_node.preferred_theta + dtheta), np.sin(fertile_node.preferred_theta + dtheta)])
                        v_forward = np.array([np.cos(self.theta), np.sin(self.theta)])
                        v = (1-np.abs(forward_drive))*v + forward_drive*v_forward
                        dtheta = np.arctan2(v[1], v[0]) - fertile_node.preferred_theta

                    elif drive_mode == 'cosine_threshold':
                        if np.cos(fertile_node.preferred_theta + dtheta - self.theta) + 1 < forward_drive:
                            success = False
                            continue


            
                if repulsion:
                    # compute repulsion field at tip of branch
                    x = fertile_node.x + lenght*np.cos(fertile_node.preferred_theta + dtheta)
                    y = fertile_node.y + lenght*np.sin(fertile_node.preferred_theta + dtheta)

                    # check if the new location is too close to another branch
                    field = 0
                    for node in self.neighbors(x, y):
                        if node is fertile_node or node is fertile_node.parent:
                            continue 
                        field += node.field(x, y)

                    if field > 1 - repulsion: # we are too close: try again
                        success = False

                if success:
                    break

            if not success: # all attempts failed, killing the fertile node
                fertile_node.fertile = False
                fertile_node.dead = True
                continue

            # we are authorized to branch
            self.add_node(fertile_node.stem(length=lenght, dtheta=dtheta))
            fertile_node.fertile = False

    def plot(self, scale_thickness=0, scale_mode='linear', color='black', highlight_node_status=False, scatter_size=1):
        for node in self.nodes:
            kw = {}
            if scale_thickness:
                if scale_mode == 'linear':
                    kw['linewidth'] = node.thickness*scale_thickness
                elif scale_mode == 'sqrt':
                    kw['linewidth'] = np.sqrt(node.thickness)*scale_thickness
                elif scale_mode.startswith('pow'):
                    kw['linewidth'] = node.thickness**float(scale_mode[3:])*scale_thickness
                else:
                    raise ValueError(f'Unknown scale mode: {scale_mode}')
                kw['solid_capstyle'] = 'round'
            plt.plot([node.parent.x, node.x], [node.parent.y, node.y], color=color, **kw)

        if highlight_node_status:
            for node in self.nodes:
                if node.fertile:
                    plt.scatter(node.x, node.y, s=scatter_size, color='green')
                elif node.dead:
                    plt.scatter(node.x, node.y, s=scatter_size, color='red')
    

## Make movie

In [None]:
folder = './tree-movies/tree-movie-35' # name of the folder where to save the frames

t = Tree(theta=np.pi/2, length=1)
kw = dict(gamma=0.9, # decay of branch length, closer to 1 means longer branches
          branching_prob=0.05, # probability of branching
          dtheta_sigma=0.5, # amplitude of random deviation when branching
          repulsion=0.98, # strength of repulsion between branches, to avoid self intersections
          min_branch_length=0.1, # minimum length of a branch, once a branch reaches this length it will stop branching
          drive_mode='cosine_threshold', # stops the growth if branches are 'too downward facing'.
                                         # The higher forward_drive, the more strict the constraint (0 no constraint), 1 branches caanot grow downward, 
                                         # 2 branches can grow only straight upward
        #   drive_mode='avg', # pushes the growth direction towards the initial growth direction
          forward_drive=0, # strength of push towards the initial growth direction, 1 means forcing the tree to grow in a line
          transverse_drive=0, # strength of push towards the transverse direction (0 means no push), 1 means forcing the tree to grow in a line sideways
          twist=0, # bias that twists the growth direction between (-1,1), 0 means no bias
          max_branching_attempts=3 # maximum number of attempts to branch after which the fertile node is killed
        )

max_stalls = 15 # how many times the tree can stall before stopping
maxiter = 1000 # maximum number of iterations
xlim = (-8, 8) # x limits of the plot
ylim = (-5.5, 10.5) # y limits of the plot
photo_every = 2 # how often to save a frame
scale_mode = 'pow0.8' # how to scale the thickness of the branches
scale_thickness = 0.1
seed = np.random.randint(0, 2**16)
np.random.seed(seed)

args = {'max_stalls': max_stalls, 'maxiter': maxiter, 'xlim': xlim, 'ylim': ylim,
        'photo_every': photo_every, 'scale_mode': scale_mode, 'scale_thickness': scale_thickness, 'seed': seed}
args = {**kw, **args}

frame = 1
stalls = 0
n_nodes_prev = 0
n_nodes = len(t.nodes)

####

if not os.path.exists(folder):
    os.makedirs(folder)
else:
    raise FileExistsError()

with open(folder + '/args.json', 'w') as f:
    json.dump(args, f)
    
for i in tqdm(range(maxiter)):
    t.grow(**kw)
    n_nodes_prev = n_nodes
    n_nodes = len(t.nodes)
    if n_nodes == n_nodes_prev:
        stalls += 1

    if i%photo_every == 0:
        plt.close(1)
        fig, ax = plt.subplots(num=1, figsize=((xlim[1]-xlim[0])/2, (ylim[1]-ylim[0])/2))

        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_aspect('equal')

        t.plot(scale_thickness=scale_thickness, highlight_node_status=False, scale_mode=scale_mode)

        plt.axis('off')

        # plt.tick_params(
        #     axis='x',       # changes apply to both axis
        #     which='both',      # both major and minor ticks are affected
        #     bottom=False,      # ticks along the bottom edge are off
        #     top=False,         # ticks along the top edge are off
        #     labelbottom=False)
        # plt.tick_params(
        #     axis='y',       # changes apply to both axis
        #     which='both',      # both major and minor ticks are affected
        #     left=False,      # ticks along the bottom edge are off
        #     right=False,         # ticks along the top edge are off
        #     labelbottom=False)

        fig.tight_layout()

        fig.savefig(f'{folder}/{frame:04d}.png', dpi=200)

        frame += 1

    if stalls > max_stalls:
        break