In [1]:
from holodecml.models import load_model
from functools import partial
from argparse import ArgumentParser
import torch.multiprocessing as mp
import matplotlib.pyplot as plt
import numpy as np
import subprocess
import traceback
import random
import logging
import signal
import joblib
import scipy
import sys
import yaml
import time
import tqdm
import glob
import os
import warnings
warnings.filterwarnings("ignore")
import xarray as xr

import torch, yaml, os
import torch.fft
import torch.nn.functional as F
import torchvision.models as models

if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

from holodecml.metrics import DistributedROC
from holodecml.transforms import LoadTransformations
from holodecml.propagation import InferencePropagator
from holodecml.data import save_sparse_csr
from torch.utils.data import Dataset

import matplotlib.pyplot as plt
import pandas as pd

from sklearn.model_selection import train_test_split

In [2]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [3]:
class CustomPropagator(InferencePropagator):
    
    def get_sub_images_labeled(self,
                               image_tnsr,
                               z_sub_set,
                               z_counter,
                               xp, yp, zp, dp,
                               infocus_mask,
                               z_part_bin_idx,
                               batch_size=32,
                               return_arrays=False,
                               return_metrics=False,
                               thresholds=None,
                               obs_threshold=None):

        with torch.no_grad():

            # build the torch tensor for reconstruction
            z_plane = torch.tensor(
                z_sub_set*1e-6, device=self.device).unsqueeze(-1).unsqueeze(-1)

            # reconstruct the selected planes
            E_out = self.torch_holo_set(image_tnsr, z_plane)

            if self.color_dim == 2:
                stacked_image = torch.cat([
                    torch.abs(E_out).unsqueeze(1), torch.angle(E_out).unsqueeze(1)], 1)
            elif self.color_dim == 1:
                stacked_image = torch.abs(E_out).unsqueeze(1)
            else:
                raise OSError(f"Unrecognized color dimension {self.color_dim}")
            stacked_image = self.apply_transforms(
                stacked_image.squeeze(0)).unsqueeze(0)

            size = (E_out.shape[1], E_out.shape[2])
            true_output = torch.zeros(size).to(self.device)
            counter = torch.zeros(size).to(self.device)

            chunked = np.array_split(
                list(self.idx2slice.items()),
                int(np.ceil(len(self.idx2slice) / batch_size))
            )

            inputs, masks, preds = [], [], []
            for z_idx in range(E_out.shape[0]):

                unet_mask = torch.zeros(E_out.shape[1:]).to(
                    self.device)  # initialize the UNET mask
                # locate all particles in this plane
                part_in_plane_idx = np.where(
                    z_part_bin_idx == z_idx+z_counter)[0]

                # build the UNET mask for this z plane
                for part_idx in part_in_plane_idx:
                    unet_mask += torch.from_numpy(
                        (self.y_arr[None, :]*1e6-yp[part_idx])**2 +
                        (self.x_arr[:, None]*1e6-xp[part_idx]
                         )**2 < (dp[part_idx]/2)**2
                    ).float().to(self.device)

                worker = partial(
                    self.collate_masks,
                    image=stacked_image[z_idx, :].float(),
                    mask=unet_mask
                )

                for chunk in chunked:
                    slices, x, true_mask_tile = worker(chunk)
                    for k, ((row_idx, col_idx), (row_slice, col_slice)) in enumerate(slices):
                        inputs.append(x[k].cpu().numpy()) 
                        masks.append(true_mask_tile[k].cpu().unsqueeze(0).numpy().sum())
            
            return_dict = {"inputs": np.vstack(inputs), "masks": np.array(masks)}
                                
        return return_dict

In [4]:
is_cuda = torch.cuda.is_available()
device = torch.device("cpu") if not is_cuda else torch.device("cuda:0")

In [15]:
fn_name = "../config/model_segmentation.yml"

In [16]:
with open(fn_name) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [17]:
seed = conf["seed"]
n_bins = conf["data"]["n_bins"]
tile_size = conf["data"]["tile_size"]
step_size = conf["data"]["step_size"]
marker_size = conf["data"]["marker_size"]
total_positive = int(conf["data"]["total_positive"])
total_negative = int(conf["data"]["total_negative"])
total_examples = int(conf["data"]["total_training"])

