In [185]:
from hagelslag.evaluation.MetricPlotter import roc_curve, performance_diagram
from holodecml.models import load_model
from functools import partial
from argparse import ArgumentParser
import torch.multiprocessing as mp
import matplotlib.pyplot as plt
import numpy as np
import subprocess
import traceback
import logging
import signal
import joblib
import scipy
import sys
import yaml
import time
import tqdm
import glob
import gc
import os
import torch
import warnings
warnings.filterwarnings("ignore")
import pandas as pd

import torch
import torch.fft
import torch.nn.functional as F
import torchvision.models as models

if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

from holodecml.metrics import DistributedROC
from holodecml.transforms import LoadTransformations
from holodecml.propagation import InferencePropagator
from holodecml.data import save_sparse_csr, load_sparse_csr 

import matplotlib
import matplotlib.pyplot as plt
from scipy.signal import convolve2d

In [2]:
def pad(a, padding = 512, fx = False, fy = False):
    """Return bottom right padding."""
    zeros = np.zeros((padding, padding))
    if not fx and not fy:
        zeros[:a.shape[0], :a.shape[1]] = a
    if fx and not fy:
        zeros[-a.shape[0]:, :a.shape[1]] = a
    if fy and not fx:
        zeros[:a.shape[0], -a.shape[1]:] = a
    else:
        zeros[-a.shape[0]:, -a.shape[1]:] = a
    return zeros

In [110]:
def propagate(prop, image, zp):
    real_zp_im = prop.torch_holo_set(
        torch.from_numpy(image).to(device),
        torch.FloatTensor([zp*1e-6]).to(device)
    ).squeeze(0)
    prop_real = torch.abs(real_zp_im).cpu().numpy()
    return prop_real

def get_particle(prop, h_idx, max_coor = 0, repeat = 0, predicted_particles = None):
    indices = predicted_particles["h"] == h_idx
    predicted_particles = predicted_particles[indices].copy().sort_values("z_p")
    predicted_particles["x_t"] = predicted_particles["x_t"] * (2 * 7209 / 4872) - 7209
    predicted_particles["y_t"] = predicted_particles["y_t"] * (2 * 4806 / 3248) - 4806
    predicted_particles["x_p"] = predicted_particles["x_p"] * (2 * 7209 / 4872) - 7209
    predicted_particles["y_p"] = predicted_particles["y_p"] * (2 * 4806 / 3248) - 4806
    predicted_particles["d_t"] = 2.96 * predicted_particles["d_t"] 
    predicted_particles["d_p"] = 2.96 * predicted_particles["d_p"] 
    
    particle_idx = np.where(prop.h_ds['hid'].values == h_idx+1)
    c = np.isfinite(predicted_particles["x_p"]) & ~np.isfinite(predicted_particles["x_t"])
    
    print("Total examples that need labeled", c.sum())

    x_locations = predicted_particles[c]["x_p"]
    y_locations = predicted_particles[c]["y_p"]
    z_locations = predicted_particles[c]["z_p"]
    d_locations = predicted_particles[c]["d_p"]
        
#     indices = np.where(prop.h_ds["hid"] == h_idx + 1)
#     d_locations = prop.h_ds["d"].values[indices]
#     x_locations = prop.h_ds["x"].values[indices]
#     y_locations = prop.h_ds["y"].values[indices]
#     z_locations = prop.h_ds["z"].values[indices]
    xp = np.digitize(x_locations, 1e6 * prop.x_arr, right=True)
    yp = np.digitize(y_locations, 1e6 * prop.y_arr, right=True)
    zp = z_locations
    dp = d_locations
    
    for (x, y, z, d) in zip(xp, yp, zp, dp):
        yield x, y, z, d

In [85]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

if is_cuda:
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

In [86]:
path = "../../results/style_transfer/"
config = f'{path}/best.yml'

with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [None]:
table = pd.read_csv(f"{path}/real_1020/prediction_table_0.001.csv")

In [87]:
n_nodes = conf["inference"]["n_nodes"]
n_gpus = conf["inference"]["gpus_per_node"]
threads_per_gpu = conf["inference"]["threads_per_gpu"]
workers = int(n_nodes * n_gpus * threads_per_gpu)
this_worker = 0

save_arrays = conf["inference"]["save_arrays"]

