In [1]:
import numpy as np

In [2]:
import numpy as np
import bokeh
from bokeh.plotting import figure, show
from bokeh.layouts import layout
from bokeh.models import Image, ColumnDataSource, Slider, CustomJS
from bokeh.layouts import row
from bokeh.util.hex import axial_to_cartesian
from bokeh.util.hex import hexbin
from bokeh.io import output_notebook
from bokeh.transform import linear_cmap
from PIL import Image
import matplotlib.colors
import time
from mike_code import LaguerreAmplitudes #, _n_m, _G_n
import matplotlib.cm as cm
from skimage.transform import resize_local_mean
from bokeh.models import LogColorMapper, LinearColorMapper
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from bokeh.models import BasicTicker, PrintfTickFormatter

output_notebook()

In [3]:
#getting min/max bounds- helpful for future plots
def find_vmin_vmax(min, max):
    if abs(min) < abs(max):
        min = -abs(max)
    elif abs(min)>abs(max):
        max = abs(min)
    return min, max
    
def beef_it(image_path, rescale_factor, rscale, recon_lims, save_loc = None, savestr = None, plot=False):
    #feed in a path to your image, the factor by which you want to scale up the image (to avoid issues with pixels, just makes image bigger)
    #the rscale for the expansion, the radius (in OG pixels) outside of which you don't want to reconstruct (a bit bigger than the galaxy)
    # and the save location

    #### need to build this out more but this is an ok start
    try: 
        im1 = np.asarray(Image.open(image_path))
    except:
        print('uhoh - no image to open, try a different path')
    if plot == True:
        fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=(9,3))
        ax1.imshow(im1[:,:,0], cmap='binary')
        ax2.imshow(im1[:,:,1], cmap='binary')
        ax3.imshow(im1[:,:,2], cmap='binary')
        ax1.set(title='channel 0')
        ax2.set(title='channel 1')
        ax3.set(title='channel 2')

    #############################
    ####### rescaling
    #############################
    av_im1 = np.mean(im1, axis=2)
    s = rescale_factor #we want to size this image up by a factor of s
    av_im1 = resize_local_mean(av_im1, (int(s*av_im1.shape[0]),int(s*av_im1.shape[0])))
    
    #############################
    ########  BFE 
    #############################
    xp = np.linspace(-av_im1.shape[0]/2.,av_im1.shape[0]/2.,av_im1.shape[0]) #getting x, y array
    xpix,ypix = np.meshgrid(xp,xp)
    rr,pp = np.sqrt(xpix**2+ypix**2),np.arctan2(ypix,xpix) #transforming to r, phi
        
        
    rval = np.sqrt(xpix**2+ypix**2).reshape(-1,)
    phi  = np.arctan2(ypix,xpix).reshape(-1,)
    av_im1flat = av_im1.reshape(-1,).copy() #need a copy
    
    # pick a radius of the image where we ignore everything outside of that radius
    gvals = np.where(rval>recon_lims*s)
        
    rval[gvals]         = np.nan
    phi[gvals]          = np.nan
    av_im1flat[gvals] = np.nan
    
    #
    # pick orders for the expansion
    mmax = 8  
    nmax = 12 #twelve makes the mapping to notes easier #8, 12
        
    # pick a scalelength for the reconstruction
    rscl = rscale*s ### 5 is p good
    
    # make the expansion and compute the weights
    # input into LaguerreAmplitudes is rscl, mmax, nmax, R, phi, mass=1., velocity=1.
    LG = LaguerreAmplitudes(rscl,mmax,nmax,rval,phi,av_im1flat) #this computes everything
    cos_cos, sin_cos = LG.laguerre_amplitudes_returns() #this returns the cosine and sine amplitudes
          
    LG.laguerre_reconstruction(rr,pp) #reconstructing the image using the BEF, computing the brightness at each r, phi in the image
    #############################
    ######## result plots!
    #############################
    if plot == True:
        plt.figure()
        plt.title('zoom in on center region, where reconstruction tends to struggle')
        plt.imshow((LG.reconstruction)[150*s:250*s,150*s:250*s],vmin=np.nanmin(av_im1), vmax=np.nanmax(av_im1),cmap='rainbow')
        
        print('max, min, median of reconstruction:', np.max(LG.reconstruction), np.min(LG.reconstruction), np.median(LG.reconstruction))
        print('max, min, median of original image:', np.max(av_im1), np.min(av_im1), np.median(av_im1))
        # make a figure for the comparison of real to reconstruction
        fig = plt.figure(figsize=(20,10),facecolor='white')
        ax1 = fig.add_subplot(131)
        ax2 = fig.add_subplot(132)
        ax3 = fig.add_subplot(133)
        
        cval = np.linspace(-5.,1.,32)
        ax1.imshow((av_im1), cmap=plt.cm.magma)
        ax2.imshow((LG.reconstruction),vmin=np.nanmin(av_im1), vmax=np.nanmax(av_im1),cmap=cm.magma)
        
        # plot the relative uncertainty (maxed out at 100 percent)
        cb = ax3.imshow((LG.reconstruction-av_im1)/np.where(av_im1==0, np.nan, av_im1),vmin=-1,vmax=1,cmap=cm.bwr)
        
        cax = fig.add_axes([0.91, 0.25, 0.01, 0.5])
        fig.colorbar(cb, cax, orientation = 'vertical', label=r'$\frac{reconstruction - original}{original}$',extend='both') 
        
        
        ax1.set_title('image')
        ax2.set_title('reconstruction')
        ax3.set_title('relative uncertainty')
        if (savestr != None) and (save_loc != None):
            plt.savefig(str(save_loc)+str(savestr)+'.jpeg')
    else:
        print('no plots for you')
    print('returning reconstruction, cosine coefficients, and sine coefficients')
    return LG.reconstruction, cos_cos, sin_cos

