In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.animation as animation

In [None]:
#### Generalized to 3D. See 2D code for details about what's going on. Generalization is fairly trivial.

class Node():
    def __init__(self,pos,length,points):
        points=np.asarray(points) #In case given as list
        self.pos=pos #Node classified by position of centre (x,y,z) and length of box
        self.length=length
        self.points=points #Points stored in node
        if len(points)<1:
            self.coords=None
        else:
            self.coords=points[:,:-1] #Coords stored in node
        
        if self.coords==None: #Store mass and center of mass for node. If no points, no mass.
            self.mass=None
            self.com=None
        else:
            totmass=np.sum(points[:,3]) #Sum of masses
            xcom=(np.sum(points[:,0]*points[:,2]))/totmass
            ycom=(np.sum(points[:,1]*points[:,2]))/totmass
            zcom=(np.sum(points[:,2]*points[:,2]))/totmass
            self.com=[xcom,ycom,zcom] #store center of mass
            self.mass=totmass #Store total mass held in that node, i.e. sum of all child masses if has children
        self.children=[]  #Children of node
        
        
class QuadTree():
    def __init__(self,points,velocities,boxsize): #Initialize with points to put in and size of total grid
        points=np.asarray(points)
        self.points=points
        self.vel=velocities
        
        self.size=boxsize
        #centerx=(np.max(points[:,0])+np.min(points[:,0]))/2
        #centery=(np.max(points[:,1])+np.min(points[:,1]))/2
        centerx=0
        centery=0
        centerz=0
        self.root=Node([centerx,centery,centerz],boxsize,points) #root of tree, large node centered at middle of box
    
        
    def rootdiv(self):
        self.divide(self.root)
        
    def divide(self,node):
        
        if len(node.points)<2: #or node.length<self.size/(2**14):
            if len(node.points)>=2:
                for pt in node.points:
                    del pt
                node.points=[node.com[0],node.com[1],node.com[2],node.mass]
            return
        cx=node.pos[0] #Center line in x
        cy=node.pos[1] #Center line in y
        cz=node.pos[2] #Center line in z
        xlow=node.pos[0]-node.length/2 #Upper and lower boundaries of node
        xhigh=node.pos[0]+node.length/2
        ylow=node.pos[1]-node.length/2
        yhigh=node.pos[1]+node.length/2
        zlow=node.pos[2]-node.length/2
        zhigh=node.pos[2]+node.length/2
        
        pts=self.PtsInChild(xlow,cx,ylow,cy,zlow,cz,node.points,4) #Find pts in "Southwest"(SW) child-node
        newcenter=[node.pos[0]-node.length/4,node.pos[1]-node.length/4,node.pos[2]-node.length/4]
        SWF=Node(newcenter,node.length/2,pts) #Create that node 
        self.divide(SWF) #Recursively do this
        
        pts=self.PtsInChild(xlow,cx,ylow,cy,cz,zhigh,node.points,8) #Find pts in "Southwest"(SW) child-node
        newcenter=[node.pos[0]-node.length/4,node.pos[1]-node.length/4,node.pos[2]+node.length/4]
        SWB=Node(newcenter,node.length/2,pts) #Create that node 
        self.divide(SWB) #Recursively do this
        
        pts=self.PtsInChild(xlow,cx,cy,yhigh,zlow,cz,node.points,1) #Repeat for 3 other child nodes
        newcenter=[node.pos[0]-node.length/4,node.pos[1]+node.length/4,node.pos[2]-node.length/4]
        NWF=Node(newcenter,node.length/2,pts) 
        self.divide(NWF) 
        
        pts=self.PtsInChild(xlow,cx,cy,yhigh,cz,zhigh,node.points,5) #Repeat for 3 other child nodes
        newcenter=[node.pos[0]-node.length/4,node.pos[1]+node.length/4,node.pos[2]+node.length/4]
        NWB=Node(newcenter,node.length/2,pts) 
        self.divide(NWB) 
        
        pts=self.PtsInChild(cx,xhigh,ylow,cy,zlow,cz,node.points,3) 
        newcenter=[node.pos[0]+node.length/4,node.pos[1]-node.length/4,node.pos[2]-node.length/4]
        SEF=Node(newcenter,node.length/2,pts) 
        self.divide(SEF) 
        
        pts=self.PtsInChild(cx,xhigh,ylow,cy,cz,zhigh,node.points,7) 
        newcenter=[node.pos[0]+node.length/4,node.pos[1]-node.length/4,node.pos[2]+node.length/4]
        SEB=Node(newcenter,node.length/2,pts) 
        self.divide(SEB) 
        
        pts=self.PtsInChild(cx,xhigh,cy,yhigh,zlow,cz,node.points,2) 
        newcenter=[node.pos[0]+node.length/4,node.pos[1]+node.length/4,node.pos[2]-node.length/4]
        NEF=Node(newcenter,node.length/2,pts) 
        self.divide(NEF)
        
        pts=self.PtsInChild(cx,xhigh,cy,yhigh,cz,zhigh,node.points,6) 
        newcenter=[node.pos[0]+node.length/4,node.pos[1]+node.length/4,node.pos[2]+node.length/4]
        NEB=Node(newcenter,node.length/2,pts) 
        self.divide(NEB) 
        
        node.children=[SWF,NWF,SEF,NEF,SWB,NWB,SEB,NEB] 
        
        
    def PtsInChild(self,xmin,xmax,ymin,ymax, points, region): #Check in what child-nodes points should go in to
        #region argument is used for points falling exactly on a boundary
        ChildPoints=[]
        if region==1:
            for pts in points:
                if pts[0]>=xmin and pts[0]<xmax and pts[1]>=ymin and pts[1]<=ymax and pts[2]>=zmin and pts[2]<=zmax:
                    ChildPoints.append(pts)
        elif region==2:
            for pts in points:
                if pts[0]>=xmin and pts[0]<=xmax and pts[1]>ymin and pts[1]<=ymax and pts[2]>=zmin and pts[2]<=zmax:
                    ChildPoints.append(pts)
        elif region==3:
            for pts in points:
                if pts[0]>xmin and pts[0]<=xmax and pts[1]>=ymin and pts[1]<=ymax and pts[2]>=zmin and pts[2]<=zmax:
                    ChildPoints.append(pts)
        elif region==4:
            for pts in points:
                if pts[0]>=xmin and pts[0]<=xmax and pts[1]>=ymin and pts[1]<ymax and pts[2]>=zmin and pts[2]<=zmax:
                    ChildPoints.append(pts)
        elif region==5:
            for pts in points:
                if pts[0]>=xmin and pts[0]<xmax and pts[1]>=ymin and pts[1]<=ymax and pts[2]>zmin and pts[2]<=zmax:
                    ChildPoints.append(pts)
        elif region==6:
            for pts in points:
                if pts[0]>=xmin and pts[0]<=xmax and pts[1]>ymin and pts[1]<=ymax and pts[2]>zmin and pts[2]<=zmax:
                    ChildPoints.append(pts)
        elif region==7:
            for pts in points:
                if pts[0]>xmin and pts[0]<=xmax and pts[1]>=ymin and pts[1]<=ymax and pts[2]>zmin and pts[2]<=zmax:
                    ChildPoints.append(pts)
        elif region==8:
            for pts in points:
                if pts[0]>=xmin and pts[0]<=xmax and pts[1]>=ymin and pts[1]<ymax and pts[2]>zmin and pts[2]<=zmax:
                    ChildPoints.append(pts)
        
        return ChildPoints
    
    def find_children(self,node):
        if not node.children:
            return [node]
        else:
            children = []
            for child in node.children:
                children += (self.find_children(child))
        return children
    
    
    def find_usedpts(self,pt,theta,node): #Find what effective points and masses to use for force on a point "pt"
        
        if not node.children: #node.children is false as boolean if 0
            return [node]
        s=node.length
        d=np.sqrt((node.com[0]-pt[0])**2 + (node.com[1]-pt[1])**2 + (node.com[2]-pt[2])**2)
        if (s/d)<theta: #Check condition for whether to use big node, threshold set by theta
            return [node]
        else:
            children = []
            for child in node.children:
                children += (self.find_usedpts(pt,theta,child))
        return children
    
    def Verlet(self,pt,vel,step):
        h=step #stepsize
        x=pt[0]
        y=pt[1]
        z=pt[2]
        vx=vel[0]
        vy=vel[1]
        vz=vel[2]
        ax=0
        ay=0
        az=0
        used=self.find_usedpts(pt,0.5,self.root) #Find nodes to use for force calc
        for node in used: 
            if node.mass!=None and node.mass!=pt[3]: #Only use node if has mass THIS IS PROBABLY USELESS, useful children should have mass
                xdis=node.com[0]-x
                ydis=node.com[1]-y
                zdis=node.com[2]-z
                r=(xdis**2+ydis**2+zdis**2)**(1/2)
                if r>1e-9:
                    ax+=(node.mass*xdis)/((r**3) + 0.01) #Gives total acceleration in x and y directions
                    ay+=(node.mass*ydis)/((r**3) + 0.01)
                    az+=(node.mass*zdis)/((r**3) + 0.01)
        vx+=(h/2)*ax #Update vx and vy at t+h/2
        vy+=(h/2)*ay
        vz+=(h/2)*az
        xnew=x+h*vx #Update x and y at t+h
        ynew=y+h*vy
        znew=z+h*vz
        
        ax=0
        ay=0
        az=0
        for node in used: #Find accelerations again, now at t+h
            if node.mass!=None and node.mass!=pt[3]: 
                xdis=node.com[0]-xnew
                ydis=node.com[1]-ynew
                zdis=node.com[2]-z
                r=(xdis**2+ydis**2+zdis**2)**(1/2)
                if r>1e-9:
                    ax+=(G*node.mass*xdis)/((r**3) + 0.1) 
                    ay+=(G*node.mass*ydis)/((r**3) + 0.1)
                    az+=(node.mass*zdis)/((r**3) + 0.01)
        vxnew=vx+(h/2)*ax
        vynew=vy+(h/2)*ay
        vznew=vz+(h/2)*az
        return [xnew,ynew,znew,pt[3]],[vxnew,vynew,vznew] #Spit out evolved points and velocities
    

    
    def Clear(self):
        del self.root.children[:]
        self.points=None
        self.mass=None
        self.com=None
        return
    
    
    

        
