In [4]:
%matplotlib widget
import numpy as np
import cupy as cp

import matplotlib.pyplot as plt 
from matplotlib_scalebar.scalebar import ScaleBar
from mpl_toolkits.axes_grid1 import make_axes_locatable

import realtime_ptycho as sm
from realtime_ptycho.core import Sparse4DData, Metadata4D
from realtime_ptycho.util import get_qx_qy_1D, get_qx_qy_2D, disk_overlap_function, plotcxmosaic, sector_mask, plot, single_sideband_reconstruction, imsave, mosaic, sparse_to_dense_datacube, wavelength

from numpy.fft import fftshift

from pathlib import Path

from tifffile import  imwrite
import time

from ipywidgets import AppLayout, FloatSlider, GridspecLayout, VBox
import ipywidgets as widgets

import time

import cupy as cp
from cupyx.scipy.fft import fft2

from tqdm import trange
from skimage.filters import gaussian

import os

In [5]:
scan_number = 147

base_path = Path(os.getcwd())
adfpath = base_path
sparse_path = base_path  
results_path = base_path / 'results/'

if not results_path.exists():
    results_path.mkdir()
    
filename4d = sparse_path / f'data_scan{scan_number}_th4.0_electrons.h5'
filenameadf = adfpath / f'scan{scan_number}.dm4'

alpha_max_factor = 1.2
alpha_max_factor = 1.05

In [7]:
print('1: data loading')
d = Sparse4DData.from_4Dcamera_file(filename4d)
metadata = Metadata4D.from_dm4_file(filenameadf)

metadata.alpha_rad = 25e-3
metadata.rotation_deg = 0
metadata.wavelength =  wavelength(metadata.E_ev)  

center, radius = d.determine_center_and_radius(manual=False, size=200) 
print(f'center: {center}')
print(f'radius: {radius}')
print('2: cropping')
d.crop_symmetric_center_(center, radius*alpha_max_factor)
print('3: sum diffraction pattern')
s = d.sum_diffraction()
print('4: plotting')

f,ax = plt.subplots(1,2,figsize=(8,4))
imax = ax[0].imshow(s)
ax[0].set_title(f'Scan {scan_number} sum after cropping')
imax = ax[1].imshow(np.log10(s+1))
ax[1].set_title(f'Scan {scan_number} log10(sum) after cropping')
plt.colorbar(imax)
plt.tight_layout()


1: data loading
center: [292.20951103 160.73790153]
radius: 123.21053157858441
2: cropping
old frames frame_dimensions: [576 576]
new frames frame_dimensions: [258 258]
3: sum diffraction pattern
4: plotting


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
abf = d.virtual_annular_image(radius/2, radius, d.frame_dimensions/2)
bf = d.virtual_annular_image(0, radius/2, d.frame_dimensions/2)
eabf = abf - bf
adf = d.virtual_annular_image(radius, d.frame_dimensions[0]/2, d.frame_dimensions/2)

bf[bf==0] = bf.mean()
abf[abf==0] = abf.mean()

fig, ax = plt.subplots(dpi=150)
im = ax.imshow(abf, cmap= plt.cm.get_cmap('bone'))
ax.set_title(f'Scan {scan_number} ABF')
ax.set_xticks([])
ax.set_yticks([])
ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
plt.tight_layout()
fig, ax = plt.subplots(dpi=150)
im = ax.imshow(bf, cmap= plt.cm.get_cmap('bone'))
ax.set_title(f'Scan {scan_number} BF')
ax.set_xticks([])
ax.set_yticks([])
ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
plt.tight_layout()
fig, ax = plt.subplots(dpi=150)
im = ax.imshow(adf, cmap= plt.cm.get_cmap('bone'))
ax.set_title(f'Scan {scan_number} ADF')
ax.set_xticks([])
ax.set_yticks([])
ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax)
plt.tight_layout()



Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
dwell_time = 1/87e3
detector_to_real_fluence_80kv = 1 

fluence = d.fluence(metadata.dr[0]) * detector_to_real_fluence_80kv
flux = d.flux(metadata.dr[0], dwell_time) * detector_to_real_fluence_80kv

print(f"E               = {metadata.E_ev/1e3}             keV")
print(f"λ               = {metadata.wavelength * 1e2:2.2}   pm")
print(f"dR              = {metadata.dr} Å")
print(f"scan       size = {d.scan_dimensions}")
print(f"detector   size = {d.frame_dimensions}")
print(f"scan       FOV  = {d.scan_dimensions*metadata.dr/10} nm")
print(f"fluence         ~ {fluence} e/Å^2")
print(f"flux            ~ {flux} e/Å^2/s")

E               = 80.0             keV
λ               = 4.2   pm
dR              = [0.31626087 0.31626087] Å
scan       size = [512 512]
detector   size = [258 258]
scan       FOV  = [16.19255676 16.19255676] nm
fluence         ~ 21622.471675663513 e/Å^2
flux            ~ 7176.036971217062 e/Å^2/s


In [11]:
dssb = d
metadata.k_max = metadata.alpha_rad * alpha_max_factor / metadata.wavelength
s = dssb.sum_diffraction()
f,ax = plt.subplots(figsize=(4,4))
imax = ax.imshow(s)
ax.set_title('Sum after cropping for SSB')
plt.colorbar(imax)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.colorbar.Colorbar at 0x7efcdd14eb80>

In [12]:
slic = np.s_[:,:]
data = dssb.slice(slic)

ssb_size = np.array([15,15])
bin_factor = int(np.min(np.floor(data.frame_dimensions/ssb_size)))
radius2 = radius/bin_factor
meta = metadata
verbose = True

start = time.perf_counter()
dc = sparse_to_dense_datacube(data.indices, data.counts, data.scan_dimensions, data.frame_dimensions, data.frame_dimensions/2, data.frame_dimensions[0]/2, data.frame_dimensions[0]/2, binning=bin_factor, fftshift=False)
print(f"Bin by {bin_factor} for ssb took {time.perf_counter() - start}s")

rmax = dc.shape[-1] // 2
alpha_max = rmax / radius2 * meta.alpha_rad

r_min = meta.wavelength / (2 * alpha_max)
r_min = [r_min, r_min]
k_max = [alpha_max / meta.wavelength, alpha_max / meta.wavelength]
r_min1 = np.array(r_min)
dxy1 = np.array(meta.dr)

M = cp.array(dc).astype(cp.float32)
xp = cp.get_array_module(M)
ny, nx, nky, nkx = M.shape

Qx1d, Qy1d = get_qx_qy_1D([nx, ny], meta.dr, M.dtype, fft_shifted=False)

start = time.perf_counter()
G = fft2(M, axes=(0, 1), overwrite_x=True)
G /= cp.sqrt(np.prod(G.shape[:2]))
print(f"FFT along scan coordinate took {time.perf_counter() - start}s")

radius_data_int : 136 
radius_max_int  : 136 
dense frame size: 16x 16
Bin by 17 for ssb took 0.3138990880106576s
FFT along scan coordinate took 0.21235918899765238s


In [13]:
manual_frequencies = None  # [[20, 62, 490], [454, 12, 57]]

