#### Libraries, Devices

In [1]:
import warnings

warnings.filterwarnings("ignore")

import itertools
import os
import random
import sys

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import segmentation_models_pytorch as smp
import torch
import xarray as xr
from torch.utils.data import Dataset
from tqdm import tqdm

In [2]:
torch.cuda.empty_cache()

In [3]:
is_cuda = torch.cuda.is_available()
print(is_cuda)
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")
print(device)
if is_cuda:
    torch.backends.cudnn.benchmark = True
    #torch.backends.cudnn.deterministic = True

True
cuda:0


#### Load Data

In [4]:
# real_path data - HOLODEC sensor data
# synthetic_path data - synthetic sensor data with no noise or processing
real_path = "/glade/p/cisl/aiml/ai4ess_hackathon/holodec/real_holograms_CSET_RF07_20150719_200000-210000.nc"
synthetic_path = "/glade/p/cisl/aiml/ai4ess_hackathon/holodec/synthetic_holograms_500particle_gamma_4872x3248_training.nc"

In [5]:
synth_data = xr.open_dataset(synthetic_path)
real_data = xr.open_dataset(real_path)

In [6]:
synth_data
# 50,000 particles with 100 propagated holograms
# each image 4872x3248
# synthetic data has deliberately 500 particles per hologram

In [7]:
real_data
# 158981 particles in 559 holograms
# Resolution same as synthetic data, 4872 x 3248

#### WavePropagator Class

In [8]:
# wave propagator class, from propagator.py

