In [18]:
from functools import partial
import numpy as np
import random
import time
import tqdm
import xarray as xr
import torch
import yaml
import sys
import os
from holodecml.models import load_model
from holodecml.propagation import InferencePropagator
from sklearn.model_selection import train_test_split
import gc
import logging
import warnings
warnings.filterwarnings("ignore")

import torch.nn.functional as F

if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

In [19]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [20]:
class OGPropagator(InferencePropagator):
    
    def get_sub_images_labeled(self,
                               image_tnsr,
                               z_sub_set,
                               z_counter,
                               xp, yp, zp, dp,
                               infocus_mask,
                               z_part_bin_idx,
                               batch_size=32,
                               return_arrays=False,
                               return_metrics=False,
                               thresholds=None,
                               obs_threshold=None):

        with torch.no_grad():

            # build the torch tensor for reconstruction
            z_plane = torch.tensor(
                z_sub_set*1e-6, device=self.device).unsqueeze(-1).unsqueeze(-1)

            # reconstruct the selected planes
            E_out = self.torch_holo_set(image_tnsr, z_plane)

            if self.color_dim == 2:
                stacked_image = torch.cat([
                    torch.abs(E_out).unsqueeze(1), torch.angle(E_out).unsqueeze(1)], 1)
            elif self.color_dim == 1:
                stacked_image = torch.abs(E_out).unsqueeze(1)
            else:
                raise OSError(f"Unrecognized color dimension {self.color_dim}")
            stacked_image = self.apply_transforms(
                stacked_image.squeeze(0)).unsqueeze(0)

            size = (E_out.shape[1], E_out.shape[2])
            true_output = torch.zeros(size).to(self.device)
            pred_proba = torch.zeros(size).to(self.device)
            counter = torch.zeros(size).to(self.device)

            chunked = np.array_split(
                list(self.idx2slice.items()),
                int(np.ceil(len(self.idx2slice) / batch_size))
            )

            inputs, masks, preds = [], [], []
            for z_idx in range(E_out.shape[0]):

                unet_mask = torch.zeros(E_out.shape[1:]).to(
                    self.device)  # initialize the UNET mask
                # locate all particles in this plane
                part_in_plane_idx = np.where(
                    z_part_bin_idx == z_idx+z_counter)[0]

                # build the UNET mask for this z plane
                for part_idx in part_in_plane_idx:
                    unet_mask += torch.from_numpy(
                        (self.y_arr[None, :]*1e6-yp[part_idx])**2 +
                        (self.x_arr[:, None]*1e6-xp[part_idx]
                         )**2 < (dp[part_idx]/2)**2
                    ).float().to(self.device)

                worker = partial(
                    self.collate_masks,
                    image=stacked_image[z_idx, :].float(),
                    mask=unet_mask
                )

                for chunk in chunked:
                    slices, x, true_mask_tile = worker(chunk)
                    pred_proba_tile = self.model(x).squeeze(1)

                    for k, ((row_idx, col_idx), (row_slice, col_slice)) in enumerate(slices):
                        counter[row_slice, col_slice] += 1
                        true_output[row_slice,
                                    col_slice] += true_mask_tile[k]
                        pred_proba[row_slice,
                                   col_slice] += pred_proba_tile[k]
            
                #print(pred_proba)
                                
        return 1

