In [None]:
import numpy as np
from numpy import save
import rebound
import random
from multiprocessing import Pool 
from myfunctions import Plotting 

#Rebound Orbital Elements


In [None]:
def Simulation(par):
    
    sim = rebound.Simulation()
    sim.integrator = "whfast"
    
    e_b,a_p, Np = par[0],par[1], par[2] 
    
    #******************STAR1*****************
    a_b = 1.
    m1 =1.
    sim.add(m=m1, hash = "Star1") 
    
    #******************STAR2*****************
    mu = 0.5
    m2 = (m1*mu)/(1-mu)
    f_b=np.random.rand()*2.*np.pi
    sim.add(m =m2, a= a_b, e=e_b,f=f_b,  hash = "Star2")
    
    #****************TEST PARTCILES************
    for p in range(Np):
        sim.add(m=0.,a=a_p,e=0,f=np.random.rand()*2.*np.pi)
    
    #*************RUN SIMULATION & ARCHIVE************
    sim.move_to_com()
    #sim.automateSimulationArchive("archive_eb{:.3f}_ap{:.3f}_Np{:.2f}.bin".format(e_b,a_p,Np),interval = 1e3, deletefile = True)
    max_dist = 100*a_b

    Torb = 2.*np.pi
    Norb_max = 1e4 
    Tmax = Norb_max*Torb
    Tmin = 0
    Noutputs = 100
    times = np.linspace(Tmin, Tmax, Noutputs)
    
    survtime= np.zeros(Np)  #survival times array
    
    for i,time in enumerate(times):
        sim.integrate(time, exact_finish_time = 0)
        
        for j in reversed(range(2,sim.N)):
            p = sim.particles[j]
            if (p.x**2 + p.y**2) > 100*2:
                survtime[j-2] = time
#                 print('removing planet {0}',j)
                sim.remove(j)
#                 print('{0} planets remaining',sim.N-2)
        
        if sim.N==2:
            break

    
    survtime[(survtime==0)] = time
    
    print('simulation finished, {} planets remianing'.len(sim.particles)-2)
    
    #del sim 
# CHECKING SIMULATION ARCHIVE RETREIVAL
#     sa = rebound.SimulationArchive("archive_eb{:.3f}_ap{:.3f}_Np{:.2f}.bin".format(e_b,a_p,Np)) 
#     print("Number of snapshots: %d" % len(sa))
#     print("Time of first and lastsnap shots are: %.1f, %.1f" % (sa.tmin, sa.tmax))
#     for i, ss in enumerate(len(sa)):
#         print ("Screenshot Number : {}".format(i))
#         sim = sa[i]
#         for j, particle in enumerate(len(sim.particles)):
#             print("Particle {} eccentricity : {} ".format(j, sim.e))
#     sim = sa[1]
#     print(sim.t, sim.particles[2])


    
#     print(np.mean(survtime))
    return np.mean(survtime) #instead of all 10 planets individually??
       

    
  
    
        



In [None]:
if __name__ == "__main__":
    random.seed(1)
    Ne, Na, Np = 2,2,10 #Np = number of test particles per (eb,ap)tuple 
    ab = 1
    ebs = np.linspace(0.,0.7,Ne)
    aps = ab*np.linspace(1.,5.,Na)

    params = [(eb,ap,Np) for eb in ebs for ap in aps]
    #print(params[0:0])


    pool = rebound.InterruptiblePool(processes = 16) #add number of processors same as number requested on Sunnyvale

    # TO TIME ONE CALL TO THE METHOD
    import time
    import itertools
    start = time.time()
    num = 1
    for _ in itertools.repeat(None,num):
       stime = pool.map(Simulation,params) #survival times
    end = time.time()
    print("Time elapsed is",end-start)


    stime = np.array(stime).reshape([Ne,Na])
    stime = np.nan_to_num(stime)
    stime = stime.T
    #save('stime1.npy', stime)
    #Times = np.load('stime1.npy')
    #print(Times)
    
    #print(Plotting.Plotting(ebs,Na,stime))
#*********************************PLOTTING**********************************#
    %matplotlib inline
    import matplotlib.pyplot as plt
    from matplotlib import ticker
    from matplotlib.colors import LogNorm
    import matplotlib

    t,ax = plt.subplots(1,1,figsize=(7,5))
    extent=[ebs.min(), ebs.max(), aps.min(), aps.max()]

    ax.set_xlim(extent[0], extent[1])
    ax.set_ylim(extent[2], extent[3])
    ax.set_xlabel("Binary Eccentricity $e_b$ ")
    ax.set_ylabel("Test particle semimajor axis $a_p$")
    im = ax.imshow(stime, aspect='auto', origin="lower", interpolation='nearest', cmap="viridis",extent=extent)


    ebs = np.linspace(0.,0.7,Ne)
    ab_s = np.zeros(Na)
    for i,eb in enumerate(ebs):
        ab_s[i] = 2.278 + 3.824*eb - 1.71*(eb**2)

    plt.plot(ebs,ab_s,'c', marker = "^",markersize = 7)
    plt.xlabel('$e_b$')
    plt.ylabel('$a_b(a_c$)')
    plt.title('Critical semimajor axis $a_c$ as a function of eccentricity $e_b$')


    cb = plt.colorbar(im, ax=ax)
    cb.solids.set_rasterized(True)
    cb.set_label("Particle Survival Times")

    #leNgth = np.zeros(N**2)
#     for i,tupl in enumerate(params):
#        eb,ap = tupl[0],tupl[1] 
#        sa = rebound.SimulationArchive("archive_eb{:.3f}_ap{:.3f}_Np{:.2f}.bin".format(eb,ap,Np)) 
#        print("Number of snapshots: %d" % len(sa))
#        print("Time of first and last snap shots are: %.1f, %.1f" % (sa.tmin, sa.tmax))
#        sim = sa[1]
#        print(sim.t, sim.particles[2])
   
    plt.show()
    #plt.savefig("Classic_results.pdf")
    
#*********************************PLOTTING*************************************#