class WavePropagator(object):

    def __init__(self,
                 data_path,
                 n_bins=1000,
                 tile_size=512,
                 step_size=128,
                 marker_size=10,
                 transform_mode=None,
                 device="cpu"):

        self.h_ds = xr.open_dataset(data_path)

        if 'zMin' in self.h_ds.attrs:
            self.zMin = self.h_ds.attrs['zMin']  # minimum z in sample volume
            self.zMax = self.h_ds.attrs['zMax']
        else:  # some of the raw data does not have this parameter
            # should warn the user here through the logger
            self.zMin = 0.014
            self.zMax = 0.158 #15.8 - 1.4 / (1000)

        self.n_bins = n_bins
        self.z_bins = np.linspace(
            self.zMin, self.zMax, n_bins+1)*1e6  # histogram bin edges
        self.z_centers = self.z_bins[:-1] + 0.5 * \
            np.diff(self.z_bins)  # histogram bin centers

        self.tile_size = tile_size  # size of tiled images in pixels
        self.step_size = step_size  # amount that we shift the tile to make a new tile
        # UNET gaussian marker width (standard deviation) in um
        self.marker_size = marker_size
        self.device = device

        # step_size is not allowed be be larger than the tile_size
        assert self.tile_size >= self.step_size

        self.dx = self.h_ds.attrs['dx']      # horizontal resolution
        self.dy = self.h_ds.attrs['dy']      # vertical resolution
        self.Nx = int(self.h_ds.attrs['Nx'])  # number of horizontal pixels
        self.Ny = int(self.h_ds.attrs['Ny'])  # number of vertical pixels
        self.lam = self.h_ds.attrs['lambda']  # wavelength
        self.image_norm = 255.0
        self.transform_mode = transform_mode
        self.x_arr = np.arange(-self.Nx//2, self.Nx//2)*self.dx
        self.y_arr = np.arange(-self.Ny//2, self.Ny//2)*self.dy

        self.tile_x_bins = np.arange(-self.Nx//2,
                                     self.Nx//2, self.step_size)*self.dx*1e6
        self.tile_y_bins = np.arange(-self.Ny//2,
                                     self.Ny//2, self.step_size)*self.dy*1e6

        self.fx = torch.fft.fftfreq(
            self.Nx, self.dx, device=self.device).unsqueeze(0).unsqueeze(2)
        self.fy = torch.fft.fftfreq(
            self.Ny, self.dy, device=self.device).unsqueeze(0).unsqueeze(1)

    def torch_holo_set(self,
                       Ein: torch.tensor,
                       z_tnsr: torch.tensor):
        """
        Propagates an electric field a distance z
        Ein complex torch.tensor
        - input electric field

        fx:real torch.tensor
        - x frequency axis (3D, setup to broadcast)

        fy: real torch.tensor
        - y frequency axis (3D, setup to broadcast)

        z_tnsr: torch.tensor
        - tensor of distances to propagate the wave Ein
            expected to have dims (Nz,1,1) where Nz is the number of z
            dimensions

        lam: float
        - wavelength

        returns: complex torch.tensor with dims (Nz,fy,fx)

        Note the torch.fft library uses dtype=torch.complex64
        This may be an issue for GPU implementation

        """

        if self.transform_mode == "standard":
            Ein = Ein.float()
            Ein -= torch.mean(Ein)
            Ein /= torch.std(Ein)

        elif self.transform_mode == "min-max":
            Ein = Ein.float()
            Ein -= torch.min(Ein)
            Ein /= torch.max(Ein)

        Etfft = torch.fft.fft2(Ein)
        Eofft = Etfft*torch.exp(1j*2*np.pi*z_tnsr/self.lam *
                                torch.sqrt(1-self.lam**2*(self.fx**2+self.fy**2)))

        # It might be helpful if we could omit this step.  It would save an inverse fft.
        Eout = torch.fft.ifft2(Eofft)

        return Eout

#### Frame Lookahead

In [9]:
def create_mask(prop, h_idx, z_idx):
    hid = h_idx + 1
    hid_mask = prop.h_ds["hid"] == hid

    # Filter particles based on h_idx
    x_part = prop.h_ds["x"].values[hid_mask]
    y_part = prop.h_ds["y"].values[hid_mask]
    z_part = prop.h_ds["z"].values[hid_mask]
    d_part = prop.h_ds["d"].values[hid_mask]

    z_indices = np.digitize(z_part, prop.z_bins) - 1
    # Initialize the UNET mask
    unet_mask = np.zeros((prop.x_arr.shape[0], prop.y_arr.shape[0]))
    num_particles = 0 
    
    if z_idx in z_indices:
        cond = np.where(z_idx == z_indices)
        x_part = x_part[cond]
        y_part = y_part[cond]
        z_part = z_part[cond]
        d_part = d_part[cond]
        
        #print(x_part, y_part, z_part, d_part)
        
        # Build the UNET mask using vectorized operations
        for part_idx in range(len(cond[0])):
            y_diff = (prop.y_arr[None, :] * 1e6 - y_part[part_idx])
            x_diff = (prop.x_arr[:, None] * 1e6 - x_part[part_idx])
            d_squared = (d_part[part_idx] / 2)**2
            unet_mask += ((y_diff**2 + x_diff**2) < d_squared).astype(float)
            num_particles += 1

    return torch.from_numpy(unet_mask).unsqueeze(0), num_particles


class FullSizeHolograms(Dataset):
    
    def __init__(self, file_path, n_bins = 1000, shuffle = False, device = "cpu", lookahead = 0):
        
        # num of waveprop windows
        self.n_bins = n_bins
        # device used
        self.device = device
        # shuffle frames
        self.shuffle = shuffle
        # num of frames to look ahead
        self.lookahead = lookahead
        # wavepropagator object on device
        self.propagator = WavePropagator(file_path, n_bins = n_bins, device = device)


    def __len__(self):
        return len(self.propagator.h_ds.hologram_number) * self.n_bins

    def __getitem__(self, idx):
        
        if self.shuffle:
            idx = random.choice(range(self.__len__()))
        
        hologram_idx = idx // self.n_bins 
        plane_idx = idx // len(self.propagator.h_ds.hologram_number)
        z_props = self.propagator.z_centers[plane_idx: plane_idx + self.lookahead + 1]
        plane_indices = np.arange(plane_idx, plane_idx + self.lookahead + 1)
        # select hologram
        image = self.propagator.h_ds["image"].isel(hologram_number=hologram_idx).values
        
        # prop
        
        # make tensors of size lookahead + 1, and then add tensors
        prop_synths = []
        prop_phases = []
        masks = []
        particles_in_frames = []
        for z_prop, z_ind in zip(z_props, plane_indices):
            image_prop = self.propagator.torch_holo_set(
                torch.from_numpy(image).to(self.device),
                torch.FloatTensor([z_prop*1e-6]).to(self.device)
            )
            # ABS (x-input)
            prop_synth = torch.abs(image_prop)
            prop_synths.append(prop_synth)
            # Phase (x-input)
            prop_phase = torch.angle(image_prop)
            prop_phases.append(prop_phase)  
            # Mask (y-label)
            mask, num_particles = create_mask(self.propagator, hologram_idx, z_ind)
            masks.append(mask)
        
        # cat target frames with lookahead context frames, convert to ndarrays
        synth_window = torch.cat(prop_synths, dim = 0).cpu().numpy()
        phases_window = torch.cat(prop_phases, dim = 0).cpu().numpy()
        masks_window = torch.cat(masks, dim = 0).cpu().numpy()
        
        #print(synth_window.shape, phases_window.shape)
        #print(f"Hologram # {hologram_idx}, z-plane # {plane_idx}, # particles: {num_particles}")
        
        return synth_window, phases_window, masks_window

In [10]:
# take in a frame, and pad it on the far side with 0s such that the dimensions
# are evenly divisible by 32, so it fits the smp filter size of 32
def pad_reshape(image):
    z_reshape = np.zeros((image.shape[0], 154 * 32, 102 * 32))
    z_reshape[:image.shape[0], :image.shape[1], :image.shape[2]] = image
    img = torch.from_numpy(z_reshape).float()
    return img

# take in the real and synthetic frames, pad them using pad_reshape(),
# then stick them together to feed to model
def reshape_concat(z, phase):
    # take in z and phase frames, pad them using pad_reshape, and then stick together
    # unsqueeze 1 dim for batch #
    # return ndarray
    return(torch.cat([pad_reshape(z), pad_reshape(phase)], dim = 0).unsqueeze(dim = 0))

In [11]:
lookahead = 1

train_dataset = FullSizeHolograms(synthetic_path, shuffle = False, device = device, n_bins = 1000, lookahead = lookahead)

In [12]:
sw, pw, m = train_dataset.__getitem__(200)
print(sw.shape)

(2, 4872, 3248)


In [13]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    #batch_size=8,
    #num_workers=1,
    pin_memory=True,
    shuffle=True,
)

In [14]:
for k, (x, xp, y) in enumerate(train_loader):
    print(x.shape, y.shape)
    break

torch.Size([1, 2, 4872, 3248]) torch.Size([1, 2, 4872, 3248])


#### Select and Load Model
Turn cell from raw into python to select model and run

In [15]:
torch.cuda.empty_cache()
model = smp.PSPNet(
    encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels= 2 * (lookahead + 1),     # model input channels. 2 * (lookahead + 1) for target frame + lookahead context for both real and phase data 
    classes=1,                      # model output channels (number of classes in your dataset)
    activation = "sigmoid"
)

# after loading data, estimated size of around 13753 MiB
model = model.to(device)

In [16]:
!nvidia-smi --query-gpu=gpu_name,memory.free --format=csv,noheader

Tesla V100-SXM2-32GB, 32507 MiB
Tesla V100-SXM2-32GB, 17355 MiB


In [17]:
z, a, m = train_dataset.__getitem__(200)

In [18]:
data_sample = reshape_concat(z,a)
#print(z.shape, a.shape)
print(data_sample.shape)

torch.Size([1, 4, 4928, 3264])


In [19]:
# pad to 2*lookahead + 1, 4896, 3294

print(z.shape, a.shape)
img = torch.from_numpy(z)
img = reshape_concat(z,a)
print(img.shape)

preds = model(img.to(device))


(2, 4872, 3248) (2, 4872, 3248)
torch.Size([1, 4, 4928, 3264])


In [20]:
# terminal command to get available GPU mem
!nvidia-smi --query-gpu=gpu_name,memory.free --format=csv,noheader

Tesla V100-SXM2-32GB, 32507 MiB
Tesla V100-SXM2-32GB, 5003 MiB
