This notebook contains some functions and tools for use with the Rebound notebooks contained in this directory.

### Import statements

In [0]:
from numpy import *

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
#import matplotlib.rcParams as rcParams

%matplotlib inline

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams["animation.embed_limit"] = 1024
plt.rcParams["font.size"] = 14

from IPython.display import display, HTML


In [0]:
#pip install --upgrade git+git://github.com/jakevdp/JSAnimation.git

Collecting git+git://github.com/jakevdp/JSAnimation.git
  Cloning git://github.com/jakevdp/JSAnimation.git to /tmp/pip-req-build-_pd9a4d_
  Running command git clone -q git://github.com/jakevdp/JSAnimation.git /tmp/pip-req-build-_pd9a4d_
Building wheels for collected packages: JSAnimation
  Building wheel for JSAnimation (setup.py) ... [?25l[?25hdone
  Created wheel for JSAnimation: filename=JSAnimation-0.1-cp36-none-any.whl size=12282 sha256=3eb4526b920cef7cc5050a969c78f7835f0680a29197810a4d758c0f736b7232
  Stored in directory: /tmp/pip-ephem-wheel-cache-tmseubcs/wheels/e4/f9/b2/03401cb39997af9ae11e004cefd599250f584e89eca4a275ee
Successfully built JSAnimation
Installing collected packages: JSAnimation
Successfully installed JSAnimation-0.1


### Friendly functions

In [0]:
names_dict=dict()
mycounter=0

def initialize_simulation():
  global names_dict,mycounter
  names_dict=dict()
  mycounter=0

  # set up basic simulation things
  sim = rebound.Simulation()
  sim.units = ('yr', 'AU', 'Msun')

  return sim

In [0]:
def add_star(sim,name="star",mass=1):
  global names_dict,mycounter
    
  # assume mass is in msun
  sim.add(m=mass,hash=name)
  sim.move_to_com()

  names_dict[mycounter]=name; mycounter+=1
  return

In [0]:
def add_planet(sim,name="planet",mass=1,a=1.,e=0.,i=0.,omega=0.,Omega=0.,f=0.):
  global names_dict,mycounter    
    
  # Assume mass units are input in m_Jupiter (code needs msun)
  # assume angles are in degrees (code needs radians)
  convert_mass=lambda m: m/1047. # 1 jupiter is 1/1047 sun



  # Don't let people add a planet that's too massive
  class StopExecution(Exception):
    def _render_traceback_(self):
        pass

  if convert_mass(mass)>=0.1: 
    print("You're trying to add a planet that is as big as a star! You can't do that!\n Reduce the mass of '"+name+"' by at least "+str(round(convert_mass(mass)/0.1,1))+" times and try again.")
    raise StopExecution

  sim.add(m=convert_mass(mass),a=a,e=e,inc=radians(i),omega=radians(omega),Omega=radians(Omega),f=radians(f),hash=name)
  sim.move_to_com()

  names_dict[mycounter]=name; mycounter+=1
  return

In [0]:
def plot_ae(sim,var,maxa,maxd):
    global names_dict
    from itertools import cycle
    from matplotlib.collections import LineCollection
    from matplotlib.colors import LinearSegmentedColormap
    
    colors = [(1.,0.,0.),(0.,0.75,0.75),(0.75,0.,0.75),(0.75, 0.75, 0,),(0., 0., 0.),(0., 0., 1.),(0., 0.5, 0.)]

    
    labels={'a': r'$a$ [AU]', 'd': 'Distance from\nthe star [AU]', 'e': r'$e$', 't': 'Time [years]'}
    maxlims={'a': [0,2.5*maxa], 'd': [0,5.*maxd], 'e':[-0.01,1.01] }
    
    fig,ax = plt.subplots(3,1,figsize=(15,8))
    for i,d in enumerate(['a','d','e']):
        coloriterator = cycle(colors)

        for j in range(shape(var[d])[1]):
            colori = next(coloriterator)
            ax[i].plot(var['t'],var[d][:,j],c=colori,label=names_dict[j+1])
            
        if i<len(ax)-1: ax[i].set_xticklabels([])
        else: ax[i].set_xlabel(labels['t'])
        ax[i].set_ylabel(labels[d])
        yl=ax[i].get_ylim()
        ax[i].set_ylim([max([maxlims[d][0],yl[0]]),min([maxlims[d][1],yl[1]])])
        if shape(var[d])[1]==1: 
            if d=='e': ax[i].set_ylim([min(var[d][:])-0.01,min(var[d][:])+0.01])
            else: ax[i].set_ylim([min(var[d][:])-0.2,min(var[d][:])+0.2])
        
    ax[0].legend(loc='lower center', bbox_to_anchor=(0.5, 1.05),ncol=4)
    plt.subplots_adjust(hspace=0.1)
    
    return

In [0]:
def run_simulation(sim,end_time,number_frames=100):
  if sim.t==end_time: print("Did you want to run a simulation? Specify an end time that isn't 0!"); return

  sim.move_to_com()

  sim.integrator='mercurius' # a hybrid symplectic that hopefully works well

  # BUT: default timestep too small!!! do some magic to figure out best timestep for long integrations
  orbits=sim.calculate_orbits()
  Ps=[o.P for o in orbits]

  sim.dt = min(Ps)*0.03 # set timestep to 3% of innermost planet period
  sim.ri_ias15.min_dt = 1e-4 * sim.dt
  sim.ri_mercurius.hillfac = 2.

  # we'll also plot here
  fig,ax = plt.subplots(1,1,figsize=(5,5))
  frames = [make_rebound_frame(ax,sim)]
  orig_ani_lims=ax.get_ylim()


  from tqdm import tqdm_notebook # progress bar
  #from tqdm.notebook import tqdm as tqdm_notebook # progress bar
  t_init=sim.t
  simtimes=linspace(t_init,end_time,number_frames+1)

  var=dict()
  smas=[o.a for o in orbits] # initial SMAs
  eccs=[o.e+1E-6 for o in orbits] # initial e's plus small fidget to avoid divide by zero
  dists=[o.d for o in orbits] # initial distance from Sun
  var['a']=array(array(smas))
  var['e']=array(array(eccs))
  var['d']=array(array(dists))
  var['t']=[t_init]
  

  pbar = tqdm_notebook(total=int(end_time),desc='Progress') 
  for i,s in enumerate(simtimes):
    if i==0: continue
    sim.integrate(s,exact_finish_time=True)
    frames.append(make_rebound_frame(ax,sim)) # Add the new frame to the list for animation
    pbar.update(int(simtimes[i]-simtimes[i-1]))
   
    # Output orbits for plotting
    orbits=sim.calculate_orbits()
    var['a']=vstack((var['a'],array([o.a for o in orbits])))
    var['e']=vstack((var['e'],array([o.e for o in orbits])))
    var['d']=vstack((var['d'],array([o.d for o in orbits])))
    var['t']+=[s]

    # Check to see if orbits blew up--no need to waste computational time if we have major instability   
    done=False
    per_err=lambda x2,x1: abs(x2-x1)/x1
    for j,o in enumerate(orbits):
        if (per_err(o.a,smas[j])>2. and o.a > 2.*max(smas)) or (per_err(o.e,eccs[j])>10. and o.e > 0.9): 
            print(names_dict[j+1]+' had a catastrophic instability detected at time '+str(s)+' years')
            print('Its initial orbital elements at time=0 are: a='+str(round(smas[j],2))+' and e='+str(round(eccs[j],3)))
            print('Its orbital elements at time='+str(s)+' are: a='+str(round(o.a,2))+' and e='+str(round(o.e,3)))
            done=True  
    if done: break
    

  pbar.close()
    
    
  new_ani_lims=ax.get_ylim()
  if new_ani_lims[1]>2.*orig_ani_lims[1]: ax.set_xlim(array(orig_ani_lims)*1.5); ax.set_ylim(array(orig_ani_lims)*1.5)
        
  ani = animation.ArtistAnimation(fig, frames)
  plt.close(fig)

  plot_ae(sim,var,max(smas),max(dists))

  return ani
  

### Fancy Rebound plotting things

In [0]:
def get_color(color):
    """
    Takes a string for a color name defined in matplotlib and returns of a 3-tuple of RGB values.
    Will simply return passed value if it's a tuple of length three.
    Parameters
    ----------
    color   : str
        Name of matplotlib color to calculate RGB values for.
    """

    if isinstance(color, tuple) and len(color) == 3: # already a tuple of RGB values
        return color

    try:
        import matplotlib.colors as mplcolors
    except:
        raise ImportError("Error importing matplotlib. If running from within a jupyter notebook, try calling '%matplotlib inline' beforehand.")
   
    try:
        hexcolor = mplcolors.cnames[color]
    except KeyError:
        raise AttributeError("Color not recognized in matplotlib.")

    hexcolor = hexcolor.lstrip('#')
    lv = len(hexcolor)
    return tuple(int(hexcolor[i:i + lv // 3], 16)/255. for i in range(0, lv, lv // 3)) # tuple of rgb values


def fading_line(x, y, color='black', alpha_initial=1., alpha_final=0., glow=False, **kwargs):
    """
    Returns a matplotlib LineCollection connecting the points in the x and y lists, with a single color and alpha varying from alpha_initial to alpha_final along the line.
    Can pass any kwargs you can pass to LineCollection, like linewidgth.
    Parameters
    ----------
    x       : list or array of floats for the positions on the (plot's) x axis
    y       : list or array of floats for the positions on the (plot's) y axis
    color   : matplotlib color for the line. Can also pass a 3-tuple of RGB values (default: 'black')
    alpha_initial:  Limiting value of alpha to use at the beginning of the arrays.
    alpha_final:    Limiting value of alpha to use at the end of the arrays.
    """
    try:
        from matplotlib.collections import LineCollection
        from matplotlib.colors import LinearSegmentedColormap
        import numpy as np
    except:
        raise ImportError("Error importing matplotlib and/or numpy. Plotting functions not available. If running from within a jupyter notebook, try calling '%matplotlib inline' beforehand.")


    if glow:
        glow = False
        kwargs["lw"] = 1
        fl1 = fading_line(x, y, color, alpha_initial, alpha_final, glow=False, **kwargs)
        kwargs["lw"] = 2
        alpha_initial *= 0.5
        alpha_final *= 0.5
        fl2 = fading_line(x, y, color, alpha_initial, alpha_final, glow=False, **kwargs)
        kwargs["lw"] = 6
        alpha_initial *= 0.5
        alpha_final *= 0.5
        fl3 = fading_line(x, y, color, alpha_initial, alpha_final, glow=False, **kwargs)
        return [fl3,fl2,fl1]

    color = get_color(color)
    cdict = {'red': ((0.,color[0],color[0]),(1.,color[0],color[0])),
             'green': ((0.,color[1],color[1]),(1.,color[1],color[1])),
             'blue': ((0.,color[2],color[2]),(1.,color[2],color[2])),
             'alpha': ((0.,alpha_initial, alpha_initial), (1., alpha_final, alpha_final))}
    
    Npts = len(x)
    if len(y) != Npts:
        raise AttributeError("x and y must have same dimension.")
   
    segments = np.zeros((Npts-1,2,2))
    segments[0][0] = [x[0], y[0]]
    for i in range(1,Npts-1):
        pt = [x[i], y[i]]
        segments[i-1][1] = pt
        segments[i][0] = pt 
    segments[-1][1] = [x[-1], y[-1]]

    individual_cm = LinearSegmentedColormap('indv1', cdict)
    lc = LineCollection(segments, cmap=individual_cm, **kwargs)
    lc.set_array(np.linspace(0.,1.,len(segments)))
    return lc


def OrbitPlotOneSlice(sim, ax, lim=None, limz=None, Narc=100, color=False, periastron=False, trails=False, show_orbit=True, lw=1., axes="xy", plotparticles=[], primary=None, glow=False, fancy=False):
    from itertools import cycle
    import matplotlib.pyplot as plt
    from matplotlib.collections import LineCollection
    from matplotlib.colors import LinearSegmentedColormap
    import numpy as np
    import random

    p_orb_pairs = []
    if not plotparticles:
        plotparticles = range(1, sim.N_real)
    for i in plotparticles:
        p = sim.particles[i]
        p_orb_pairs.append((p, p.calculate_orbit(primary=primary)))

    if lim is None:
        lim = 0.
        for p, o in p_orb_pairs: 
            if o.a>0.:
                r = (1.+o.e)*o.a
            else:
                r = o.d
            if r>lim:
                lim = r
        lim *= 1.15
    if limz is None:
        z = [p.z for p,o in p_orb_pairs]
        limz = 2.0*max(z)
        if limz > lim:
            limz = lim
        if limz <= 0.:
            limz = lim

    if axes[0]=="z":
        ax.set_xlim([-limz,limz])
    else:
        ax.set_xlim([-lim,lim])
    if axes[1]=="z":
        ax.set_ylim([-limz,limz])
    else:
        ax.set_ylim([-lim,lim])
        
    if fancy:
        ax.set_facecolor((0.,0.,0.))
        for pos in ['top', 'bottom', 'right', 'left']:
            ax.spines[pos].set_edgecolor((0.3,0.3,0.3))

    if color:
        if color == True:
            colors = [(1.,0.,0.),(0.,0.75,0.75),(0.75,0.,0.75),(0.75, 0.75, 0,),(0., 0., 0.),(0., 0., 1.),(0., 0.5, 0.)]
        if isinstance(color, str):
            colors = [get_color(color)]
        if isinstance(color, list):
            colors = []
            for c in color:
                colors.append(get_color(c))
    else:
        if fancy:
            colors = [(181./206.,66./206.,191./206.)]
            glow = True
        else:
            colors = ["black"]
    coloriterator = cycle(colors)

    coords = {'x':0, 'y':1, 'z':2}
    axis0 = coords[axes[0]]
    axis1 = coords[axes[1]]
   
    prim = sim.particles[0] if primary is None else primary 
    if fancy:
        sun = (256./256.,256./256.,190./256.)
        opa = 0.020
        size = 6000.
        for i in range(256):
            ax.scatter(getattr(prim,axes[0]),getattr(prim,axes[1]), alpha=opa, s=size*lw, facecolor=sun, edgecolor=None, zorder=3)
            size *= 0.95
        
        starcolor = (1.,1.,1.)
        mi, ma = ax.get_xlim()
        prestate = random.getstate()
        random.seed(1) #always same stars
        x, y = [], []
        #small stars
        for i in range(64):
            x.append(random.uniform(mi,ma))
            y.append(random.uniform(mi,ma))
        ax.scatter(x,y, alpha=0.05, s=8*lw, facecolor=starcolor, edgecolor=None, zorder=3)
        ax.scatter(x,y, alpha=0.1, s=4*lw, facecolor=starcolor, edgecolor=None, zorder=3)
        ax.scatter(x,y, alpha=0.2, s=0.5*lw, facecolor=starcolor, edgecolor=None, zorder=3)
        #medium stars
        x, y = [], []
        for i in range(16):
            x.append(random.uniform(mi,ma))
            y.append(random.uniform(mi,ma))
        ax.scatter(x,y, alpha=0.1, s=15*lw, facecolor=starcolor, edgecolor=None, zorder=3)
        ax.scatter(x,y, alpha=0.1, s=5*lw, facecolor=starcolor, edgecolor=None, zorder=3)
        ax.scatter(x,y, alpha=0.5, s=2*lw, facecolor=starcolor, edgecolor=None, zorder=3)
        random.setstate(prestate)

    else:
        ax.scatter(getattr(prim,axes[0]),getattr(prim,axes[1]), marker="*", s=35*lw, facecolor="black", edgecolor=None, zorder=3)
    
    proj = {}
    for p, o in p_orb_pairs:
        prim = p.jacobi_com if primary is None else primary 
        if fancy:
            ax.scatter(getattr(p,axes[0]), getattr(p,axes[1]), s=25*lw, facecolor=colors, edgecolor=None, zorder=3)
        else:
            ax.scatter(getattr(p,axes[0]), getattr(p,axes[1]), s=25*lw, facecolor="black", edgecolor=None, zorder=3)

        colori = next(coloriterator)
       
        if show_orbit is True:
            alpha_final = 0. if trails is True else 1. # fade to 0 with trails

            hyperbolic = o.a < 0. # Boolean for whether orbit is hyperbolic
            if hyperbolic is False:
                pts = np.array(p.sample_orbit(Npts=Narc+1, primary=prim))
                proj['x'],proj['y'],proj['z'] = [pts[:,i] for i in range(3)]
                lc = fading_line(proj[axes[0]], proj[axes[1]], colori, alpha_final=alpha_final, lw=lw, glow=glow)
                if type(lc) is list:
                    for l in lc:
                        ax.add_collection(l)
                else:
                    ax.add_collection(lc)

            else:
                pts = np.array(p.sample_orbit(Npts=Narc+1, primary=prim, useTrueAnomaly=False))
                # true anomaly stays close to limiting value and switches quickly at pericenter for hyperbolic orbit, so use mean anomaly
                proj['x'],proj['y'],proj['z'] = [pts[:,i] for i in range(3)]
                lc = fading_line(proj[axes[0]], proj[axes[1]], colori, alpha_final=alpha_final, lw=lw, glow=glow)
                if type(lc) is list:
                    for l in lc:
                        ax.add_collection(l)
                else:
                    ax.add_collection(lc)
          
                alpha = 0.2 if trails is True else 1.
                pts = np.array(p.sample_orbit(Npts=Narc+1, primary=prim, trailing=False, useTrueAnomaly=False))
                proj['x'],proj['y'],proj['z'] = [pts[:,i] for i in range(3)]
                lc = fading_line(proj[axes[0]], proj[axes[1]], colori, alpha_initial=alpha, alpha_final=alpha, lw=lw, glow=glow)
                if type(lc) is list:
                    for l in lc:
                        ax.add_collection(l)
                else:
                    ax.add_collection(lc)

        if periastron:
            newp = Particle(a=o.a, f=0., inc=o.inc, omega=o.omega, Omega=o.Omega, e=o.e, m=p.m, primary=prim, simulation=sim)
            ax.plot([getattr(prim,axes[0]), getattr(newp,axes[0])], [getattr(prim,axes[1]), getattr(newp,axes[1])], linestyle="dotted", c=colori, zorder=1, lw=lw)
            ax.scatter([getattr(newp,axes[0])],[getattr(newp,axes[1])], marker="o", s=5.*lw, facecolor="none", edgecolor=colori, zorder=1)
    return ax

### Couple of options for making Rebound animations

First using their widget, and then by creating frames to animate via matplotlib.animation.ArtistAnimation

In [0]:
def integrate_with_widget(sim, t_final, N_steps, pause=0.1):
    dt = (t_final-sim.t)/N_steps
    w = sim.getWidget(scale=3,size=(500,500),autorefresh=True)

    display(w)

    for i in range(N_steps):
        new_time = sim.t + dt # we're changing the stopping time of the simulation to be 0.1 years longer every step
        sim.integrate(new_time) # integrate to the new stopping point
        plt.pause(pause)
    


In [0]:
def make_rebound_frame(ax, sim):

    orig_children=ax.get_children()
    ax = OrbitPlotOneSlice(sim, ax, color=True,trails=True)
    ax.set_xlabel("x [AU]")
    ax.set_ylabel("y [AU]")
    ttl = plt.text(0.5, 1.01, str(round(sim.t,2))+' years', horizontalalignment='center', verticalalignment='bottom', transform=ax.transAxes)
    # need to do this or the figure doesn't get cleared
    ax_list = setdiff1d(ax.get_children(),orig_children, assume_unique=True)

    return ax_list