In [None]:
def kde_den(data, cov=0.25):
    from scipy.stats import gaussian_kde
    density = gaussian_kde(data)
    xs = np.linspace(0,8,200)
    density.covariance_factor = lambda : cov
    density._compute_covariance()
    return density

def draw_kdes(dlM, dlm, dlo, dlt, ax
                  , nevents
                  , lw=1.5
                  , excess=False):
    
        dM = kde_den(dlM)
        dm = kde_den(dlm)
        do = kde_den(dlo)
        dtot = kde_den(dlt)

        xs=np.linspace(-0.7,0.7,51)
        if excess:
            i_positive = np.linspace(0.0,0.6,26)
            dM_curve = dM(xs) - np.concatenate((dM(i_positive)[::-1], dM(i_positive[1:])))
            dm_curve = dm(xs) - np.concatenate((dm(i_positive)[::-1], dm(i_positive[1:])))
            do_curve = do(xs) - np.concatenate((do(i_positive)[::-1], do(i_positive[1:])))
            dtot_curve = dtot(xs) - np.concatenate((dtot(i_positive)[::-1], dtot(i_positive[1:])))
        else:
            dM_curve = dM(xs)
            dm_curve = dm(xs)
            do_curve = do(xs)
            dtot_curve = dtot(xs)
        
        nM = len(dlM)
        nm = len(dlm)
        no = len(dlo)
        ntot = len(dlt)

        
        Mlabel="Major \n" +r"$N_{g}(N_{e})$" + " = {}({})".format(nM, nevents[0])
        mlabel="Minor \n" +r"$N_{g}(N_{e})$" + " = {}({})".format(nm, nevents[1])
        olabel="Rest  \n" +r"$N_{g}$" + " = {}".format(no)
        totlabel="Total\n" +r"$N_{g}$" + " = {}".format(ntot)
        ax.plot(xs, dM_curve*nM/ntot, label=Mlabel, lw=lw, color="r")
        ax.plot(xs, dm_curve*nm/ntot, label=mlabel, lw=lw, color="g")        
        ax.plot(xs, do_curve*no/ntot, label=olabel, lw=lw, color="b")
        ax.plot(xs, dtot_curve, label=totlabel, lw=lw, color="black")
        
        ax.set_ylim([0, 1.15*ax.get_ylim()[1]])
        