In [21]:
class CustomPropagator(InferencePropagator):
    
    def get_sub_images_labeled(self,
                               image_tnsr,
                               z_sub_set,
                               z_counter,
                               xp, yp, zp, dp,
                               infocus_mask,
                               z_part_bin_idx,
                               batch_size=32,
                               return_arrays=False,
                               return_metrics=False,
                               thresholds=None,
                               obs_threshold=None):

        with torch.no_grad():

            # build the torch tensor for reconstruction
            z_plane = torch.tensor(
                z_sub_set*1e-6, device=self.device).unsqueeze(-1).unsqueeze(-1)

            # reconstruct the selected planes
            E_out = self.torch_holo_set(image_tnsr, z_plane)

            if self.color_dim == 2:
                image = torch.cat([
                    torch.abs(E_out).unsqueeze(1), torch.angle(E_out).unsqueeze(1)], 1)
            elif self.color_dim == 1:
                image = torch.abs(E_out).unsqueeze(1)
            else:
                raise OSError(f"Unrecognized color dimension {self.color_dim}")
            stacked_image = self.apply_transforms(
                image.squeeze(0)).unsqueeze(0)

            size = (E_out.shape[1], E_out.shape[2])
            for z_idx in range(E_out.shape[0]):

                unet_mask = torch.zeros(E_out.shape[1:]).to(
                    self.device)  # initialize the UNET mask
                # locate all particles in this plane
                part_in_plane_idx = np.where(
                    z_part_bin_idx == z_idx+z_counter)[0]

                # build the UNET mask for this z plane
                for part_idx in part_in_plane_idx:
                    unet_mask += torch.from_numpy(
                        (self.y_arr[None, :]*1e6-yp[part_idx])**2 +
                        (self.x_arr[:, None]*1e6-xp[part_idx]
                         )**2 < (dp[part_idx]/2)**2
                    ).float().to(self.device)
                    
                # Calculate the expected number of tiles
                Nx_tile = self.Nx//self.step_size
                Ny_tile = self.Ny//self.step_size

                # For a complex field E_out, this generalizes the functional 
                # operations that define the color channels
                # the result E_out_tile has dimensions
                # (z, tile, color, x, y)
                
                
                func_lst = [torch.abs, torch.angle]
                image = torch.stack([
                    F.unfold(fnc(image), (self.tile_size, self.tile_size), stride=self.step_size).permute(1,2,0).reshape(
                        self.tile_size, self.tile_size, image.shape[0], -1) for fnc in func_lst
                ], dim = -1).permute(2,3,4,0,1).squeeze(0)[:, 0:1, :, :]
                
                # Make mask tiles
