In [None]:
import sys
import datetime
import numpy as np
import xarray as xr
import dask.array as da

sys.path.append('../../holodecml/')

import FourierOpticsLib as FO

In [None]:
histogram_edges = np.arange(0,200,5)
histogram_centers = 0.5*np.diff(histogram_edges) + histogram_edges[:-1]

run_date_str = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')

path_data = '/glade/scratch/mhayman/holodec/holodec-ml-data/'
path_save = '/glade/p/cisl/aiml/ggantos/holodec/ft_rad_bidis/'
data_file_list = ['synthetic_holograms_50-100particle_bidisperse_training.nc',
                  'synthetic_holograms_50-100particle_bidisperse_validation.nc',
                  'synthetic_holograms_50-100particle_bidisperse_test.nc']

max_hist_count = 5000
FourierTransform = True
log_hist = False
log_in = False


In [None]:
for fn in data_file_list:

    # load the dataset file
    with xr.open_dataset(path_data+fn, chunks={'hologram_number':1}) as ds:
        # pre-process training data
        # generate a histogram for each image
        # initialize the particle property histogram bins

        print(type(ds))
        hologram_count = ds['hologram_number'].values.size
        print(hologram_count)

In [None]:
ypix = np.arange(ds.coords['xsize'].size)-ds.coords['xsize'].size//2
xpix = np.arange(ds.coords['ysize'].size)-ds.coords['ysize'].size//2
rpix = np.sqrt(xpix[np.newaxis,:]**2+ypix[:,np.newaxis]**2)

In [None]:
x = lambda a: a + 10
print(x(5))