Gabs = xp.log10(xp.sum(xp.abs(G), (2, 3)))
sh = np.array(Gabs.shape)
mask = ~np.array(fftshift(sector_mask(sh, sh // 2, 5, (0, 360))))
mask[:,-1] = 0
mask[:,0] = 0
mask[:,1] = 0

gg = Gabs.get()
gg[~mask] = gg.mean()
fig, ax = plt.subplots(1,2,figsize=(10,5))
ax[0].imshow(fftshift(mask))
ax[0].set_title('FFT mask')
ax[1].imshow(fftshift(gg), cmap=plt.cm.get_cmap('inferno'))
ax[1].set_title('Masked absolute values of G')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Masked absolute values of G')

In [15]:
Gabs = xp.sum(xp.abs(G), (2, 3))
sh = np.array(Gabs.shape)

n_fit=16
meta.rotation_deg = 0.0
best_angle = meta.rotation_deg
aberrations = xp.zeros((12))
aberrations[2] = 0

gg = Gabs.get() * mask
gg[gg==0] = gg.mean()

inds = xp.argsort((gg).ravel()) 
strongest_object_frequencies = np.unravel_index(inds[-1 - n_fit:-1], G.shape[:2])
# print(strongest_object_frequencies[0])
# print(strongest_object_frequencies[1])
strongest_object_frequencies = strongest_object_frequencies
G_max = G[strongest_object_frequencies]

r_min1 = np.array(r_min)
dxy1 = np.array(meta.dr)

r_min1 *= 1
dxy1 *= 1.0
Kx, Ky = get_qx_qy_1D([nkx, nky], r_min1, G[0, 0, 0, 0].real.dtype, fft_shifted=True)
print(strongest_object_frequencies)
print(Kx)
print(Ky)
print([nx, ny], dxy1)
Qx1d, Qy1d = get_qx_qy_1D([nx, ny], dxy1, G[0, 0, 0, 0].real.dtype, fft_shifted=False)
Qy_max1d = Qy1d[strongest_object_frequencies[0]]
Qx_max1d = Qx1d[strongest_object_frequencies[1]]
print(Qx1d.max())
print(Qy1d.max())
print(Qy_max1d)
print(Qx_max1d)

Gamma = disk_overlap_function(Qx_max1d, Qy_max1d, Kx, Ky, aberrations, best_angle, meta.alpha_rad, meta.wavelength)

fig, ax = plt.subplots(1,3,figsize=(19,6))
im = ax[0].imshow(np.log10(fftshift(gg)+1), cmap= plt.cm.get_cmap('bone'))
ax[0].set_title(f'Scan {1} fft')
ax[0].set_xticks([])
ax[0].set_yticks([])

im = ax[1].imshow(imsave(mosaic(G_max.get() * Gamma.get())), cmap= plt.cm.get_cmap('bone'))
ax[1].set_title(f'Scan {1} double overlap')
ax[1].set_xticks([])
ax[1].set_yticks([])
divider = make_axes_locatable(ax[1])

im = ax[2].imshow(imsave(mosaic(G_max.get())), cmap= plt.cm.get_cmap('bone'))
ax[2].set_title(f'Scan {1} double overlap')
ax[2].set_xticks([])
ax[2].set_yticks([])
divider = make_axes_locatable(ax[1])

plt.tight_layout()
fig.savefig(results_path /f'scan{1}_fft.png')


(array([  7, 477,  35, 507,   5, 479,  33, 473,  39, 496,  16, 481,  31,
       507,   5, 473]), array([  3, 495,  17, 509,   3, 495,  17, 493,  19,  42, 470, 495,  17,
       510,   2, 495]))
[-0.66084564 -0.5782399  -0.4956342  -0.4130285  -0.33042282 -0.2478171
 -0.16521141 -0.0826057   0.          0.0826057   0.16521141  0.2478171
  0.33042282  0.4130285   0.4956342   0.5782399 ]
[-0.66084564 -0.5782399  -0.4956342  -0.4130285  -0.33042282 -0.2478171
 -0.16521141 -0.0826057   0.          0.0826057   0.16521141  0.2478171
  0.33042282  0.4130285   0.4956342   0.5782399 ]
[512, 512] [0.31626087 0.31626087]
1.5747976
1.5747976
[ 0.04322974 -0.21614869  0.21614869 -0.03087839  0.03087839 -0.20379734
  0.20379734 -0.2408514   0.2408514  -0.09881083  0.09881083 -0.19144599
  0.19144599 -0.03087839  0.03087839 -0.2408514 ]
[ 0.01852703 -0.10498651  0.10498651 -0.01852703  0.01852703 -0.10498651
  0.10498651 -0.11733786  0.11733786  0.25937843 -0.25937843 -0.10498651
  0.10498651 -0.012351

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:

plt.ioff()
C = xp.zeros((12,))
C_names = ['C1','C12a' ,'C12b','C21a','C21b','C23a','C23b','C3','C32a','C32b','C34a','C34b']
C_min = [-50,-20,-20,-50,-50,-50,-50,-20,-20,-20,-20,-20]
C_max = [50,20,20,50,50,50,50,20,20,20,20,20]
C_multiplier = [1e1,1e1,1e1,1e1,1e1,1e1,1e4,1e4,1e4,1e4,1e4]

gs = GridspecLayout(4,9)
Cslider_box = VBox(width=50)
scale_slider_box = VBox()
children= []
sliders =  []

text = widgets.HTML(
    value="1",
    placeholder='',
    description='',
)

overlaps_output = widgets.Output()
# recon_output = widgets.Output()

Psi_Qp = cp.zeros((ny, nx), dtype=G.dtype)
Psi_Qp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Qp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Rp = cp.zeros((ny, nx), dtype=G.dtype)
Psi_Rp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Rp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)

start = time.perf_counter()
eps = 1e-3
single_sideband_reconstruction(
    G,
    Qx1d,
    Qy1d,
    Kx,
    Ky,
    aberrations,
    best_angle,
    meta.alpha_rad,
    Psi_Qp,
    Psi_Qp_left_sb,
    Psi_Qp_right_sb,
    eps,
    meta.wavelength,
)

print(f"SSB took {time.perf_counter() - start}")
Psi_Rp[:] = ifft2(Psi_Qp, norm="ortho")
Psi_Rp_left_sb[:] = ifft2(Psi_Qp_left_sb, norm="ortho")
Psi_Rp_right_sb[:] = ifft2(Psi_Qp_right_sb, norm="ortho")

Gamma = disk_overlap_function(Qx_max1d, Qy_max1d, Kx, Ky, aberrations, best_angle, meta.alpha_rad, meta.wavelength)
gg = Gamma * G_max

overlap_figure_axes = []
with overlaps_output:
    overlap_figure = plt.figure(constrained_layout=True,figsize=(7,7))
    gs1 = overlap_figure.add_gridspec(3, 3, wspace=0.05,hspace=0.05)
    for d, ggs in zip(gg[:9], gs1):
        f3_ax1 = overlap_figure.add_subplot(ggs)
        imax2 = f3_ax1.imshow(imsave(d.get()))
        f3_ax1.set_xticks([])
        f3_ax1.set_yticks([])
        overlap_figure_axes.append(imax2)

#     plt.show(fig3)

plot_box = VBox(children =[overlap_figure.canvas])    

recon_fig, recon_axes = plt.subplots(figsize=(9,9))
recon_img = recon_axes.imshow(np.angle(Psi_Rp.get()), cmap=plt.get_cmap('bone'))
recon_axes.set_xticks([])
recon_axes.set_yticks([])
scalebar = ScaleBar(meta.dr[0]/10,'nm') # 1 pixel = 0.2 meter
recon_axes.add_artist(scalebar)
plt.tight_layout()

def create_function(name, i, multiplier):
    def func1(change):
        C[i] = change['new']*multiplier
        Psi_Qp[:] = 0
        Psi_Qp_left_sb[:] = 0
        Psi_Qp_right_sb[:] = 0
        single_sideband_reconstruction(
            G,
            Qx1d,
            Qy1d,
            Kx,
            Ky,
            C,
            best_angle,
            meta.alpha_rad,
            Psi_Qp,
            Psi_Qp_left_sb,
            Psi_Qp_right_sb,
            eps,
            meta.wavelength,
        )
        m = 5

        Psi_Rp[:] = ifft2(Psi_Qp, norm="ortho")
        Psi_Rp_left_sb[:] = ifft2(Psi_Qp_left_sb, norm="ortho")
        Psi_Rp_right_sb[:] = ifft2(Psi_Qp_right_sb, norm="ortho")

        # Psi_Rp = fft.ifft2(Psi_Qp_left_sb, norm="ortho")
        img = np.angle(Psi_Rp_left_sb.get()[m:-m,m:-m])
        recon_img.set_data(img)
        recon_img.set_clim(img.min(),img.max())
        recon_fig.canvas.draw()
        recon_fig.canvas.flush_events()

        Gamma = disk_overlap_function(Qx_max1d, Qy_max1d, Kx, Ky, C, best_angle, meta.alpha_rad, meta.wavelength)
        gg = Gamma * G_max
        for ax, g in zip(overlap_figure_axes,gg):
            ax.set_data(imsave(g.get()))
        overlap_figure.canvas.draw()
        overlap_figure.canvas.flush_events()
        text.value = f'{C[i]}'
    func1.__name__ = name
    return func1

for i, (name, mins, maxs, multiplier) in enumerate(zip(C_names, C_min, C_max, C_multiplier)):
    s = FloatSlider(description=name,
                   min=mins, max = maxs)
    s.observe(create_function(f'slider_changed_{i}', i, multiplier), names='value')
    sliders.append(s)
    children.append(s)

Cslider_box.children = children + [text]

gs[:2,0] = Cslider_box
gs[2:,0] = scale_slider_box
gs[:,1:5] = plot_box
gs[:,5:9] = recon_fig.canvas

AppLayout(center=gs)



SSB took 0.21659923999686725


AppLayout(children=(GridspecLayout(children=(VBox(children=(FloatSlider(value=0.0, description='C1', max=50.0,…

In [19]:
data = dssb 
bright_field_radius = radius

ssb_size = np.array([32,32])
bin_factor = int(np.min(np.floor(data.frame_dimensions/ssb_size)))
meta = metadata

Psi_Qp = cp.zeros((ny, nx), dtype=G.dtype)
Psi_Qp_left_sb = cp.zeros((ny, nx), dtype=np.complex64)
Psi_Qp_right_sb = cp.zeros((ny, nx), dtype=np.complex64)

print(f'defocus: {C[0]}')

eps = 1e-3
single_sideband_reconstruction(
    G,
    Qx1d,
    Qy1d,
    Kx,
    Ky,
    C,
    best_angle,
    meta.alpha_rad,
    Psi_Qp,
    Psi_Qp_left_sb,
    Psi_Qp_right_sb,
    eps,
    meta.wavelength,
)

Psi_Rp_left_sb = ifft2(Psi_Qp_left_sb, norm="ortho")
Psi_Rp_right_sb = ifft2(Psi_Qp_right_sb, norm="ortho")
Psi_Rp = ifft2(Psi_Qp, norm="ortho")

ssb_defocal = Psi_Rp.get()
ssb_defocal_right = Psi_Rp_right_sb.get()
ssb_defocal_left = Psi_Rp_left_sb.get()

defocus: -53.0


In [21]:

th= 0.03
fig, ax = plt.subplots(dpi=300)
im1 = ax.imshow(np.angle(ssb_defocal_right), cmap= plt.cm.get_cmap('bone'))
ax.set_title(f'Scan {scan_number} SSB ptychography')
ax.set_xticks([])
ax.set_yticks([])
ax.add_artist(ScaleBar(metadata.dr[0]/10,'nm'))
fig.savefig(results_path / 'ptycho1.pdf')
AppLayout(center=fig.canvas)

AppLayout(children=(Canvas(layout=Layout(grid_area='center'), toolbar=Toolbar(toolitems=[('Home', 'Reset origi…

In [23]:
imwrite(results_path /f'scan{scan_number}_ssb_ptycho_best_right.tif',np.angle(ssb_defocal_right).astype('float32'), imagej=True, resolution=(1./(metadata.dr[0]/10), 1./(metadata.dr[1]/10)), metadata={'spacing': 1 / 10, 'unit': 'nm', 'axes': 'YX'})