n_bins = conf["data"]["n_bins"]
tile_size = conf["data"]["tile_size"]
step_size = conf["data"]["step_size"]
marker_size = conf["data"]["marker_size"]
transform_mode = "None" if "transform_mode" not in conf[
    "data"] else conf["data"]["transform_mode"]

model_loc = conf["save_loc"]
model_name = conf["model"]["name"]
color_dim = conf["model"]["in_channels"]

batch_size = conf["inference"]["batch_size"]
save_arrays = conf["inference"]["save_arrays"]
save_prob = conf["inference"]["save_probs"]
inference_mode = conf["inference"]["mode"]

if "probability_threshold" in conf["inference"]:
    probability_threshold = conf["inference"]["probability_threshold"]
else:
    probability_threshold = 0.5

verbose = conf["inference"]["verbose"]
data_set = conf["inference"]["data_set"]["path"]
data_set_name = conf["inference"]["data_set"]["name"]

prop_data_loc = os.path.join(model_loc, f"{data_set_name}/propagated")
roc_data_loc = os.path.join(model_loc, f"{data_set_name}/roc")
image_data_loc = os.path.join(model_loc, f"{data_set_name}/images")

for directory in [prop_data_loc, roc_data_loc, image_data_loc]:
    if not os.path.exists(directory):
        os.makedirs(directory, exist_ok=True)

In [6]:
h_conf = conf["inference"]["data_set"]["holograms"]
if isinstance(h_conf, dict):
    h_min = conf["inference"]["data_set"]["holograms"]["min"]
    h_max = conf["inference"]["data_set"]["holograms"]["max"]
    h_range = range(h_min, h_max)
elif isinstance(h_conf, list):
    h_range = h_conf
elif isinstance(h_conf, int) or isinstance(h_conf, float):
    h_range = [h_conf]
else:
    raise OSError(f"Unidentified h-range settings {h_conf}")

# Load the model
print(
    f"Worker {this_worker}: Loading and moving model to device {device}")
model = load_model(conf["model"]).to(device)

# Load the weights from the training location
checkpoint = torch.load(
    os.path.join(model_loc, "best.pt"),
    map_location=lambda storage, loc: storage
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel()
                       for p in model.parameters() if p.requires_grad)
print(
    f"Worker {this_worker}: There are {total_params} total model parameters")

# Load the image transformations
if "inference" in conf["transforms"]:
    print(f"Worker {this_worker}: Loading image transformations")
    if "Normalize" in conf["transforms"]["training"]:
        conf["transforms"]["inference"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]
    tile_transforms = LoadTransformations(conf["transforms"]["inference"])
else:
    tile_transforms = None

Worker 0: Loading and moving model to device cuda:0
Worker 0: There are 10483137 total model parameters
Worker 0: Loading image transformations


In [7]:
print(f"Worker {this_worker}: Loading an inference wave-prop generator")

prop = InferencePropagator(
    data_set,
    n_bins=n_bins,
    color_dim=color_dim,
    tile_size=tile_size,
    step_size=step_size,
    marker_size=marker_size,
    transform_mode=transform_mode,
    device=device,
    model=model,
    mode=inference_mode,
    probability_threshold=probability_threshold,
    transforms=tile_transforms
)

# Create a list of z-values to propagate to
z_list = prop.create_z_plane_lst(planes_per_call=1)
z_list = np.array_split(z_list, workers)[this_worker]

Worker 0: Loading an inference wave-prop generator