# Do not load the image transformations
transform_mode = "None"
tile_transforms = None
color_dim = conf["model"]["in_channels"]
batch_size = conf["inference"]["batch_size"]
inference_mode = conf["inference"]["mode"]

name_tag = f"{tile_size}_{step_size}_{total_positive}_{total_negative}_{total_examples}_{transform_mode}"

# Load the model
model = load_model(conf["model"]).to(device).eval()

In [18]:
data_set = conf["style"]["raw"]["path"]
#holograms_per_dataset = style_conf["data"]["raw"]["holograms_per_dataset"]
tiles_per_reconstruction = conf["style"]["raw"]["tiles_per_reconstruction"]
reconstruction_per_hologram = conf["style"]["raw"]["reconstruction_per_hologram"]
save_path = conf["style"]["raw"]["save_path"]
sampler = conf["style"]["raw"]["sampler"]
name_tag = f"{sampler}_{name_tag}"

In [9]:
seed_everything(seed)

In [10]:
prop = CustomPropagator(
    data_set,
    n_bins=n_bins,
    color_dim=color_dim,
    tile_size=tile_size,
    step_size=step_size,
    marker_size=marker_size,
    transform_mode=transform_mode,
    device=device,
    model=model,
    mode=inference_mode,
    probability_threshold=0.5,
    transforms=tile_transforms
)

# Create a list of z-values to propagate to
z_list = prop.create_z_plane_lst(planes_per_call=1)
random.shuffle(z_list)
z_list = z_list[:reconstruction_per_hologram]

In [11]:
h_range = prop.h_ds.hologram_number.values
h_range_prime = list(set(h_range) - set(range(10, 20)))

# split into train/test/valid
h_train, rest_data = train_test_split(h_range_prime, train_size=0.8)
h_valid, h_test = train_test_split(rest_data, test_size=0.45)
h_test += list(range(10, 20))

In [12]:
print(len(h_train), len(h_valid), len(h_test))

439 60 60


In [12]:
h_splits = [h_train, h_valid, h_test]
split_names = ["train", "valid", "test"]

for split, h_split in zip(split_names, h_splits):

    total = len(h_split) * tiles_per_reconstruction * reconstruction_per_hologram
    X = np.zeros((total, 512, 512))
    Y = np.zeros((total, 1))

    c = 0
    # Main loop to call the generator, predict with the model, and aggregate and save the results
    for nc, h_idx in tqdm.tqdm(enumerate(h_split), total = len(h_split)):
        planes_processed = n_bins
        inference_generator = prop.get_next_z_planes_labeled(
            h_idx,
            z_list,
            batch_size=batch_size,
            thresholds=[0.5],
            return_arrays=False,
            return_metrics=False,
            obs_threshold=0.5,
            start_z_counter=planes_processed
        )

        t0 = time.time()
        for z_idx, results_dict in enumerate(inference_generator):

            if results_dict["masks"].sum() > 0:
                mins = np.array([p.min() for p in results_dict["inputs"]])
                idx = np.where(mins==min(mins))[0]
                print(z_idx, idx, results_dict["masks"].shape)
                #raise

            idx = random.sample(range(results_dict["inputs"].shape[0]), k=tiles_per_reconstruction)
            X[c:c+tiles_per_reconstruction] = results_dict["inputs"][idx]
            Y[c:c+tiles_per_reconstruction] = results_dict["masks"][idx].reshape((tiles_per_reconstruction, 1))
            c += tiles_per_reconstruction

            if c >= X.shape[0]:
                break

        if c >= X.shape[0]:
                break

    if color_dim == 1:
        X = np.expand_dims(X, axis = 1)
        
    print(split, X.shape)

    df = xr.Dataset(data_vars=dict(var_x=(['n', 'd', 'x', 'y'], X[:c]), 
                                   var_y=(['n', 'z'], Y[:c])))

    df.to_netcdf(f"{save_path}/{split}_{name_tag}.nc")

446it [44:31,  5.99s/it]


train (11175, 1, 512, 512)


55it [05:36,  6.12s/it]


valid (1400, 1, 512, 512)


55it [05:34,  6.09s/it]


test (1400, 1, 512, 512)