In [64]:
#super helpful code from https://github.com/bokeh/bokeh/issues/2426, edited for my purpose
class RGBAColorMapper(object):
    """Maps floating point values to rgb values over a palette"""

    def __init__(self, low, high, palette):
        self.range = np.linspace(low, high, len(palette))
        color_arr = np.array([matplotlib.colors.to_rgb(i) for i in palette])
        self.r = color_arr[:,0]
        self.g = color_arr[:,1]
        self.b = color_arr[:,2]
        #self.r, self.g, self.b = np.array(zip(*[matplotlib.colors.to_rgb(i) for i in palette]))

    def color(self, data):
        """Maps your data values to the pallette with linear interpolation"""
        original = np.empty((data.shape[0], data.shape[1]),dtype=np.uint32)
        rgb_mapped_data = original.view(dtype=np.uint8).reshape((data.shape[0], data.shape[1], 4))
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                red = np.interp(data[i, j], self.range, self.r)
                blue = np.interp(data[i, j], self.range, self.b)
                green = np.interp(data[i, j], self.range, self.g)
                # Style plot to return a grey color when value is 'nan'
                if np.isnan(red) == True:
                    red = 240
                if np.isnan(blue) == True:
                    blue = 240
                if np.isnan(green) == True:
                    green = 240
                rgb_mapped_data[i, j, 0] = int(red*255)
                rgb_mapped_data[i, j, 1] = int(blue*255)
                rgb_mapped_data[i, j, 2] = int(green*255)
                rgb_mapped_data[i, j, 3] = 255 #make opaque
        return original

In [65]:
from scipy.special import eval_genlaguerre
xp = np.linspace(-100, 100, 200) #getting x, y array
xpix,ypix = np.meshgrid(xp,xp)
rr,pp = np.sqrt(xpix**2+ypix**2),np.arctan2(ypix,xpix) #transforming to r, phi

def gamma_n(nrange, rscl):
    """
    Calculate the Laguerre alpha=1 normalisation.
    Args:
        nrange (array-like): Range of order parameters.
        rscl (float): Scale parameter for the Laguerre basis.
    Returns:
        array-like: Laguerre alpha=1 normalisation values.
    """
    return (rscl / 2.) * np.sqrt(nrange + 1.)