def kde_sci(mpgs
    ,mstar_cut_hard = 5e9
    ,mcut=1e10
    ,fname="figs/test"
    ,wdir='./'
    ,nbins=21
    ,kde=True
    ,hist=True
    ,shade=True
    ,norm_hist=False
    ,pallette="muted"
    ,ylim=None
    ,per_event=True
    ,per_galaxy=True
    ,detected=True
    ,maj_ratio = 4
    ,excess=True
    ,img_scale=1.0):
    
    
    fontsize_ticks = 6 * img_scale
    fontsize_tick_label = 8 * img_scale
    fontsize_legend = 5 * img_scale
    
    from matplotlib.ticker import NullFormatter

    l_dl_e = []
    l_mr_e = []
    l_mass_e = []

    s_dl_e = []
    s_mr_e = []
    s_mass_e = []

    l_dlt_g=[]
    l_dlo_g=[]
    l_dlM_g=[]
    l_dlm_g=[]
    l_mass_g=[]

    s_dlt_g=[]
    s_dlo_g=[]
    s_dlM_g=[]
    s_dlm_g=[]
    s_mass_g=[]


    M_changed = 0
    m_changed = 0
    no_merger_count = 0
    count = 0
    Maj_small = 0
    for i, gal in enumerate(mpgs):
        mgal = gal.data["mstar"][0]
        if mgal > mstar_cut_hard:
            delta_lambda_tot = np.average(gal.data['lambda_r'][:5]) - np.average(gal.data['lambda_r'][-5:])
            delta_lambda_major = 0
            delta_lambda_minor = 0        

            # Large
            if mgal > mcut:
                if hasattr(gal, "merger"):
                    if gal.merger is not None:
                        l_dl_e.extend(gal.merger.delta_l)
                        l_mr_e.extend(gal.merger.mr)
                        for dl, mr in zip(gal.merger.delta_l, gal.merger.mr):
                            if (mr < maj_ratio) and (dl > -1):
                                delta_lambda_major = delta_lambda_major + dl
                            if (mr > maj_ratio) and (dl > -1):
                                delta_lambda_minor = delta_lambda_minor + dl

                delta_lambda_other = delta_lambda_tot - delta_lambda_major - delta_lambda_minor
                l_dlt_g.append(delta_lambda_tot)
                l_dlo_g.append(delta_lambda_other)
                l_dlM_g.append(delta_lambda_major)
                l_dlm_g.append(delta_lambda_minor)
            # small
            else:
                #s_mass_g.append(mgal)
                if hasattr(gal, "merger"):
                    if gal.merger is not None:
                        s_dl_e.extend(gal.merger.delta_l)
                        s_mr_e.extend(gal.merger.mr)
                        for dl, mr in zip(gal.merger.delta_l, gal.merger.mr):
                            if (mr < maj_ratio) and (dl > -1):
                                delta_lambda_major = delta_lambda_major + dl
                            if (mr > maj_ratio) and (dl > -1):
                                delta_lambda_minor = delta_lambda_minor + dl

                    delta_lambda_other = delta_lambda_tot - delta_lambda_major - delta_lambda_minor
                    s_dlt_g.append(delta_lambda_tot)
                    s_dlo_g.append(delta_lambda_other)
                    s_dlM_g.append(delta_lambda_major)
                    s_dlm_g.append(delta_lambda_minor)

    l_dlt_g = np.array(l_dlt_g)
    l_dlo_g = np.array(l_dlo_g)
    l_dlM_g = np.array(l_dlM_g)
    l_dlm_g = np.array(l_dlm_g)
    #l_mass_g = np.array(l_mass_g)

    s_dlt_g = np.array(s_dlt_g)
    s_dlo_g = np.array(s_dlo_g)
    s_dlM_g = np.array(s_dlM_g)
    s_dlm_g = np.array(s_dlm_g)
    #s_mass_g = np.array(s_mass_g)

    # detected
    l_dlM_g = l_dlM_g [l_dlM_g !=0]
    #l_dlM_M = l_mass_g[l_dlM_g !=0]
    l_dlm_g = l_dlm_g [l_dlm_g !=0]
    #l_dlm_M = l_mass_g[l_dlm_g !=0]
    #l_dlo_M = l_mass_g

    s_dlM_g = s_dlM_g [s_dlM_g !=0]
    #s_dlM_M = s_mass_g[s_dlM_g !=0]
    s_dlm_g = s_dlm_g [s_dlm_g !=0]
    #s_dlm_M = s_mass_g[s_dlm_g !=0]
    #s_dlo_M = s_mass_g


    l_dl_e = np.array(l_dl_e)
    l_mr_e = np.array(l_mr_e)
    #l_mass_e = []

    s_dl_e = np.array(s_dl_e)
    s_mr_e = np.array(s_mr_e)
    #s_mass_e = []

    fig, axs = plt.subplots(3, sharex=True)
    fig.set_size_inches(4.75,7)
    plt.subplots_adjust(hspace=0.01)

    all_dlM_g = np.concatenate((l_dlM_g,s_dlM_g))
    all_dlm_g = np.concatenate((l_dlm_g,s_dlm_g))
    all_dlo_g = np.concatenate((l_dlo_g,s_dlo_g))
    all_dlt_g = np.concatenate((l_dlt_g,s_dlt_g))

    draw_kdes(all_dlM_g,
              all_dlm_g,
              all_dlo_g,
              all_dlt_g,
              axs[0],
              [sum(s_mr_e < maj_ratio) + sum(l_mr_e < maj_ratio),
               sum(s_mr_e > maj_ratio) + sum(l_mr_e > maj_ratio),
               len(all_dlo_g)],
              excess=excess)

    draw_kdes(l_dlM_g,
              l_dlm_g,
              l_dlo_g,
              l_dlt_g,
              axs[1],
              [sum(l_mr_e < maj_ratio),
               sum(l_mr_e > maj_ratio),
               len(l_dlo_g)],
              excess=excess)

    draw_kdes(s_dlM_g,
              s_dlm_g,
              s_dlo_g,
              s_dlt_g,
              axs[2],
              [sum(s_mr_e < maj_ratio),
               sum(s_mr_e > maj_ratio),
               len(s_dlo_g)],
              excess=excess)

    axs[0].set_xlim([-0.6,0.6])
    for ax in axs:
        ax.xaxis.grid()
        leg = ax.legend(fontsize=fontsize_legend)
        leg.get_frame().set_alpha(0.5)
        ax.yaxis.set_major_formatter(NullFormatter())
        ax.set_ylabel("relative probability", fontsize=fontsize_tick_label)

    axs[2].set_xlabel(r"$\Delta \lambda_{R_{eff}}$", fontsize=fontsize_tick_label)
    axs[2].tick_params(labelsize=fontsize_ticks)
    axs[2].set_xlim([-0.7,0.6])
    #axs[0].legend(fontsize=12)

    axs[0].text(0.05, 0.87, "(A)", weight="bold", transform=axs[0].transAxes, fontsize=fontsize_ticks) 
    axs[0].text(0.15, 0.87, "All",transform=axs[0].transAxes, fontsize=fontsize_ticks)
    axs[1].text(0.05, 0.87, "(B) ", weight="bold",transform=axs[1].transAxes, fontsize=fontsize_ticks)
    axs[1].text(0.15, 0.87, r"$log_{10}M_{\star} > $ " +"{:.1f}".format(np.log10(mcut))
                , fontsize=fontsize_ticks
                , transform=axs[1].transAxes)
    axs[2].text(0.05, 0.87, "(C) ", weight="bold",transform=axs[2].transAxes, fontsize=fontsize_ticks)
    axs[2].text(0.15, 0.87, r"$log_{10}M_{\star} < $ " +"{:.1f}".format(np.log10(mcut))
                , fontsize=fontsize_ticks
                , transform=axs[2].transAxes)


    plt.savefig(fname + "{:.1f}.png".format(np.log10(mcut)), dpi=200, bbox_inches="tight")
    plt.savefig(fname + "{:.1f}.pdf".format(np.log10(mcut)), bbox_inches='tight') # eps does NOT support transparency!
    plt.savefig(fname + "{:.1f}.eps".format(np.log10(mcut)), bbox_inches='tight')
    plt.savefig(fname + "{:.1f}.svg".format(np.log10(mcut)), bbox_inches='tight')

    plt.close()
    