In [196]:
for h_idx in range(11, 20):
    k = 0 
    prop.model.eval()
    real_image = prop.h_ds["image"][h_idx].values
    for (x, y, z, d) in get_particle(prop, h_idx, predicted_particles = table):

        prop_image = propagate(prop, real_image, z)

        dz = prop.z_centers[1] - prop.z_centers[0]
        prop_image_m1 = propagate(prop, real_image, z - dz)
        prop_image_p1 = propagate(prop, real_image, z + dz)

        print(h_idx, x, y, z, d)

        x, y, z = int(x), int(y), int(z)

        ### 512 x 512 image
        fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(nrows=1, ncols=5, sharey='col', figsize=(12,5))

        _x = x if x < 256 else 256
        _y = y if y < 256 else 256

        image = pad(prop_image[(x-_x):(x+_x), (y-_y):(y+_y)].transpose((1,0)), 
            padding = 512,
            fx = _x < 256,
            fy = _y < 256)
        ax1.imshow(image, origin ='lower', cmap = "gray")
        ax1.set_title("Input image")

        X = torch.from_numpy(prop.apply_transforms(prop_image[(x-_x):(x+_x), (y-_y):(y+_y)])).unsqueeze(0).unsqueeze(0)
        with torch.no_grad():
            _, _, a, b = X.shape
            if a != b:
                _X = torch.zeros(1, 1, 512, 512)
                _X[:, :, :a, :b] = X
                X = _X
            mask = (prop.model(X.to(device)).detach().cpu().squeeze(0).squeeze(0).numpy() > 0.5).astype(float)

        ### 32 x 32 image 
        _x = x if x < 32 else 32
        _y = y if y < 32 else 32
        image = pad(prop_image_m1[(x-_x):(x+_x), (y-_y):(y+_y)].transpose((1,0)), 
            padding = 64,
            fx = _x < 32,
            fy = _y < 32)

        ax2.imshow(image, origin ='lower', cmap = "gray")
        ax2.set_title("z(i-1)")

        ### 32 x 32 image 
        _x = x if x < 32 else 32
        _y = y if y < 32 else 32
        image = pad(prop_image[(x-_x):(x+_x), (y-_y):(y+_y)].transpose((1,0)), 
            padding = 64,
            fx = _x < 32,
            fy = _y < 32)

        ax3.imshow(image, origin ='lower', cmap = "gray")
        ax3.set_title("z(i)")

        ### 32 x 32 image 
        _x = x if x < 32 else 32
        _y = y if y < 32 else 32
        image = pad(prop_image_p1[(x-_x):(x+_x), (y-_y):(y+_y)].transpose((1,0)), 
            padding = 64,
            fx = _x < 32,
            fy = _y < 32)

        ax4.imshow(image, origin ='lower', cmap = "gray")
        ax4.set_title("z(i+1)")

        ### Mask prediction
        image = mask.transpose((1,0))
    #     image = pad(image, 
    #         padding = 64,
    #         fx = _x < 32,
    #         fy = _y < 32)
        ax5.imshow(image, origin ='lower', cmap = "gray")
        ax5.set_xlim(256-32, 256+32)
        ax5.set_ylim(256-32, 256+32)
        ax5.set_title("Predicted mask")

        plt.tight_layout()
        #plt.show()

        k += 1
        
        name = f"{h_idx} {int(x)} {int(y)} {int(z)} {np.round(d, 2)}"
        fig.savefig(f'{path}/labeling/{name}.png', dpi = 300, bbox_inches = "tight")
        
        fig.clear()
        del fig
        gc.collect()

Total examples that need labeled 37
11 2567 2467 15368.0 23.68
11 3604 916 16376.0 2.96
11 1615 2775 18104.0 13.32
11 4197 1262 18680.0 5.328
11 604 1526 18721.14285714286 25.37142857142857
11 508 1751 31496.0 40.45333333333333
11 4246 2398 35792.0 15.786666666666665
11 1905 2546 38148.8 26.048000000000002
11 4175 2214 38840.0 38.48
11 217 3112 41576.0 8.879999999999999
11 1808 862 46472.0 8.879999999999999
11 4579 2008 46640.0 14.306666666666665
11 2429 1995 52520.0 29.6
11 1702 1994 60224.0 14.8
11 1787 2921 61941.71428571428 25.79428571428571
11 777 1074 63368.0 24.666666666666668
11 1604 2280 67442.0 29.6
11 1325 1997 69704.0 23.68
11 1324 1996 72536.0 41.44
11 3040 645 72536.0 17.759999999999998
11 2823 1110 72656.0 17.759999999999998
11 1861 7 79412.0 6.66
11 1965 1987 80240.0 21.46
11 1272 2087 89816.0 17.759999999999998
11 1273 2086 90420.8 24.864
11 159 900 90680.0 2.96
11 2793 2822 93905.6 18.944
11 2313 2683 101019.2 24.272
11 691 903 102104.0 23.68
11 3816 3007 103424.0 10.

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>

<Figure size 864x360 with 0 Axes>