def G_n(R, nrange, rscl):
    """
    Calculate the Laguerre basis.
    Args:
        R (array-like): Radial values.
        nrange (array-like): Range of order parameters.
        rscl (float): Scale parameter for the Laguerre basis.
    Returns:
        array-like: Laguerre basis values.
    """
    laguerrevalues = np.array([eval_genlaguerre(n, 1, 2 * R / rscl)/gamma_n(n, rscl) for n in nrange])
    return np.exp(-R / rscl) * laguerrevalues

def n_m(mmax):
    """
    Calculate the angular normalisation.
    Returns:
        array-like: Angular normalisation values.
    """
    deltam0 = np.zeros(mmax)
    deltam0[0] = 1.0
    return np.power((deltam0 + 1) * np.pi / 2., -0.5)
        
def recon_from_cos(rr, pp, cos_cos, sin_cos, mmax=8, nmax=12, rscl=5):
    nmvals = n_m(mmax)
    G_j = G_n(rr, np.arange(0, nmax, 1), 5)
    fftotal = 0.
    for m in range(0, mmax):
        for n in range(0, nmax):
            fftotal += cos_cos[m, n] * nmvals[m] * np.cos(m * pp) * G_j[n]
            fftotal += sin_cos[m, n] * nmvals[m] * np.sin(m * pp) * G_j[n]

    reconstruction = 0.5 * fftotal
    return reconstruction


In [362]:
#image_path, rescale_factor, rscale, recon_lims, save_loc = None, savestr = None, plot=False
recon, cos_cos, sin_cos = beef_it('./galaxyzoo_2_data/images/18.jpg', 3, 
                                  5, 100, save_loc = './', savestr = 'galaxy_18_recon')

returning coscoefs, sincoefs
no plots for you
returning reconstruction, cosine coefficients, and sine coefficients


In [330]:
from bokeh.layouts import column
from bokeh.models import (ColumnDataSource, DataTable, HoverTool, IntEditor,
                          NumberEditor, NumberFormatter, SelectEditor,
                          StringEditor, StringFormatter, TableColumn, CellEditor)
from bokeh.plotting import figure, show


In [382]:
#### getting min/max for colors (will need to have this update....?)
#### getting m, n arrays for grid...
from bokeh.embed import file_html
from bokeh.resources import CDN

from bokeh.io import output_file, show
from bokeh.plotting import figure

output_file("layout.html")


ms = np.arange(1,8+1e-3,1)
ns = np.arange(1,12+1e-3,1)

x_arr = np.array([])
y_arr = np.array([])
for m in ms:
    for n in ns:
        x_arr = np.append(x_arr, m)
        y_arr = np.append(y_arr, n)

src = ColumnDataSource(data={'M':x_arr,'N':y_arr,'sine_coefficients':sin_cos.reshape(-1), 
                             'cosine_coefficients':cos_cos.reshape(-1)})

columns = [
    TableColumn(field="N", title="N",
                formatter=StringFormatter(font_style="bold"),
               editor=CellEditor()),
    TableColumn(field="M", title="M",
               formatter=StringFormatter(font_style="bold"),
               editor=CellEditor()),
    TableColumn(field="cosine_coefficients", title="Cosine Coefficient",
                editor=NumberEditor()),
  TableColumn(field="sine_coefficients", title="Sine Coefficient",
                editor=NumberEditor())]
data_table = DataTable(source=src, columns=columns, editable=True, width=600,height=300)


tooltips = [
    ("Cosine Coefficient", "@cosine_coefficients"),
    ("Coefficient", "@sine_coefficients"),
]

min_max_cos = (cos_cos.copy()).flatten()
min_max_sin = (sin_cos.copy()).flatten()
#getting minimum, maximums for coefficients (helpful in setting color range)
min_cos, max_cos = np.min(min_max_cos), np.max(min_max_cos)
min_sin, max_sin = np.min(min_max_sin), np.max(min_max_sin)

vmin_cosine, vmax_cosine = find_vmin_vmax(min_cos, max_cos)
vmin_sine, vmax_sine = find_vmin_vmax(min_sin, max_sin)


#####    cosine
p1 = figure(title="Cosine Heatmap", width=300, height=400, tools="wheel_zoom,xbox_select,reset")
r1 = p1.rect(x="M", y="N", width=1, height=1, source=src,
           fill_color=linear_cmap("cosine_coefficients", bokeh.palettes.BuRd9, low=vmin_cosine, high=vmax_cosine),
           line_color=None)
p1.add_layout(r1.construct_color_bar(
    major_label_text_font_size="7px",
   label_standoff=6,
    border_line_color=None,
    padding=5,
   ), 'below')

#### sine
p2 = figure(title="Sine Heatmap", width=300, height=400, tools="wheel_zoom,xbox_select,reset")
r2 = p2.rect(x="M", y="N", width=1, height=1, source=src,
           fill_color=linear_cmap("sine_coefficients", bokeh.palettes.BuRd9, low=vmin_sine, high=vmax_sine),
           line_color=None)
p2.add_layout(r2.construct_color_bar(
    major_label_text_font_size="7px",
   label_standoff=6,
    border_line_color=None,
    padding=5,
   ), 'below')


p1.grid.grid_line_color = p2.grid.grid_line_color= None
p1.axis.axis_line_color = p2.axis.axis_line_color = None
p1.axis.major_tick_line_color = p2.axis.major_tick_line_color = 'black'
p1.axis.minor_tick_line_alpha = p2.axis.minor_tick_line_alpha = 0
p1.axis.major_label_text_font_size = p2.axis.major_label_text_font_size = "24px"
p1.axis.major_label_standoff = p2.axis.major_label_standoff = 0
p1.xaxis.ticker = p2.xaxis.ticker = [1, 2, 3, 4, 5, 6, 7, 8]
p1.yaxis.ticker = p2.yaxis.ticker = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
p1.xaxis.major_label_orientation = p2.xaxis.major_label_orientation = np.pi / 3

#initialize a column datasource and assign first image into it

xp = np.linspace(-100, 100, 200) #getting x, y array
xpix,ypix = np.meshgrid(xp,xp)
rr,pp = np.sqrt(xpix**2+ypix**2),np.arctan2(ypix,xpix) #transforming to r, phi

recon = recon_from_cos(rr, pp, src.data['cosine_coefficients'].reshape(8, 12), 
                       src.data['sine_coefficients'].reshape(8, 12), 
                       mmax=8, nmax=12, rscl=5)
colormap = RGBAColorMapper(np.min(recon), np.max(recon), bokeh.palettes.Greys256)
rgba_img = colormap.color(recon)

p3 = figure(title="Galaxy Reconstruction", width=600, height=600)
#src = ColumnDataSource(data={'x':[0],'y':[0],'dw':[10],'dh':[10],'im':[recon]})
#p3.image(image='im', x='x', y='y', dw='dw', dh='dh', palette="TolRainbow23", source=src)
p3.image_rgba(image=[rgba_img],  x=[0], y=[0], dw=[10], dh=[10])
p3.grid.grid_line_width = 0.0
p3.xaxis.visible = False
p3.yaxis.visible = False



show(layout([
    [p1, p2],
    [data_table],
    [p3],
    ]))
#html = file_html(layout([
#    [p1, p2],
#    [data_table],
#    [p3],
#    ]), CDN, "test")


https://docs.bokeh.org/en/3.0.1/docs/user_guide/interaction/linking.html
https://discourse.bokeh.org/t/dynamically-updating-multiple-2d-numpy-arrays-in-image-glyph/8799/2
https://docs.bokeh.org/en/latest/docs/user_guide/output/embed.html


In [377]:
def full_shebang(rr, pp, cos_cos, sin_cos):
    recon = recon_from_cos(rr, pp, cos_cos, 
                       sin_cos, 
                       mmax=8, nmax=12, rscl=5)
    colormap = RGBAColorMapper(np.min(recon), np.max(recon), bokeh.palettes.Greys256)
    rgba_img = colormap.color(recon)
    return rgba_img

In [381]:
#### getting min/max for colors (will need to have this update....?)
#### getting m, n arrays for grid...
from bokeh.embed import file_html
from bokeh.resources import CDN

from bokeh.io import output_file, show
from bokeh.plotting import figure

output_file("layout.html")


ms = np.arange(1,8+1e-3,1)
ns = np.arange(1,12+1e-3,1)

x_arr = np.array([])
y_arr = np.array([])
for m in ms:
    for n in ns:
        x_arr = np.append(x_arr, m)
        y_arr = np.append(y_arr, n)


src = ColumnDataSource(data={'M':x_arr,'N':y_arr,'sine_coefficients':sin_cos.reshape(-1), 
                             'cosine_coefficients':cos_cos.reshape(-1)})

columns = [
    TableColumn(field="N", title="N",
                formatter=StringFormatter(font_style="bold"),
               editor=CellEditor()),
    TableColumn(field="M", title="M",
               formatter=StringFormatter(font_style="bold"),
               editor=CellEditor()),
    TableColumn(field="cosine_coefficients", title="Cosine Coefficient",
                editor=NumberEditor()),
  TableColumn(field="sine_coefficients", title="Sine Coefficient",
                editor=NumberEditor())]
data_table = DataTable(source=src, columns=columns, editable=True, width=600,height=300)


tooltips = [
    ("Cosine Coefficient", "@cosine_coefficients"),
    ("Coefficient", "@sine_coefficients"),
]

min_max_cos = (cos_cos.copy()).flatten()
min_max_sin = (sin_cos.copy()).flatten()
#getting minimum, maximums for coefficients (helpful in setting color range)
min_cos, max_cos = np.min(min_max_cos), np.max(min_max_cos)
min_sin, max_sin = np.min(min_max_sin), np.max(min_max_sin)

vmin_cosine, vmax_cosine = find_vmin_vmax(min_cos, max_cos)
vmin_sine, vmax_sine = find_vmin_vmax(min_sin, max_sin)


#####    cosine
p1 = figure(title="Cosine Heatmap", width=300, height=400, tools="wheel_zoom,xbox_select,reset")
r1 = p1.rect(x="M", y="N", width=1, height=1, source=src,
           fill_color=linear_cmap("cosine_coefficients", bokeh.palettes.BuRd9, low=vmin_cosine, high=vmax_cosine),
           line_color=None)
p1.add_layout(r1.construct_color_bar(
    major_label_text_font_size="7px",
   label_standoff=6,
    border_line_color=None,
    padding=5,
   ), 'below')

#### sine
p2 = figure(title="Sine Heatmap", width=300, height=400, tools="wheel_zoom,xbox_select,reset")
r2 = p2.rect(x="M", y="N", width=1, height=1, source=src,
           fill_color=linear_cmap("sine_coefficients", bokeh.palettes.BuRd9, low=vmin_sine, high=vmax_sine),
           line_color=None)
p2.add_layout(r2.construct_color_bar(
    major_label_text_font_size="7px",
   label_standoff=6,
    border_line_color=None,
    padding=5,
   ), 'below')


p1.grid.grid_line_color = p2.grid.grid_line_color= None
p1.axis.axis_line_color = p2.axis.axis_line_color = None
p1.axis.major_tick_line_color = p2.axis.major_tick_line_color = 'black'
p1.axis.minor_tick_line_alpha = p2.axis.minor_tick_line_alpha = 0
p1.axis.major_label_text_font_size = p2.axis.major_label_text_font_size = "24px"
p1.axis.major_label_standoff = p2.axis.major_label_standoff = 0
p1.xaxis.ticker = p2.xaxis.ticker = [1, 2, 3, 4, 5, 6, 7, 8]
p1.yaxis.ticker = p2.yaxis.ticker = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
p1.xaxis.major_label_orientation = p2.xaxis.major_label_orientation = np.pi / 3