In [None]:
def plot_density_map(axmain, x, y, xmin, xmax, ymin, ymax,
                    levels=None, color=True, cmap="winter",
                    surf=False, bw_method="silverman",
                    d_alpha=1.0):
    import scipy.stats as st  
    # Draw main density map 
    xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([xx.ravel(), yy.ravel()])

    values = np.vstack([x, y])
    kernel = st.gaussian_kde(values, bw_method=bw_method)
    f = np.reshape(kernel(positions).T, xx.shape)
    # , [0.2, 2, 3, 5, 10, 16]  custom contour levels.
    f /= max(f.ravel())
    
#   ains = inset_axes(axmain, width='5%', height='60%', loc=5)
    if surf:
        cfset = axmain.contourf(xx, yy, f,
                            levels=levels,
                            cmap=cmap,
                            alpha = d_alpha)#,
                            
    else:
        cfset = axmain.contour(xx, yy, f,
                            levels=levels,
                            cmap=cmap,
                            alpha = d_alpha,
                            linewidths=0.6)
    
    return cfset
def plot_sami(ax, data, contour=True, scatter=False):
    x = data['ellp']
    y = data['r1']

    xmin, xmax = -0.05, 0.8
    ymin, ymax = -0.05, 0.8

    if contour:
        plot_densiy_map(ax, x, y, xmin, xmax, ymin, ymax,
                        N=4,
                        levels=None, color=False, cmap="spring")
    if scatter:
        ax.scatter()

def density_map(x, y, sort=True):
    from scipy.stats import gaussian_kde
    xy = np.vstack([x,y])
    z = gaussian_kde(xy)(xy) 
    z /= max(z)

    idx = z.argsort()    
    xx, yy = x[idx], y[idx]
    z = z[idx]
    
    #im = ax.scatter(xx, yy, c=z, s=50, edgecolor='')
    return xx,yy,z