#                 mask_tile = F.unfold(
#                     unet_mask.unsqueeze(0).unsqueeze(0), 
#                     (self.tile_size, self.tile_size), 
#                     stride=self.step_size
#                 )
#                 mask_tile = mask_tile.reshape(
#                     mask_tile.shape[2], self.tile_size, self.tile_size
#                 )
                
                # predict then fold the predictions back into full size         
                pred_masks = torch.cat([
                    self.model(tile.float())
                    for tile in np.array_split(image, image.shape[0] // batch_size)
                ], axis = 0)
                # (no tiles, 1, tile size, tile size

                pred_masks = pred_masks.permute(2,3,0,1).reshape(
                    self.tile_size*self.tile_size, pred_masks.shape[1], pred_masks.shape[0]
                ).permute(1, 0, 2)
                # (1, tile size * tile size, no of tiles)
                
                pred_masks = F.fold(pred_masks,
                    output_size = (Nx_tile*self.step_size, Ny_tile*self.step_size), 
                    kernel_size = self.tile_size, 
                    stride = self.step_size
                )
                # (1, 1, 4864, 3200)
                #print(pred_masks)
                
        return 1

In [22]:
is_cuda = torch.cuda.is_available()
device = torch.device("cpu") if not is_cuda else torch.device("cuda:0")

In [23]:
fn_name = "../config/model_segmentation.yml"

In [24]:
with open(fn_name) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [25]:
seed = conf["seed"]
n_bins = conf["data"]["n_bins"]
tile_size = conf["data"]["tile_size"]
step_size = conf["data"]["step_size"]
marker_size = conf["data"]["marker_size"]
total_positive = int(conf["data"]["total_positive"])
total_negative = int(conf["data"]["total_negative"])
total_examples = int(conf["data"]["total_training"])

# Do not load the image transformations
transform_mode = "None"
tile_transforms = None
color_dim = conf["model"]["in_channels"]
batch_size = conf["inference"]["batch_size"]
inference_mode = conf["inference"]["mode"]

name_tag = f"{tile_size}_{step_size}_{total_positive}_{total_negative}_{total_examples}_{transform_mode}"

# Load the model
model = load_model(conf["model"]).to(device).eval()

In [26]:
data_set = conf["style"]["raw"]["path"]
#holograms_per_dataset = style_conf["data"]["raw"]["holograms_per_dataset"]
tiles_per_reconstruction = conf["style"]["raw"]["tiles_per_reconstruction"]
reconstruction_per_hologram = conf["style"]["raw"]["reconstruction_per_hologram"]
save_path = conf["style"]["raw"]["save_path"]
sampler = conf["style"]["raw"]["sampler"]
name_tag = f"{sampler}_{name_tag}"

In [27]:
seed_everything(seed)

In [28]:
prop = CustomPropagator(
    data_set,
    n_bins=n_bins,
    color_dim=color_dim,
    tile_size=tile_size,
    step_size=step_size,
    marker_size=marker_size,
    transform_mode=transform_mode,
    device=device,
    model=model,
    mode=inference_mode,
    probability_threshold=0.5,
    transforms=tile_transforms
)

og_prop = OGPropagator(
    data_set,
    n_bins=n_bins,
    color_dim=color_dim,
    tile_size=tile_size,
    step_size=step_size,
    marker_size=marker_size,
    transform_mode=transform_mode,
    device=device,
    model=model,
    mode=inference_mode,
    probability_threshold=0.5,
    transforms=tile_transforms
)

# Create a list of z-values to propagate to
z_list = prop.create_z_plane_lst(planes_per_call=1)
random.shuffle(z_list)
z_list = z_list[:reconstruction_per_hologram]

In [29]:
h_range = prop.h_ds.hologram_number.values
h_range_prime = list(set(h_range) - set(range(10, 20)))

# split into train/test/valid
h_train, rest_data = train_test_split(h_range_prime, train_size=0.8)
h_valid, h_test = train_test_split(rest_data, test_size=0.45)
h_test += list(range(10, 20))

In [30]:
h_splits = [h_train, h_valid, h_test]
split_names = ["train", "valid", "test"]

for split, h_split in zip(split_names, h_splits):

    total = len(h_split) * tiles_per_reconstruction * reconstruction_per_hologram
    X = np.zeros((total, 512, 512))
    Y = np.zeros((total, 1))

    c = 0
    # Main loop to call the generator, predict with the model, and aggregate and save the results
    for nc, h_idx in enumerate(h_split):
        
        # Create a list of z-values to propagate to
        z_list = prop.create_z_plane_lst(planes_per_call=1)
        random.shuffle(z_list)
        z_list = z_list[:reconstruction_per_hologram]
        
        planes_processed = n_bins
        inference_generator = prop.get_next_z_planes_labeled(
            h_idx,
            z_list,
            batch_size=batch_size,
            thresholds=[0.5],
            return_arrays=False,
            return_metrics=False,
            obs_threshold=0.5,
            start_z_counter=planes_processed
        )

        t0 = time.time()
        count = 0
        for z_idx, results_dict in enumerate(inference_generator):
            count += 1 
            if count == 10:
                break
        print(time.time() - t0)
            
            
        inference_generator = og_prop.get_next_z_planes_labeled(
            h_idx,
            z_list,
            batch_size=batch_size,
            thresholds=[0.5],
            return_arrays=False,
            return_metrics=False,
            obs_threshold=0.5,
            start_z_counter=n_bins
        )
        
        t0 = time.time()
        count = 0
        for z_idx, results_dict in enumerate(inference_generator):
            count += 1 
            if count == 10:
                break
           
        print(time.time() - t0)
        break
    break

55.23858022689819
55.917946100234985


In [31]:
torch.cuda.empty_cache()
del prop
gc.collect()

89