In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import datetime

In [None]:
# set this up to point to the libararies directory in ml-holodec
dirP_str = os.path.join(os.environ['HOME'], 
                    'Python', 
                    'holodec-ml',
                    'library')
if dirP_str not in sys.path:
    sys.path.append(dirP_str)

In [None]:
import torch_optics_utils as optics

In [None]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

if is_cuda:
    torch.backends.cudnn.benchmark = True

print(f'Preparing to use device {device}')

In [None]:
dtype = torch.complex64  # fft required data type

Load some data to use for input images

In [None]:
data_dir = '/glade/p/cisl/aiml/ai4ess_hackathon/holodec/'

In [None]:
# list all the netcdf files in data_dir
file_list = glob.glob(data_dir+'*.nc')
for f_idx,file in enumerate(file_list):
    print(f'{f_idx}.) '+file.split('/')[-1])

In [None]:
# pick the file
f_sel = 29  # selected file index
dataFile = file_list[f_sel]

In [None]:
print('loading '+dataFile.split('/')[-1])
h_ds = xr.open_dataset(dataFile)  # open the simulated data file

In [None]:
holo_idx = 0  # pick the hologram index to use

# define the input tensor based as the selected input hologram
E_input = torch.tensor(h_ds['image'].isel(hologram_number=holo_idx).values,device=device,dtype=dtype)[None,:,:]

These are inputs are needed to perform propagation/reconstruction

In [None]:
dx = h_ds.attrs['dx']      # horizontal resolution
dy = h_ds.attrs['dy']      # vertical resolution
Nx = int(h_ds.attrs['Nx']) # number of horizontal pixels
Ny = int(h_ds.attrs['Ny']) # number of vertical pixels
lam = h_ds.attrs['lambda'] # laser wavelength

In [None]:
# create the frequency axes on the "device" needed for the reconstruction calculation
fx = torch.fft.fftfreq(Nx,dx,device=device)[None,:,None]
fy = torch.fft.fftfreq(Ny,dy,device=device)[None,None,:]

These inputs are used to define which and how many planes we reconstruct in this demo

In [None]:
Nplanes = 1000 #  number of z planes we want to reconstruct between z min and z max

In [None]:
# simulation/hardware definitions for z ranges
zMin = h_ds.attrs['zMin']  # minimum z in sample volume
zMax = h_ds.attrs['zMax']  # maximum z in sample volume
zCCD = 0                   # z position of the image plane

In [None]:
# define the z position of planes we want to reconstruct
z_plane = torch.linspace(zMin,zMax,Nplanes,device=device)[:,None,None]

Calculate the electric field at each requested plane (z_plane)
The output tensor has dimensions (z,x,y)

In [None]:
start_time = datetime.datetime.now()
Eres = optics.torch_holo_set(E_input,fx,fy,z_plane,lam).detach().cpu().numpy()
end_time = datetime.datetime.now()

In [None]:
exec_time = end_time-start_time
print(f'{E_input.shape[0]} x {E_input.shape[1]} image')
print(f'executed {z_plane.size} planes in {exec_time.total_seconds()} seconds')
print(f' for {exec_time.total_seconds()/z_plane.size} seconds per plane')