In [133]:
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from matplotlib import colormaps


In [134]:
def genMaskF(shape,f):
    x = np.arange(shape[0])
    y = np.arange(shape[1])
    
    
    # Create Mask
    return np.logical_and(
        np.mod(x[:,None],f) == 0,
        np.mod(y[None,:],f) == 0,
    )

def generateSpiralCoordinates(Psi0downsampled):
    mask = np.fft.fftshift(Psi0downsampled) > 0 # Select for non zero spatial frequencies

    # Get a list of the coordinates of those points
    xy = np.argwhere(mask)

    # Take the (n,2) array and calculate the center by getting the mean of points
    xy_centered = xy - np.mean(xy,axis = 0) # Calculate 

    # Calculate the radius and angle of each of the points
    r = np.sqrt(np.sum(xy_centered**2,axis = 1))
    t = np.arctan2(xy_centered[:,1],xy_centered[:,0])

    vals = r + t*0.1 # Start with radius and add a small amount dependent on angle to enable sort
    inds = np.argsort(vals) # Get an array of sorted positions coresponding to the spiral order
    return xy, inds # Return the spiral order

def genPlaneWaveStack(im_shape,coordinates):
    h,w = im_shape
        
    planeWaveStack = np.zeros((np.shape(coordinates)[0],h,w),np.complex128)
    array = np.zeros((h,w))

    for i in range(np.shape(coordinates)[0]):
        array *= 0
        (x,y) = coordinates[i,:]
        # print(f"{x}, {y}")
        array[x][y] = 1
        planeWaveStack[i,:,:] = np.fft.fft2(np.fft.fftshift(array))
    return planeWaveStack

In [135]:
# Give initial parameters to generate the probe
pixel_size = 0.1 # All units are in angstroms for computational simulations
im_shape = (512,512) # Pixel array size
probe_mrads = 20 # Max alpha
wavl = 0.02 # Wavelength (angstroms) Corresponds to ~300keV

# Make the coordinates
# Coordinates
kx = np.fft.fftfreq(im_shape[0], pixel_size) # Creates a frequency array scaled to the inverse of the real value passed
ky = np.fft.fftfreq(im_shape[1], pixel_size) # Creates a frequency array scaled to the inverse of the real value passed
kr2 = kx[:,None]**2+ky[None,:]**2
kr = kr2**0.5

# Convert the k_max into radians from miliradians
k_max = probe_mrads/1000/wavl

# Create the central clipped circle
Psi0 = np.clip((k_max-kr)/(kx[1]-kx[0])+0.5,0,1)

# Generate the real space version
psi0 = np.fft.ifft2(Psi0)


In [136]:
# Set parameters and generate the required holders for plotting
f = 8
Psi0downsampled = Psi0*genMaskF(np.shape(Psi0), f) # Apply f sampling mask
xy, inds = generateSpiralCoordinates(Psi0downsampled) # Generate the spiral coordinates

spiralCoords = xy[inds] # Grab the specific coordinates from the index order of the indices array
planeWaveStack = genPlaneWaveStack(im_shape, spiralCoords) # Generate a plane wave stack that matches

# Generate place holder values that will be used to update the figures
probeTotal = np.zeros(np.shape(Psi0downsampled),dtype=np.complex128)
pupilTotal = np.zeros(np.shape(Psi0downsampled))

In [None]:
# Main visualization
plt.close('all')
plt.ioff()

# Initialize the Subplot to be used in the interactive demo
dpi = 72
fig, ax = plt.subplots(1,3,figsize=(900/dpi, 300/dpi), dpi=dpi)

# Create the initialization plot data
pupilTotal[(spiralCoords[0][0],spiralCoords[0][1])] = np.nan
cmap = colormaps.get_cmap('gray')
cmap.set_bad("red")
ax[0].clear()
ax[0].set_title(f"Pupil Plane f = 8")
ax[0].set_xlim(im_shape[0]/2-64,im_shape[0]/2+64)
ax[0].set_ylim(im_shape[1]/2-64,im_shape[1]/2+64)
ax[0].set_ylabel("$k_{y}$")
ax[0].set_xlabel("$k_{x}$")
ax[0].tick_params(left = False, right = False , labelleft = False , labelbottom = False, bottom = False) 
im_pupil = ax[0].imshow(pupilTotal, cmap=cmap,vmin=0,vmax=1)

ax[1].clear()
ax[1].set_title(f"Plane Wave of Spatial Frequency Added")
ax[1].set_xlim(0,im_shape[0])
ax[1].set_ylim(1,im_shape[0])
ax[1].tick_params(left = False, right = False , labelleft = False , labelbottom = False, bottom = False)
im_plane_wave = ax[1].imshow(np.real(planeWaveStack[0]),vmin=0,vmax=1,cmap='turbo') 

# Add the probe to the total
ax[2].clear()
ax[2].set_title(f"Probe on Sample")
ax[2].tick_params(left = False, right = False , labelleft = False , labelbottom = False, bottom = False) 
ax[2].autoscale()
im_probe_sum = ax[2].imshow(np.real(planeWaveStack[0]),cmap='turbo')


In [138]:
pupilTotal = np.zeros(im_shape)

In [147]:
# Define the update function to updated after the slider trigger
def update_display(change):
    # Change when passed is a new updated version of the observed object
    index = change['new']

    # Update Pupil Plot
    pupilTotal = np.zeros(im_shape)
    pupilTotal[spiralCoords[:index,0],spiralCoords[:index,1]] = 1
    pupilTotal[spiralCoords[index][0],spiralCoords[index][1]] = np.nan
    im_pupil.set_data(pupilTotal)
    
    # Update planewave displated
    im_plane_wave.set_data(np.real(planeWaveStack[index]))
    
    # Update Probe Total
    # im_probe_sum.set_data(np.real(np.sum(planeWaveStack[:index+1],0))/np.max(np.real(np.sum(planeWaveStack[:index+1],0))))
    im_probe_sum.set_data(np.real(np.sum(planeWaveStack[:index+1],0)))
    im_probe_sum.autoscale()
    # Update the figure
    fig.canvas.draw_idle()
    
    return None
    
zoomArray = [1,2,4,8,16]
# Define the update function to update the zoom using the slider
def update_zoom(change):
    # Grab the zoom amount in text
    zoom = change['new']
    ax[2].set_xlim(im_shape[0]/2-255/zoom,im_shape[0]/2+255/zoom)
    ax[2].set_ylim(im_shape[1]/2-255/zoom,im_shape[1]/2+255/zoom)
    fig.canvas.draw_idle()
    return None
    
    
    
    
    
    

In [None]:
# Define widget parameters
fps = 2
num_frames = len(planeWaveStack)


# Create the sliders and then display the widget

# Create frame slider
zoom_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=8,
    description='Zoom:',
    continuous_update=False,  # Only update when slider is released
    style = {'description_width': 'initial'},
    layout=widgets.Layout(width='500px'),
)

# # Create play button
play_button = widgets.Play(
    value=0,
    min=0,
    max=num_frames-1,
    step=1,
    interval=1000//fps,
    description="Play"
)

# Link play button to slider
# widgets.jslink((play_button, 'value'), (frame_slider, 'value'))
zoom_slider.observe(update_zoom, names='value')
play_button.observe(update_display, names="value")

# Create controls container
controls = widgets.HBox([play_button,zoom_slider])

# Combine everything into a single widget
player = widgets.VBox([controls,fig.canvas])


In [150]:
# | label: app:time_evolution
# Time Evolution
display(player)

VBox(children=(HBox(children=(Play(value=103, description='Play', interval=500, max=136, playing=True), IntSli…