In [None]:
# define function for calculating radial mean
avg_rad = lambda r : np.abs(image_ft0[(rpix >= r-.5) & (rpix < r+.5)]).mean()
# define the radial coordinate for the radial mean
rad  = np.arange(np.maximum(ypix.size//2,xpix.size//2))

In [None]:
ds['hologram_number'].values[slice(None, max_hist_count)]

In [None]:
for im in ds['hologram_number'].values[slice(None,max_hist_count)][0:2]:
    # find the particles in this hologram
    # hologram indexing is base 1
    particle_index = np.nonzero(ds['hid'].values==im+1)[0] # indices of particles for hologram in flat array of coordinates
    print(im)
    print('\t', particle_index)

In [None]:
h_moments = []
for m in [0,1,2,3,4,5,6]:
    h_moments += [np.sum((ds['d'].values[particle_index]/2)**m)]
h_moments = np.array(h_moments)

In [None]:
h_moments

In [None]:
histogram_edges

In [None]:
if FourierTransform:
    # in_chan = list(settings['input_func'].keys())

    # FT the image and store the desired operations
    image0 = ds['image'].sel(hologram_number=im)  # select the hologram image
    print(image0.shape)
    print(image0.values)
    image_ft0 = FO.OpticsFFT(image0)  # FFT the image
    print(image_ft0.shape)
    print(image_ft0)

    # calculate the radial mean of the FT
    image_ft_r_mean = np.vectorize(avg_rad)(rad)
    print(image_ft_r_mean.shape)
    image_ft_r_mean[0] = image_ft_r_mean[0]/(image_ft_r_mean.size) # rescale DC term
    print(image_ft_r_mean[0].shape)
    print(image_ft_r_mean[0])
    
    image_ft_list = [image_ft_r_mean[np.newaxis,...]/255.0]
    print(len(image_ft_list), image_ft_list[0].shape)

    image_ft = da.array(np.concatenate(image_ft_list,axis=0)[np.newaxis,...])
    print(image_ft.shape)

In [None]:
if FourierTransform:
    # in_chan = list(settings['input_func'].keys())

    # FT the image and store the desired operations
    image0 = ds['image'].sel(hologram_number=im)  # select the hologram image
    image_ft0 = FO.OpticsFFT(image0)  # FFT the image

    # calculate the radial mean of the FT
    image_ft_r_mean = np.vectorize(avg_rad)(rad)
    image_ft_r_mean[0] = image_ft_r_mean[0]/(image_ft_r_mean.size) # rescale DC term

    # perform requested operations for storage
    image_ft_list = []
    # for ik,func in enumerate(settings['input_func'].keys()):
    #     image_ft_list+=[(settings['input_func'][func](image_ft0) / settings['input_scale'][func])[np.newaxis,...]]
    #     # image_ft[func][im,:,:] = settings['input_func'][func](image_ft0) / settings['input_scale'][func]
    if log_in:
        mage_ft_list = [np.log(1e-12+image_ft_r_mean)[np.newaxis,...]/np.log(255.0)]
    else:
        image_ft_list = [image_ft_r_mean[np.newaxis,...]/255.0]

    if im == 0:
        image_ft = da.array(np.concatenate(image_ft_list,axis=0)[np.newaxis,...])
    else:
#         image_ft = da.concatenate([image_ft,np.concatenate(image_ft_list,axis=0)[np.newaxis,...]],axis=0)


In [None]:
for fn in data_file_list:

    # load the dataset file
    with xr.open_dataset(path_data+fn, chunks={'hologram_number':1}) as ds:
        # pre-process training data
        # generate a histogram for each image
        # initialize the particle property histogram bins

        hologram_count = ds['hologram_number'].values.size # 5000, 1000, 1000

        if 'training' in fn:
            file_use = '_training_'
        elif 'validation' in fn:
            file_use = '_validation_'
        elif 'test' in fn:
            file_use = '_test_'
        else:
            file_use = '_'
        
        file_base = f'histogram{file_use}data_{hologram_count}count{run_date_str}'

        print("   histogram bins: ")
        print("      "+str(histogram_centers.size)) # 39
        print("       ["+str(histogram_centers[0])+', '+str(histogram_centers[-1])+']') # '2.5, 192.5'
        print()
        
        print('   max particle size: %d'%ds['z'].values.max())
        print('   min particle size: %d'%ds['z'].values.min())
        print()

        # define x (columns) and y (rows) coordinates and calculate radial coordinate
        ypix = np.arange(ds.coords['xsize'].size)-ds.coords['xsize'].size//2 # 1200 x 0, ranges -400 to 399
        xpix = np.arange(ds.coords['ysize'].size)-ds.coords['ysize'].size//2 # 800 x 0, ranges -600 to 599
        rpix = np.sqrt(xpix[np.newaxis,:]**2+ypix[:,np.newaxis]**2) # 1200x800, ranges 0 to 721

        # define function for calculating radial mean
        avg_rad = lambda r : np.abs(image_ft0[(rpix >= r-.5) & (rpix < r+.5)]).mean()
        # define the radial coordinate for the radial mean
        rad = np.arange(np.maximum(ypix.size//2,xpix.size//2)) # 600 x 0, ranges 0 to 599

        # store the Fourier Transform and particle size histogram for each hologram
        print("Performing Fourier Transform")
        ft_start_time = datetime.datetime.now()
        for im in ds['hologram_number'].values[slice(None,max_hist_count)]: # why do you slice by max_hist_count?
            # find the particles in this hologram
            # hologram indexing is base 1
            particle_index = np.nonzero(ds['hid'].values==im+1)[0] # indices of particles for hologram in flat array of coordinates

            particle_count = ds['d'].values[particle_index].size #number of particles per hologram
            print(particle_count)
            # print(f'  found {particle_count} particles')

            h_moments = [] # what are we doing here?
            for m in [0,1,2,3,4,5,6]:
                h_moments += [np.sum((ds['d'].values[particle_index]/2)**m)]
            h_moments = np.array(h_moments)

            # make a histogram of particles and store it in the data set
            # [  0,   5,  10,  15,  20,  25,  30,  35,  40,  45,  50,  55,  60,
            # 65,  70,  75,  80,  85,  90,  95, 100, 105, 110, 115, 120, 125,
            # 130, 135, 140, 145, 150, 155, 160, 165, 170, 175, 180, 185, 190,
            # 195]
            hist0 = np.histogram(ds['d'].values[particle_index],
                        bins=histogram_edges)[0]
            
            # take log of histogram if settings
            if log_hist: # log_hist should be set to False?
                hist0 = np.log(hist0+1e-12)
            
            # create histogram if first image, otherwise concatenate histogram and moments
            if im == 0:
                histogram = da.array(hist0[np.newaxis,...])
                histogram_moments = da.array(h_moments[np.newaxis,:])
            else:
                histogram = da.concatenate([histogram,hist0[np.newaxis,...]],axis=0)     
                histogram_moments = da.concatenate([histogram_moments,h_moments[np.newaxis,:]],axis=0)   
            
            if FourierTransform:
                # in_chan = list(settings['input_func'].keys())
                
                # FT the image and store the desired operations
                image0 = ds['image'].sel(hologram_number=im)  # select the hologram image, 1200 x 800
                image_ft0 = FO.OpticsFFT(image0)  # FFT the image, 1200 x 800
                
                # calculate the radial mean of the FT
                image_ft_r_mean = np.vectorize(avg_rad)(rad) # 600 x 0
                image_ft_r_mean[0] = image_ft_r_mean[0]/(image_ft_r_mean.size) # rescale DC term, 0 x 0

                # perform requested operations for storage
                image_ft_list = []
                # for ik,func in enumerate(settings['input_func'].keys()):
                #     image_ft_list+=[(settings['input_func'][func](image_ft0) / settings['input_scale'][func])[np.newaxis,...]]
                #     # image_ft[func][im,:,:] = settings['input_func'][func](image_ft0) / settings['input_scale'][func]
                if log_in:
                    mage_ft_list = [np.log(1e-12+image_ft_r_mean)[np.newaxis,...]/np.log(255.0)]
                else:
                    image_ft_list = [image_ft_r_mean[np.newaxis,...]/255.0] # accumulating list of shape len x 1 x 600
                
                if im == 0:
                    image_ft = da.array(np.concatenate(image_ft_list,axis=0)[np.newaxis,...]) # accumulating list of shape len x 1 x 600
                else:
                    image_ft = da.concatenate([image_ft,np.concatenate(image_ft_list,axis=0)[np.newaxis,...]],axis=0) # accumulating list of shape len x 1 x 600

            print(f'completed hologram {im} of {hologram_count}') # ,end='\r
        ft_stop_time = datetime.datetime.now()

        print('histogram shape:')
        print(histogram.shape)

        # if settings['n_decimate'] <= 1:
        #     xsize = ds.coords['xsize'].copy()
        #     ysize = ds.coords['ysize'].copy()
        # else:
        #     xsize = ds.coords['xsize'][settings['n_decimate']//2::settings['n_decimate']]
        #     ysize = ds.coords['ysize'][settings['n_decimate']//2::settings['n_decimate']]
        holo_num = ds.coords['hologram_number'].copy()
        image_dims = ds['image'].dims
        print('image dimensions')
        print(image_dims)
        print('image shape')
        print(ds['image'].shape)
        # print('xsize:%d'%xsize.size)
        # print('ysize:%d'%ysize.size)
        # if not settings['FourierTransform']:
        #     in_chan = ['real']
        #     image_ft = ds['image'].values[:,np.newaxis,...]
        in_chan = ['abs']
        



    image_in_da = xr.DataArray(image_ft,
                                    coords={'hologram_number':holo_num[:hologram_count],
                                            'input_channels':in_chan,
                                            'rsize':rad},
                                    dims=[image_dims[0]]
                                        +['input_channels','rsize'])


    hist_bin_cent = xr.DataArray(histogram_centers,
                                    coords={'histogram_bin_centers':histogram_centers},
                                    dims=('histogram_bin_centers'))

    hist_bin_edges = xr.DataArray(histogram_edges,
                                    coords={'histogram_bin_edges':histogram_edges},
                                    dims=('histogram_bin_edges'))

    histogram_moments_da = xr.DataArray(histogram_moments,
                                    dims = ('hologram_number','moments'),
                                    coords={'hologram_number':holo_num[:hologram_count],
                                            'moments':settings.get('moments',[0,1,2,3,4,5,6])})

    histogram = histogram[...,np.newaxis]
    print('histogram shape')
    print(histogram.shape)
    histogram_da = xr.DataArray(histogram,
                dims=('hologram_number','histogram_bin_centers','output_channels'),
                coords={'hologram_number':holo_num[:hologram_count],
                        'histogram_bin_centers':hist_bin_cent,
                        'output_channels':['hist']})

    preproc_ds = xr.Dataset({'histogram':histogram_da,
                    'histogram_bin_centers':hist_bin_cent,
                    'histogram_bin_edges':hist_bin_edges,
                    'input_image':image_in_da,
                    'histogram_moments':histogram_moments_da},
                    attrs={'data_file':settings['data_file']})


    print("Writing to netcdf")
    print(path_save+file_base+".nc")
    preproc_ds.to_netcdf(path_save+file_base+".nc")

# # save the settings in human readable format
# # with a small file size
# json_dct = {'settings':settings,'paths':paths}
# for k in json_dct['settings']:
#     if hasattr(json_dct['settings'][k], '__call__'):
#         json_dct['settings'][k] = json_dct['settings'][k].__name__
#     if hasattr(json_dct['settings'][k],'__iter__'):
#         for j in range(json_dct['settings'][k]):
#             if hasattr(json_dct['settings'][k][j], '__call__'):
#                 json_dct['settings'][k][j] = json_dct['settings'][k][j].__name__
    
# with open(paths['save']+file_base+".json", 'w') as fp:
#     json.dump(json_dct, fp, indent=4)

print('write complete')


In [4]:
!python preprocess_FT_radavg.py

   histogram bins: 
      19
       [5.0, 185.0]

   max particle size: 232
   min particle size: 0

Performing Fourier Transform
58
completed hologram 0 of 5000
51
completed hologram 1 of 5000
76
completed hologram 2 of 5000
89
completed hologram 3 of 5000
65
completed hologram 4 of 5000
95
completed hologram 5 of 5000
59
completed hologram 6 of 5000
82
completed hologram 7 of 5000
99
completed hologram 8 of 5000
50
completed hologram 9 of 5000
54
completed hologram 10 of 5000
80
completed hologram 11 of 5000
80
completed hologram 12 of 5000
59
completed hologram 13 of 5000
66
completed hologram 14 of 5000
55
completed hologram 15 of 5000
98
completed hologram 16 of 5000
61
completed hologram 17 of 5000
54
completed hologram 18 of 5000
86
completed hologram 19 of 5000
68
completed hologram 20 of 5000
68
completed hologram 21 of 5000
92
completed hologram 22 of 5000
55
completed hologram 23 of 5000
78
completed hologram 24 of 5000
52
completed hologram 25 of 5000
95
completed hologram 