def do_plot(x,y, 
            do_scatter=True,
            contour_label=False,
            surf = False,
            img_scale = 1.0,
            twocolors=['#4c72b0', '#c44e52'],
            den_cmap = "PuBu",
            levels=None,
            fname_vs_e = "./figs/lambda_vs_e_z0",
            d_alpha=1.0
            ):
    import scipy.stats as st  

    fontsize_ticks = 6 * img_scale
    fontsize_tick_label = 8 * img_scale
    fontsize_legend = 5 * img_scale
    img_size_single_column =2.25 * img_scale
    
    
    xmin = ymin = -0.05
    xmax = ymax = 0.9

    fig, axmain=plt.subplots(1)

    fig.set_size_inches(img_size_single_column,
                        img_size_single_column)

    axmain.set_xlim(xmin, xmax)
    axmain.set_ylim(ymin, ymax)
    # suppress last tick
    axmain.set_xticks(np.arange(0, xmax, 0.1))
    axmain.set_yticks(np.arange(0, ymax, 0.1))
    axmain.set_xlabel(r"$\epsilon_{R_{e}}$", fontsize=fontsize_tick_label)
    axmain.set_ylabel(r"$\lambda_{R_{e}}$", fontsize=fontsize_tick_label)
    axmain.tick_params(axis='both', which='major', labelsize=fontsize_ticks)

    # S/R demarcation line
    sr_line_xx = np.arange(90)*0.01
    # Simple demarkation (Emsellem 2011)
    sr_line_yy = 0.31 * np.sqrt(sr_line_xx)
    axmain.plot(sr_line_xx, sr_line_yy, '--', lw=1, color='black')

    # Draw main density map 

    
    if surf:
        fname_vs_e = fname_vs_e + "_srf"
    elif contour_label:
        fname_vs_e = fname_vs_e + "_cl"
    try:
        nlevels = str(len(levels))
    except:
        nlevels = "default"
    fname_vs_e = fname_vs_e + "_" + nlevels

    if 1 == 2:
        cfset = plot_density_map(axmain, x, y, xmin, xmax, ymin, ymax,
                            levels=levels,
                            cmap=den_cmap,
                            surf=True,
                            d_alpha=d_alpha)    
    
    if 1==1:
        xx,yy,z = density_map(x, y)    
        axmain.scatter(xx, yy, c=z, s=15, edgecolor='',
                       cmap=den_cmap, rasterized=True,
                       alpha=1.0, label="This work")
    
    if do_scatter:
        scatter = axmain.scatter(x,y, s=7,
                             facecolor=twocolors[0],
                             edgecolor='none',
                             alpha= 0.7,
                             label="This work")
        fname_vs_e = fname_vs_e + "_sct"

    if 1 == 1:
        cfset = plot_density_map(axmain, x, y, xmin, xmax, ymin, ymax,
                            levels=levels,
                            cmap="winter",
                            surf=False,
                            d_alpha=d_alpha)
    # My data
    
    
    if contour_label:
        axmain.clabel(cfset, inline=1, fontsize=7)

    #ATLAS3D
    axmain.scatter(atlas[:,0], atlas[:,1],
                   s=20,
                   color=twocolors[1],
                   marker=".",
                   lw=1,
                   alpha=0.8,
                   label=r"$ATLAS^{3D}$")    

    # Legend
    if 1 == 2:
        handles, labels = axmain.get_legend_handles_labels()
        #Create custom artists
        thisArtist = plt.Line2D((0,1),(0,0), color='k', marker='o', linestyle='')
        #Create legend from custom artist/label lists
        handles.append(thisArtist)
        labels.append("This work")
        axmain.legend(handles, labels,
                      loc=2,
                      borderaxespad=0.,
                      labelspacing=1.2,
                      fontsize=fontsize_legend)
    else:
        axmain.legend(loc=2,
                      borderaxespad=0.,
                      labelspacing=1.2,
                      fontsize=fontsize_legend)


    #plt.savefig(fname_vs_e + ".pdf", bbox_inches='tight')
    plt.savefig(fname_vs_e + ".png", bbox_inches='tight', dpi=200)
    #plt.savefig(fname_vs_e + ".svg", bbox_inches='tight')
    #plt.savefig(fname_vs_e + ".eps", bbox_inches='tight')

    #plt.show()

    plt.close()

def truncate_colormap(cmap_name, minval=0.0, maxval=1.0, n=100):
    """
    Use only part of color maps
    """
    import matplotlib.colors as colors
    cmap = plt.get_cmap(cmap_name)
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap