In [224]:
import ipywidgets as widgets
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
import IPython

def electron_wavelength(ht):
   h =  6.626E-34
   e_c =  1.602E-19
   c =  3.000E8
   e_m0 = 9.109E-31

   return (h / np.sqrt(2 * e_m0 * e_c * ht) 
             / np.sqrt(1 + e_c * ht / (2 * e_m0 * c * c)) * 1E10)

def ctf_1d(box_size, pixel_size, c_s, energy, q0, dz, idx):
    wavelength = electron_wavelength(energy * 1000)
    pc = np.sqrt(1 - q0 * q0)
    K1 = np.pi / 2 * c_s * 1E7 * wavelength**3
    K2 = np.pi * wavelength

    k = idx / (box_size * pixel_size)
    k2 = k**2
    gamma = K1 * k**4 - K2 * k**2 * dz
    return -pc * np.sin(gamma) + q0 * np.cos(gamma) 

def ctf_thresh(box_size, pixel_size, c_s, energy, q0, dz):
    wavelength = electron_wavelength(energy * 1000)
    poly = [wavelength**3 * c_s * 1E7 * 0.5,
            0,
            wavelength * -dz,
            pixel_size * (box_size // 2),
            np.arctan(q0 / np.sqrt(1 - q0**2))]

    root = [1 / x for x in np.roots(poly)][2]
    if np.real(root) > 0 and np.imag(root) == 0:
        return np.real(root)
    else:
        return 0
    
def min_box(pixel_size, c_s, energy, q0, dz):
    wavelength = electron_wavelength(energy * 1000)
    return ((wavelength * dz / pixel_size**2) -
            (wavelength**3 * c_s * 1E7 / (8 * pixel_size**4)) -
            ((4 / np.pi) * np.arctan(q0 / np.sqrt(1 - q0**2))))

def deloc(box_size, pixel_size, c_s, energy, dz, idx):
    wavelength = electron_wavelength(energy * 1000)
    k = idx / (box_size * pixel_size)
    return (((wavelength * k) *
             (abs(dz) + (wavelength**2 * k**2 * c_s * 1E7))) / 
            pixel_size)

bT_style = {'description_width' : '150px'}

box_size_ = widgets.BoundedIntText(
    value = 128,
    min = 32,
    max = 4096,
    step = 2,
    description = 'Box size (px): ',
    style = bT_style,
    disabled = False
)

box_size__ = widgets.IntSlider(
    value = 128,
    min = 32,
    max = 4096,
    step = 2,
    description = '',
    disabled = False,
    continuous_update = True,
    orientation = 'horizontal',
    readout = False
)

box_size_ui = widgets.HBox([box_size_, box_size__])
box_size_link = widgets.jslink((box_size_, 'value'),
                               (box_size__, 'value'))

pixel_size_ = widgets.BoundedFloatText(
    value = 1.0,
    min = 0.01,
    max = 5,
    step = 0.01,
    description = 'Pixel size (Å): ',
    style = bT_style,
    disabled = False
)

pixel_size__ = widgets.FloatSlider(
    value = 1.0,
    min = 0.01,
    max = 5,
    step = 0.01,
    description = '',
    disabled = False,
    continuous_update = True,
    orientation = 'horizontal',
    readout = False
)

pixel_size_ui = widgets.HBox([pixel_size_, pixel_size__])
pixel_size_link = widgets.jslink((pixel_size_, 'value'),
                                 (pixel_size__, 'value'))

c_s_ = widgets.BoundedFloatText(
    value=2.7,
    min=0,
    step = 0.1,
    description = 'Cs (mm): ',
    style = bT_style,
    disabled = False
)

energy_ = widgets.BoundedFloatText(
    value = 300.0,
    min = 1.0,
    max = 1000,
    step = 1,
    description = '{0:20s}'.format('Energy (keV): '),
    style = bT_style,
    disabled = False
)

q0_ = widgets.BoundedFloatText(
    value = 0.1,
    min = 0,
    max = 1,
    step = 0.01,
    description = 'Amplitude Contrast: ',
    style = bT_style,
    disabled = False
)

res_ = widgets.BoundedFloatText(
    value = 4.0,
    min = 0.02,
    max = 20,
    step = 0.01,
    description = 'Resolution (Å): ',
    style = bT_style,
    disabled = False
)

res__ = widgets.FloatSlider(
    value = 4.0,
    min = 0.02,
    max = 20,
    step = 0.01,
    description = '',
    disabled = False,
    continuous_update = True,
    orientation = 'horizontal',
    readout = False
)
res_ui = widgets.HBox([res_, res__])
res_link = widgets.jslink((res_, 'value'),
                          (res__, 'value'))

def_grid = widgets.GridspecLayout(4, 5, width='600px')

for i in range(4):
    for j in range(5):
        if i == 0 and j == 0:
            def_grid[i, j] = widgets.BoundedFloatText(
                value = 10000.0,
                min = -1000000,
                max = 1000000,
                step = 1,
                readout = False,
                disabled = False,
                layout = widgets.Layout(width = '100px')
            )
        else:
            def_grid[i, j] = widgets.BoundedFloatText(
                value = -1.0,
                min = -1000000,
                max = 1000000,
                step = 1,
                readout = False,
                disabled = False,
                layout = widgets.Layout(width = '100px')
        )

def_label = widgets.Label(value='Defocuses (Å): ')
ui = widgets.VBox([box_size_ui, pixel_size_ui, c_s_, energy_, q0_, res_ui])
run_button = widgets.Button(description='Run')
reset_button = widgets.Button(description='Reset')
buttons = widgets.HBox([run_button, reset_button])

box_size = box_size_.value
k_max = box_size // 2 + 1
pixel_size = pixel_size_.value
c_s = c_s_.value
energy = energy_.value
q0 = q0_.value
res = res_.value

if res < pixel_size * 2:
    res = pixel_size * 2
    res_.value = pixel_size * 2
    
def_list = sorted([ x.value for x in def_grid.children ], reverse = True)

n_step_thresh = 1024
n_step_minbox = 1024
    
ctf_x = [ x / (box_size * pixel_size) for x in range(k_max)]
thresh_x = [box_size_.min +
            (box_size_.max - box_size_.min) / (n_step_thresh - 1) *
            x for x in range(n_step_thresh)]

minbox_x = [pixel_size_.min +
            (pixel_size_.max - pixel_size_.min) / (n_step_minbox - 1) *
            x for x in range(n_step_minbox)]

template = 'ggplot2'
colorway = ['#49006A',
            '#7A0177',
            '#AE017E',
            '#DD3497',
            '#F768A1',
            '#FA9FB5',
            '#FCC5C0']

def get_color(idx):
    return colorway[(idx % (len(colorway) - 1)) + 1]

ctf_plots = []
thresh_plots = [go.Scatter(
    x = [box_size],
    y = [res],
    name = 'selected',
    showlegend = False,
    mode = 'markers',
    marker = dict(
        symbol = 'circle-open',
        size = 18,
        color = colorway[0],
        line = dict(width = 4)))]

minbox_plots = [go.Scatter(
    x = [pixel_size],
    y = [box_size],
    name = 'selected',
    showlegend = False,
    mode = 'markers',
    marker = dict(
        symbol = 'circle-open',
        size = 18,
        color = colorway[0],
        line = dict(width = 4)))]

deloc_pt = round(box_size * pixel_size / res)
deloc_pt_x = ctf_x[deloc_pt]
deloc_pt_y = deloc(box_size, pixel_size, c_s, energy,
                   np.max(np.abs(def_list)), deloc_pt)

deloc_plots = [go.Scatter(
    x = [deloc_pt_x],
    y = [deloc_pt_y],
    name = 'selected',
    showlegend = False,
    mode = 'markers',
    marker = dict(
        symbol = 'circle-open',
        size = 18,
        color = colorway[0],
        line = dict(width = 4)))]

ctf2_sum = [0 for x in range(k_max)]
ctf2_norm = 0

for idx, dz in enumerate(def_list):
    color = get_color(idx)
    
    if dz == -1:
        ctf_y = [0 for x in range(k_max)]
        ctf_plots.append(go.Scatter(
            x = ctf_x,
            y = ctf_y,
            name = '{0:.2f}um'.format(dz / 10000),
            hoverinfo = 'name+y',
            line_shape = 'linear',
            line = dict(color = color),
            visible = False))
        
        thresh_y = [0 for x in thresh_x]
        thresh_plots.append(go.Scatter(
            x = thresh_x,
            y = thresh_y,
            name = '{0:.2f}um'.format(dz / 1000),
            hoverinfo = 'name+x+y',
            line_shape = 'linear',
            line = dict(color = color),
            visible = False))
        
        minbox_y = [0 for x in minbox_x]
        minbox_plots.append(go.Scatter(
            x = minbox_x,
            y = minbox_y,
            name = '{0:.2f}um'.format(dz / 1000),
            hoverinfo = 'name+x+y',
            line_shape = 'linear',
            line = dict(color = color),
            visible = False))
        
        deloc_y = [0 for x in range(k_max)]
        deloc_plots.append(go.Scatter(
            x = ctf_x,
            y = deloc_y,
            name = '{0:.2f}um'.format(dz / 10000),
            hoverinfo = 'name+y',
            line_shape = 'linear',
            line = dict(color = color),
            visible = False))
    else:
        ctf_y = [ctf_1d(box_size, pixel_size, c_s, 
                        energy, q0, dz, x) for x in range(k_max)]
        
        ctf_plots.append(go.Scatter(
            x = ctf_x,
            y = ctf_y,
            name = '{0:.2f}um'.format(dz / 10000),
            hoverinfo = 'name+y',
            line_shape = 'linear',
            line = dict(color = color),
            visible = True))
        
        ctf2_sum = [x[0] + x[1]**2 for x in zip(ctf2_sum, ctf_y)]
        ctf2_norm = ctf2_norm + 1
        
        thresh_y = [ctf_thresh(x, pixel_size, c_s,
                               energy, q0, dz) for x in thresh_x]
        
        thresh_plots.append(go.Scatter(
            x = thresh_x,
            y = thresh_y,
            name = '{0:.2f}um'.format(dz / 10000),
            hoverinfo = 'name+x+y',
            line_shape = 'linear',
            line = dict(color = color),
            visible = True))
        
        minbox_y = [min_box(x, c_s, energy, q0, dz) for x in minbox_x]
        minbox_plots.append(go.Scatter(
            x = minbox_x,
            y = minbox_y,
            name = '{0:.2f}um'.format(dz / 10000),
            hoverinfo = 'name+x+y',
            line_shape = 'linear',
            line = dict(color = color),
            visible = True))

        deloc_y = [deloc(box_size, pixel_size, c_s,
                         energy, dz, x) for x in range(k_max)]

        deloc_plots.append(go.Scatter(
            x = ctf_x,
            y = deloc_y,
            name = '{0:.2f}um'.format(dz / 10000),
            hoverinfo = 'name+y',
            line_shape = 'linear',
            line = dict(color = color),
            visible = True))

n_xticks = 20
n_yticks = 10
ctf_x_tickvals = [ctf_x[-1] / (n_xticks - 1) * x for x in range(n_xticks)]
ctf_x_ticktext = ['Inf'] + ['{0:.2f}'.format(1 / x) for x in ctf_x_tickvals[1:]]
ctf_y_tickvals = [ -1 + (2 / (n_yticks) * x) for x in range(n_yticks + 1)]
ctf_y_ticktext = ['{0:.2f}'.format(x) for x in ctf_y_tickvals]
ctf_fig = go.FigureWidget(
    data = ctf_plots[:],
    layout = dict(
        template = template,
        title = 'Contrast Transfer Function',
        xaxis_title = 'Resolution (Å)',
        xaxis = dict(
            tickmode = 'array',
            tickvals = ctf_x_tickvals,
            ticktext = ctf_x_ticktext),
        yaxis_title = 'CTF (a.u.)',
        yaxis = dict(
            tickmode = 'array',
            tickvals = ctf_y_tickvals,
            ticktext = ctf_y_ticktext)))

if ctf2_norm == 0:
    ctf2_norm = 1
    
ctf2_sum = [x / ctf2_norm for x in ctf2_sum]
ctf2_y_tickvals = [ (2 / (n_yticks) * x) for x in range(n_yticks + 1)]
ctf2_y_ticktext = ['{0:.2f}'.format(x) for x in ctf2_y_tickvals]
ctf2_fig = go.FigureWidget(
    data = go.Scatter(
        x = ctf_x,
        y = ctf2_sum,
        name = 'cumulative CTF sq.',
        hoverinfo = 'name+y',
        line_shape = 'linear',
        line = dict(color = colorway[1]),
        visible = True),
    layout = dict(
        template = template,
        title = 'Average Cumulative Contrast Transfer Function Squared',
        xaxis_title = 'Resolution (Å)',
        xaxis = dict(
            tickmode = 'array',
            tickvals = ctf_x_tickvals,
            ticktext = ctf_x_ticktext),
        yaxis_title = 'Average Sum CTF2 (a.u.)',
        yaxis = dict(
            tickmode = 'array',
            tickvals = ctf2_y_tickvals,
            ticktext = ctf2_y_ticktext)))

thresh_range_x = 2 * box_size - box_size_.min 
if thresh_range_x > box_size_.max:
    thresh_range_x = box_size_.max
    
thresh_range_y = res * 2
if thresh_range_y > res_.max:
    thresh_range_y = res_.max

thresh_x_tickvals = [box_size_.min +
                     (thresh_range_x - box_size_.min) /
                     (n_xticks - 1) * x for x in range(n_xticks)]

thresh_x_ticktext = ['{0:d}'.format(round(x)) for x in thresh_x_tickvals]
thresh_y_tickvals = [thresh_range_y / (n_yticks - 1) * x for x in range(n_yticks)]
thresh_y_ticktext = ['{0:.2f}'.format(x) for x in thresh_y_tickvals]

thresh_fig = go.FigureWidget(
    data = thresh_plots[:],
    layout = dict(
        template = template,
        title = 'Resolution Threshold for aliasing',
        xaxis_title = 'Box size (px)',
        xaxis = dict(
            range = [box_size_.min, thresh_range_x],
            tickmode = 'array',
            tickvals = thresh_x_tickvals,
            ticktext = thresh_x_ticktext),
        yaxis_title = 'Resolution (Å)',
        yaxis = dict(
            range = [0, thresh_range_y],
            tickmode = 'array',
            tickvals = thresh_y_tickvals,
            ticktext = thresh_y_ticktext)))

thresh_fig.add_shape(
    type = "line",
    xref = "x",
    yref = "y",
    x0 = box_size,
    y0 = 0,
    x1 = box_size,
    y1 = res,
    line = dict(
        color = colorway[0],
        width = 2,
        dash = 'dash'))

thresh_fig.add_shape(
    type = "line",
    xref = "x",
    yref = "y",
    x0 = 0,
    y0 = res,
    x1 = box_size,
    y1 = res,
    line = dict(
        color = colorway[0],
        width = 2,
        dash = 'dash'))

minbox_range_x = pixel_size * 2
if minbox_range_x > pixel_size_.max:
    minbox_range_x = pixel_size_.max

minbox_range_y = thresh_range_x
minbox_x_tickvals = [minbox_range_x / (n_xticks - 1) * x for x in range(n_xticks)]
minbox_x_ticktext = ['{0:.2f}'.format(x) for x in minbox_x_tickvals]
minbox_y_tickvals = [box_size_.min +
                     (minbox_range_y - box_size_.min) /
                     (n_yticks - 1) * x for x in range(n_yticks)]

minbox_y_ticktext = ['{0:.0f}'.format(round(x)) for x in minbox_y_tickvals]

minbox_fig = go.FigureWidget(
    data = minbox_plots[:],
    layout = dict(
        template = template,
        title = 'Minimum Boxsize to avoid aliasing',
        xaxis_title = 'Pixel size (Å)',
        xaxis = dict(
            range = [0, minbox_range_x],
            tickmode = 'array',
            tickvals = minbox_x_tickvals,
            ticktext = minbox_x_ticktext),
        yaxis_title = 'Box size (px)',
        yaxis = dict(
            range = [box_size_.min, minbox_range_y],
            tickmode = 'array',
            tickvals = minbox_y_tickvals,
            ticktext = minbox_y_ticktext)))

minbox_fig.add_shape(
    type = "line",
    xref = "x",
    yref = "y",
    x0 = pixel_size,
    y0 = 0,
    x1 = pixel_size,
    y1 = box_size,
    line = dict(
        color = colorway[0],
        width = 2,
        dash = 'dash'))

minbox_fig.add_shape(
    type = "line",
    xref = "x",
    yref = "y",
    x0 = 0,
    y0 = box_size,
    x1 = pixel_size,
    y1 = box_size,
    line = dict(
        color = colorway[0],
        width = 2,
        dash = 'dash'))

deloc_range_y = (deloc_pt_y) * 2
if deloc_range_y > box_size_.max:
    deloc_range_y = box_size_.max

deloc_y_tickvals = [deloc_range_y / (n_yticks - 1) * x for x in range(n_yticks)]
deloc_y_ticktext = ['{0:.0f}'.format(round(x)) for x in deloc_y_tickvals]

deloc_fig = go.FigureWidget(
    data = deloc_plots[:],
    layout = dict(
        template = template,
        title = 'Delocalisation Due to Contrast Transfer Function',
        xaxis_title = 'Resolution (Å)',
        xaxis = dict(
            tickmode = 'array',
            tickvals = ctf_x_tickvals,
            ticktext = ctf_x_ticktext),
        yaxis_title = 'Delocalisation (px)',
        yaxis = dict(
            range = [0, deloc_range_y],
            tickmode = 'array',
            tickvals = deloc_y_tickvals,
            ticktext = deloc_y_ticktext)))
    
deloc_fig.add_shape(
    type = "line",
    xref = "x",
    yref = "y",
    x0 = deloc_pt_x,
    y0 = 0,
    x1 = deloc_pt_x,
    y1 = deloc_pt_y,
    line = dict(
        color = colorway[0],
        width = 2,
        dash = 'dash'))

deloc_fig.add_shape(
    type = "line",
    xref = "x",
    yref = "y",
    x0 = 0,
    y0 = deloc_pt_y,
    x1 = deloc_pt_x,
    y1 = deloc_pt_y,
    line = dict(
        color = colorway[0],
        width = 2,
        dash = 'dash'))

figures = widgets.VBox([ctf_fig,
                        ctf2_fig,
                        thresh_fig,
                        minbox_fig,
                        deloc_fig])

def run_clicked(b):
    box_size = box_size_.value
    k_max = box_size // 2 + 1
    pixel_size = pixel_size_.value
    c_s = c_s_.value
    energy = energy_.value
    q0 = q0_.value
    res = res_.value

    if res < pixel_size * 2:
        res = pixel_size * 2
        res_.value = pixel_size * 2
        
    def_list = sorted([ x.value for x in def_grid.children ], reverse = True)
    
    ctf_x = [ x / (box_size * pixel_size) for x in range(k_max)]
    ctf2_sum = [0 for x in range(k_max)]
    ctf2_norm = 0
    deloc_pt = round(box_size * pixel_size / res)
    deloc_pt_x = ctf_x[deloc_pt]
    deloc_pt_y = deloc(box_size, pixel_size, c_s, energy,
                       np.max(np.abs(def_list)), deloc_pt)

    with ctf_fig.batch_update(), thresh_fig.batch_update(), minbox_fig.batch_update(), deloc_fig.batch_update():
        for idx, dz in enumerate(def_list):
            jdx = idx + 1
            color = get_color(idx)
            
            if dz == -1:
                ctf_y = [0 for x in range(k_max)]
                ctf_fig.data[idx].x = ctf_x
                ctf_fig.data[idx].y = ctf_y
                ctf_fig.data[idx].name = '{0:.2f}'.format(dz / 10000)
                ctf_fig.data[idx].marker.color = color
                ctf_fig.data[idx].visible = False
                
                thresh_y = [0 for x in thresh_x]
                thresh_fig.data[jdx].x = thresh_x
                thresh_fig.data[jdx].y = thresh_y
                thresh_fig.data[jdx].name = '{0:.2f}um'.format(dz / 10000)
                thresh_fig.data[idx].marker.color = color
                thresh_fig.data[jdx].visible = False
        
                minbox_y = [0 for x in minbox_x]
                minbox_fig.data[jdx].x = minbox_x
                minbox_fig.data[jdx].y = minbox_y
                minbox_fig.data[jdx].name = '{0:.2f}um'.format(dz / 10000)
                minbox_fig.data[idx].marker.color = color
                minbox_fig.data[jdx].visible = False
                
                deloc_y = [0 for x in range(k_max)]
                deloc_fig.data[jdx].x = ctf_x
                deloc_fig.data[jdx].y = ctf_y
                deloc_fig.data[jdx].name = '{0:.2f}'.format(dz / 10000)
                deloc_fig.data[jdx].marker.color = color
                deloc_fig.data[jdx].visible = False
                
            else:
                ctf_y = [ctf_1d(box_size, pixel_size, c_s, 
                                energy, q0, dz, x) for x in range(k_max)]
                
                ctf2_sum = [x[0] + x[1]**2 for x in zip(ctf2_sum, ctf_y)]
                ctf2_norm = ctf2_norm + 1
                
                ctf_fig.data[idx].x = ctf_x
                ctf_fig.data[idx].y = ctf_y
                ctf_fig.data[idx].name = '{0:.2f}'.format(dz / 10000)
                ctf_fig.data[idx].marker.color = color
                ctf_fig.data[idx].visible = True
                
                thresh_y = [ctf_thresh(x, pixel_size, c_s,
                                       energy, q0, dz) for x in thresh_x]
                thresh_fig.data[jdx].x = thresh_x
                thresh_fig.data[jdx].y = thresh_y
                thresh_fig.data[jdx].name = '{0:.2f}um'.format(dz / 10000)
                thresh_fig.data[idx].marker.color = color
                thresh_fig.data[jdx].visible = True
        
                minbox_y = [min_box(x, c_s, energy, q0, dz) for x in minbox_x]
                minbox_fig.data[jdx].x = minbox_x
                minbox_fig.data[jdx].y = minbox_y
                minbox_fig.data[jdx].name = '{0:.2f}um'.format(dz / 10000)
                minbox_fig.data[idx].marker.color = color
                minbox_fig.data[jdx].visible = True
                
                deloc_y = [deloc(box_size, pixel_size, c_s, 
                                 energy, dz, x) for x in range(k_max)]
                
                deloc_fig.data[jdx].x = ctf_x
                deloc_fig.data[jdx].y = deloc_y
                deloc_fig.data[jdx].name = '{0:.2f}'.format(dz / 10000)
                deloc_fig.data[jdx].marker.color = color
                deloc_fig.data[jdx].visible = True
                
    
    if ctf2_norm == 0:
        ctf2_norm = 1
        
    ctf2_sum = [x / ctf2_norm for x in ctf2_sum]
    
    ctf_x_tickvals = [ctf_x[-1] / (n_xticks - 1) * x for x in range(n_xticks)]
    ctf_x_ticktext = ['Inf'] + ['{0:.2f}'.format(1 / x) for x in ctf_x_tickvals[1:]]
    
    with ctf_fig.batch_update():
        ctf_fig.layout.xaxis.tickvals = ctf_x_tickvals
        ctf_fig.layout.xaxis.ticktext = ctf_x_ticktext
    
    with ctf2_fig.batch_update():
        ctf2_fig.data[0].x = ctf_x
        ctf2_fig.data[0].y = ctf2_sum
        ctf2_fig.layout.xaxis.tickvals = ctf_x_tickvals
        ctf2_fig.layout.xaxis.ticktext = ctf_x_ticktext

    thresh_range_x = 2 * box_size - box_size_.min 
    if thresh_range_x > box_size_.max:
        thresh_range_x = box_size_.max
    
    thresh_range_y = res * 2
    if thresh_range_y > res_.max:
        thresh_range_y = res_.max

    thresh_x_tickvals = [
        box_size_.min +
        (thresh_range_x - box_size_.min) /
        (n_xticks - 1) * x for x in range(n_xticks)]

    thresh_x_ticktext = ['{0:d}'.format(round(x)) for x in thresh_x_tickvals]
    thresh_y_tickvals = [thresh_range_y / (n_yticks - 1) * x for x in range(n_yticks)]
    thresh_y_ticktext = ['{0:.2f}'.format(x) for x in thresh_y_tickvals]

    with thresh_fig.batch_update():
        thresh_fig.data[0].x = [box_size]
        thresh_fig.data[0].y = [res]
        thresh_fig.data[0].marker.color = colorway[0]
        thresh_fig.layout.shapes[0].x0 = box_size
        thresh_fig.layout.shapes[0].x1 = box_size
        thresh_fig.layout.shapes[0].y1 = res

        thresh_fig.layout.shapes[1].y0 = res
        thresh_fig.layout.shapes[1].x1 = box_size
        thresh_fig.layout.shapes[1].y1 = res
        
        thresh_fig.layout.xaxis.range = [box_size_.min, thresh_range_x]
        thresh_fig.layout.xaxis.tickvals = thresh_x_tickvals
        thresh_fig.layout.xaxis.ticktext = thresh_x_ticktext
        thresh_fig.layout.yaxis.range = [0, thresh_range_y]
        thresh_fig.layout.yaxis.tickvals = thresh_y_tickvals
        thresh_fig.layout.yaxis.ticktext = thresh_y_ticktext

    minbox_range_x = pixel_size * 2
    if minbox_range_x > pixel_size_.max:
        minbox_range_x = pixel_size_.max

    minbox_range_y = thresh_range_x
    minbox_x_tickvals = [minbox_range_x / (n_xticks - 1) * x for x in range(n_xticks)]
    minbox_x_ticktext = ['{0:.2f}'.format(x) for x in minbox_x_tickvals]
    minbox_y_tickvals = [
        box_size_.min +
        (minbox_range_y - box_size_.min) /
        (n_yticks - 1) * x for x in range(n_yticks)]

    minbox_y_ticktext = ['{0:.0f}'.format(round(x)) for x in minbox_y_tickvals]

    with minbox_fig.batch_update():
        minbox_fig.data[0].x = [pixel_size]
        minbox_fig.data[0].y = [box_size]
        minbox_fig.data[0].marker.color = colorway[0]
        minbox_fig.layout.shapes[0].x0 = pixel_size
        minbox_fig.layout.shapes[0].x1 = pixel_size
        minbox_fig.layout.shapes[0].y1 = box_size
        
        minbox_fig.layout.shapes[1].y0 = box_size
        minbox_fig.layout.shapes[1].x1 = pixel_size
        minbox_fig.layout.shapes[1].y1 = box_size
        
        minbox_fig.layout.xaxis.range = [pixel_size_.min, minbox_range_x]
        minbox_fig.layout.xaxis.tickvals = minbox_x_tickvals
        minbox_fig.layout.xaxis.ticktext = minbox_x_ticktext
        minbox_fig.layout.yaxis.range = [box_size_.min, minbox_range_y]
        minbox_fig.layout.yaxis.tickvals = minbox_y_tickvals
        minbox_fig.layout.yaxis.ticktext = minbox_y_ticktext

    deloc_pt = round(box_size * pixel_size / res)
    deloc_pt_x = ctf_x[deloc_pt]
    deloc_pt_y = deloc(box_size, pixel_size, c_s, energy,
                       np.max(np.abs(def_list)), deloc_pt)
    
    deloc_range_y = (deloc_pt_y) * 2
    if deloc_range_y > box_size_.max:
        deloc_range_y = box_size_.max

    deloc_y_tickvals = [deloc_range_y / (n_yticks - 1) * x for x in range(n_yticks)]
    deloc_y_ticktext = ['{0:.0f}'.format(round(x)) for x in deloc_y_tickvals]

    with deloc_fig.batch_update():
        deloc_fig.data[0].x = [deloc_pt_x]
        deloc_fig.data[0].y = [deloc_pt_y]
        deloc_fig.data[0].marker.color = colorway[0]
        deloc_fig.layout.shapes[0].x0 = deloc_pt_x
        deloc_fig.layout.shapes[0].x1 = deloc_pt_x
        deloc_fig.layout.shapes[0].y1 = deloc_pt_y
        
        deloc_fig.layout.shapes[1].y0 = deloc_pt_y
        deloc_fig.layout.shapes[1].x1 = deloc_pt_x
        deloc_fig.layout.shapes[1].y1 = deloc_pt_y
        
        deloc_fig.layout.xaxis.tickvals = ctf_x_tickvals
        deloc_fig.layout.xaxis.ticktext = ctf_x_ticktext
        deloc_fig.layout.yaxis.range = [0, deloc_range_y]
        deloc_fig.layout.yaxis.tickvals = deloc_y_tickvals
        deloc_fig.layout.yaxis.ticktext = deloc_y_ticktext

def reset_clicked(b):
    box_size_.value = 128
    pixel_size_.value = 1.0
    c_s_.value = 2.7
    energy_.value = 300.0
    q0_.value = 0.1
    res_.value = 4.0
    
    for child in def_grid.children:
        child.value = -1
        
    def_grid[0, 0].value = 10000.0
    run_clicked(b)

run_button.on_click(run_clicked)
reset_button.on_click(reset_clicked)
IPython.display.display(ui, def_label, def_grid, buttons, figures)

VBox(children=(HBox(children=(BoundedIntText(value=128, description='Box size (px): ', max=4096, min=32, step=…

Label(value='Defocuses (Å): ')

GridspecLayout(children=(BoundedFloatText(value=10000.0, layout=Layout(grid_area='widget001', width='100px'), …

HBox(children=(Button(description='Run', style=ButtonStyle()), Button(description='Reset', style=ButtonStyle()…

VBox(children=(FigureWidget({
    'data': [{'hoverinfo': 'name+y',
              'line': {'color': '#7A0177', …