In [1]:
import os,sys
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import torch
import datetime
import glob

In [2]:
# set this up to point to the libararies directory in ml-holodec
dirP_str = os.path.join(os.environ['HOME'], 
                    'Python', 
                    'holodec-ml',
                    'library')
if dirP_str not in sys.path:
    sys.path.append(dirP_str)

In [3]:
import torch_optics_utils as optics

In [27]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")
# device = torch.device("cpu")

if is_cuda:
    torch.backends.cudnn.benchmark = True

print(f'Preparing to use device {device}')

Preparing to use device cpu


In [28]:
dtype = torch.complex64  # fft required data type

Load some data to use for input images

In [29]:
data_dir = '/glade/p/cisl/aiml/ai4ess_hackathon/holodec/'

In [30]:
# list all the netcdf files in data_dir
file_list = glob.glob(data_dir+'*.nc')
for f_idx,file in enumerate(file_list):
    print(f'{f_idx}.) '+file.split('/')[-1])

0.) synthetic_holograms_7particle_gamma_600x400_training.nc
1.) synthetic_holograms_10particle_gamma_512x512_validation_patches128x128.nc
2.) synthetic_holograms_50-100particle_bidisperse_test.nc
3.) synthetic_holograms_multiparticle_validation.nc
4.) synthetic_holograms_1particle_training_small.nc
5.) synthetic_holograms_6particle_gamma_600x400_test.nc
6.) synthetic_holograms_multiparticle_training.nc
7.) synthetic_holograms_50-100particle_gamma_private.nc
8.) synthetic_holograms_12-25particle_gamma_600x400_validation.nc
9.) synthetic_holograms_6particle_gamma_600x400_training.nc
10.) synthetic_holograms_50-100particle_gamma_training.nc
11.) synthetic_holograms_1particle_gamma_600x400_training.nc
12.) synthetic_holograms_4particle_gamma_600x400_validation.nc
13.) synthetic_holograms_3particle_validation.nc
14.) synthetic_holograms_10particle_gamma_600x400_test.nc
15.) synthetic_holograms_1particle_gamma_600x400_validation.nc
16.) synthetic_holograms_10particle_gamma_600x400_training.n

In [31]:
# pick the file
f_sel = 29  # selected file index (29 for real holograms)
dataFile = file_list[f_sel]

In [32]:
print('loading '+dataFile.split('/')[-1])
h_ds = xr.open_dataset(dataFile)  # open the simulated data file

loading real_holograms_CSET_RF07_20150719_203600-203700.nc


In [33]:
holo_idx = 0  # pick the hologram index to use

# define the input tensor based as the selected input hologram
E_input = torch.tensor(h_ds['image'].isel(hologram_number=holo_idx).values,device=device,dtype=dtype)[None,:,:]

These are inputs are needed to perform propagation/reconstruction

In [34]:
dx = h_ds.attrs['dx']      # horizontal resolution
dy = h_ds.attrs['dy']      # vertical resolution
Nx = int(h_ds.attrs['Nx']) # number of horizontal pixels
Ny = int(h_ds.attrs['Ny']) # number of vertical pixels
lam = h_ds.attrs['lambda'] # laser wavelength

In [35]:
# create the frequency axes on the "device" needed for the reconstruction calculation
fx = torch.fft.fftfreq(Nx,dx,device=device)[None,:,None]
fy = torch.fft.fftfreq(Ny,dy,device=device)[None,None,:]

These inputs are used to define which and how many planes we reconstruct in this demo

In [36]:
Nplanes = 1000 #  number of z planes we want to reconstruct between z min and z max

In [37]:
# simulation/hardware definitions for z ranges
zMin = 0.014 # h_ds.attrs['zMin']  # minimum z in sample volume
zMax = 0.158 # h_ds.attrs['zMax']  # maximum z in sample volume
zCCD = 0                   # z position of the image plane

In [38]:
# define the z position of planes we want to reconstruct
z_plane = torch.linspace(zMin,zMax,Nplanes,device=device)[:,None,None]

Calculate the electric field at each requested plane (z_plane)
The output tensor has dimensions (z,x,y)

In [16]:
start_time = datetime.datetime.now()
Eres = optics.torch_holo_set(E_input,fx,fy,z_plane,lam).detach().cpu().numpy()
end_time = datetime.datetime.now()

In [18]:
exec_time = end_time-start_time
print(f'{E_input.shape[1]} x {E_input.shape[2]} image')
print(f'executed {z_plane.shape[0]} planes in {exec_time.total_seconds()} seconds')
print(f' for {exec_time.total_seconds()/z_plane.shape[0]} seconds per plane')

4872 x 3248 image
executed 10 planes in 0.75647 seconds
 for 0.07564699999999999 seconds per plane


In [39]:
batch_size = 10  # number of planes per gpu batch

In [40]:
plane_lst = []
start_time = datetime.datetime.now()
for batch_idx in range(int(np.ceil(Nplanes/batch_size))):
    plane_lst.append(optics.torch_holo_set(E_input,fx,fy,z_plane[batch_idx*batch_size:(batch_idx+1)*batch_size],lam).detach().cpu().numpy())
end_time = datetime.datetime.now()

In [26]:
exec_time = end_time-start_time
print(f'{E_input.shape[1]} x {E_input.shape[2]} image')
print(f'executed {z_plane.shape[0]} planes in {exec_time.total_seconds()} seconds')
print(f' for {exec_time.total_seconds()/z_plane.shape[0]} seconds per plane')

4872 x 3248 image
executed 1000 planes in 70.8821 seconds
 for 0.07088209999999999 seconds per plane


In [41]:
exec_time = end_time-start_time
print(f'{E_input.shape[1]} x {E_input.shape[2]} image')
print(f'executed {z_plane.shape[0]} planes in {exec_time.total_seconds()} seconds')
print(f' for {exec_time.total_seconds()/z_plane.shape[0]} seconds per plane')

4872 x 3248 image
executed 1000 planes in 784.245014 seconds
 for 0.784245014 seconds per plane