def evolve(time,tree):
    step=time/10000 #Stepsize, SHOULD MAKE OPTIMAL
    t=0
    #Storage=np.empty([10001,3])
    #Storage1=np.empty([10001,3])
    #KE=np.empty([10001])
    #Pot=np.empty([10001])
    #AngMom=np.empty([10001])
    N=0
    while t<time:
        points=tree.points
        vels=tree.vel
        New_Points=np.empty([len(points),4]) #Temp arrays to hold evolved quantities
        New_Velocities=np.empty([len(vels),3])
        for i in range(len(points)):
            pt=points[i] ######PROBLEM: Am evolving without using new quantities
            vel=vels[i]
            New_Points[i],New_Velocities[i]=tree.Verlet(pt,vel,step) #Update
        #Storage[N]=New_Points[0]
        #torage1[N]=New_Points[1]
        
        if N%100==0:
            print(N)
        
        t+=step
        tree.Clear() #Kill tree
        tree=QuadTree(New_Points,New_Velocities,tree.size) #Create new tree with updated points
        #####
        #KE[N]=0.5*np.sum(tree.points[:,3]*(tree.vel[:,0]**2 + tree.vel[:,1]**2))
        #Pot[N]=-(tree.points[0][2])*np.sum(tree.points[1:,2]*(1/((tree.points[1:,0]**2 + tree.points[1:,1]**2)**(1/2))))
        #for i in range(len(tree.points)):
        #    rad=[tree.points[i][0],tree.points[i][1]]
        #    AngMom[N]+=np.linalg.norm(np.cross(tree.points[i][:2],tree.points[i][2]*tree.vel[i]))
        #####
        N+=1
        tree.rootdiv()
    return tree #,Storage,Storage1 #Return evolved tree



def evolve_once(time,tree):
    step=time/10000
    points=tree.points
    vels=tree.vel
    New_Points=np.empty([len(points),4]) #Temp arrays to hold evolved quantities
    New_Velocities=np.empty([len(vels),3])
    for i in range(len(points)):
        pt=points[i] ######PROBLEM: Am evolving without using new quantities
        vel=vels[i]
        New_Points[i],New_Velocities[i]=tree.Verlet(pt,vel,step) #Update
        
    tree.Clear() #Kill tree
    tree=QuadTree(New_Points,New_Velocities,tree.size) #Create new tree with updated points
    tree.rootdiv()
    return tree #Return evolved tree