#initialize a column datasource and assign first image into it

xp = np.linspace(-100, 100, 200) #getting x, y array
xpix,ypix = np.meshgrid(xp,xp)
rr,pp = np.sqrt(xpix**2+ypix**2),np.arctan2(ypix,xpix) #transforming to r, phi

src2 = ColumnDataSource(data={'sine_coefficients':data_table.source.data['sine_coefficients'], 
                             'cosine_coefficients':data_table.source.data['cosine_coefficients']})
recon = recon_from_cos(rr, pp, src.data['cosine_coefficients'].reshape(8, 12), 
                       src.data['sine_coefficients'].reshape(8, 12), 
                       mmax=8, nmax=12, rscl=5)
colormap = RGBAColorMapper(np.min(recon), np.max(recon), bokeh.palettes.Greys256)
rgba_img = colormap.color(recon)

#could try making new src with updated data from table?
p3 = figure(title="Galaxy Reconstruction", width=600, height=600)
#src = ColumnDataSource(data={'x':[0],'y':[0],'dw':[10],'dh':[10],'im':[recon]})
#p3.image(image='im', x='x', y='y', dw='dw', dh='dh', palette="TolRainbow23", source=src)
p3.image_rgba(image=[rgba_img],  x=[0], y=[0], dw=[10], dh=[10])
p3.grid.grid_line_width = 0.0
p3.xaxis.visible = False
p3.yaxis.visible = False



show(layout([
    [p1, p2],
    [data_table],
    [p3],
    ]))
#html = file_html(layout([
#    [p1, p2],
#    [data_table],
#    [p3],
#    ]), CDN, "test")

In [348]:
data_table.columns[3]

In [371]:
data_table.source.data['sine_coefficients'].reshape(8, 12)

array([[    0.        ,     0.        ,     0.        ,     0.        ,
            0.        ,     0.        ,     0.        ,     0.        ,
            0.        ,     0.        ,     0.        ,     0.        ],
       [  413.12548021,  -410.63519058,   154.05010524,  -365.20190529,
          745.1176104 ,  -870.50059827,   782.69017079,  -395.45973277,
         -202.56083815,   145.90443168,   409.83471918,   -69.86785596],
       [ -107.83927979,   185.30401374,  -240.80028605,   617.98443056,
        -1148.29205433,   855.207136  ,   -78.62271578,   133.63511539,
         -256.56506465,  -173.73148867,     6.49203997,   290.07704589],
       [  -52.63903702,    79.60423073,   -10.24368639,   153.56205513,
         -521.54109306,   489.22634882,    -5.35883042,  -233.03694196,
          152.48530554,  -108.12594865,    21.96826942,    85.26845264],
       [   30.31154289,   -26.46461977,   -50.78587423,    49.09305693,
          295.5616861 ,  -819.70719238,   686.195934  ,   -9

In [373]:
src.data

{'M': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 5., 5., 5.,
        5., 5., 5., 5., 5., 5., 5., 5., 5., 6., 6., 6., 6., 6., 6., 6., 6.,
        6., 6., 6., 6., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 8.,
        8., 8., 8., 8., 8., 8., 8., 8., 8., 8., 8.]),
 'N': array([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,  1.,
         2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,  1.,  2.,
         3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,  1.,  2.,  3.,
         4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,  1.,  2.,  3.,  4.,
         5.,  6.,  7.,  8.,  9., 10., 11., 12.,  1.,  2.,  3.,  4.,  5.,
         6.,  7.,  8.,  9., 10., 11., 12.,  1.,  2.,  3.,  4.,  5.,  6.,
         7.,  8.,  9., 10., 11., 12.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,
         8.,  9., 10., 11., 12.]),
 'sine_coe