In [2]:
import pathlib
import os

from addict import Dict
from scipy.ndimage.morphology import binary_fill_holes
from skimage import color
import scipy.ndimage
import numpy as np
import seaborn_image as isns
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import yaml

import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from IPython.display import Image

from segmentation.model.metrics import *
from segmentation.model.frame import Framework
import segmentation.model.functions as fn
import segmentation.data.slice as sl

isns.set_context("notebook")

def min_max(im):
    return (im - im.min()) / im.max()

def mean_std(im):
    return (im - im.mean()) / im.std()

def get_tp_fp_fn(pred, true):
    pred, true = torch.from_numpy(pred), torch.from_numpy(true)
    tp, fp, fn = tp_fp_fn(pred, true)
    return tp, fp, fn

def get_precision_recall_iou(tp, fp, fn):
    p, r, i = precision(tp, fp, fn), recall(tp, fp, fn), IoU(tp, fp, fn)
    return p, r, i

#%% Data Preparation
conf = Dict(yaml.safe_load(open('./conf/unet_predict.yaml')))
conf.model_opts_cleanice.args.inchannels = len(conf.use_channels_cleanice)
conf.model_opts_cleanice.args.outchannels = len(conf.class_names)
use_physics = 10 in conf.use_channels_cleanice

data_dir = pathlib.Path(conf.data_dir)
preds_dir = pathlib.Path(conf.out_processed_dir) / "preds" / conf.run_name
model_path = pathlib.Path(conf.folder_name) / conf.run_name / 'models' / 'model_best.pt'
loss_fn = fn.get_loss(conf.model_opts_cleanice.args.outchannels)
frame = Framework(
    loss_fn=loss_fn,
    model_opts=conf.model_opts_cleanice,
    optimizer_opts=conf.optim_opts,
    device=(int(conf.gpu_rank))
)
if torch.cuda.is_available():
    state_dict = torch.load(model_path)
else:
    state_dict = torch.load(model_path, map_location="cpu")
frame.load_state_dict(state_dict)

arr = np.load(data_dir / "normalize_train.npy")
if conf.normalize == "mean-std":
    _mean, _std = arr[0], arr[1]
if conf.normalize == "min-max":
    _min, _max = arr[2], arr[3]

files = os.listdir(data_dir / "test")
inputs = [x for x in files if "tiff" in x]

# Visualizing Data
Use this widget to visualize the testing set images and labels as well as the model predictions and how where exactly they differ from the labeled ground truth.

For the disagreements:
* If the color is **RED** that means the label says those pixels are the given class but the model predicted otherwise.
* If the color is **BLUE** that means the model says those pixels are the given class but the labeled ground truth says otherwise.
* If you enable the input "visualize_agreement" then **GREEN** will be used to represent the pixels where both the model and label agree.

Here is an explanation of all the inputs:

> x_fname
* Filename in the format "tiff_X_slice_Y.npy" where X=Cell Number (0 to 201) and Y=Slice Number

> Channel 1, 2, and 3
* The channels used to create the false color image.

> color_map
* The map used to determine how to color the grayscale images https://matplotlib.org/stable/gallery/color/colormap_reference.html

In [100]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, interactive
from IPython.display import Image


labels = ["B1", "B2", "B3", "B4", "B5", "B6_VCID1", "B6_VCID2", "B7", "elevation", "slope"]
labels_with_idx = [(name, idx) for idx, name in enumerate(labels)]

ch1 = widgets.Dropdown(options=labels_with_idx, value=2, description='Channel 1')
ch2 = widgets.Dropdown(options=labels_with_idx, value=1, description='Channel 2')
ch3 = widgets.Dropdown(options=labels_with_idx, value=0, description='Channel 3')

@interact_manual
def show_data(x_fname = inputs, ch1=ch1, ch2=ch2, ch3=ch3, color_map=plt.colormaps(), figure_size=(1, 25), visualize_agreement=False):
    # Load input image (x)
    x = np.load(data_dir / "test" / x_fname)
    im = x[:, :, [ch1, ch2, ch3]]
    im = min_max(im)*255
    im = im.astype(np.uint8)

    if conf.normalize == "mean-std":
        if use_physics:
            x[:, :, :-1] = (x[:, :, :-1] - _mean[:-1]) / _std[:-1]
        else:
            x = (x - _mean) / _std
    if conf.normalize == "min-max":
        x = (x - _min) / (_max - _min)

    # Load label (y)
    mask = np.sum(x, axis=2) == 0
    has_mask = np.sum(mask==1) > 0
    print(has_mask)
    y_fname = x_fname.replace("tiff", "mask")
    y_true = np.load(data_dir / "test" / y_fname)+1
    y_true = y_true.astype(np.uint8)
    y_true[mask] = 0

    # Visualize False Color
    fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(2*figure_size, 2*figure_size))
    g = isns.imgplot(ax=ax[0], data=im)
    ax[0].set_title('False Color Image')

    # Visualize Label
    ticks = [0, 1, 2, 3] if has_mask else [1, 2, 3]
    g = isns.imgplot(ax=ax[1], data=y_true, cmap=color_map, cbar_ticks=ticks)
    ax[1].set_title('Labels (0=Mask, 1=BG, 2=CI, 3=Debris)')

    # Visualize Physics Channel
    phys_channel = min_max(x[:, :, 10])
    g = isns.imgplot(ax=ax[2], data=phys_channel, cmap=color_map)
    ax[2].set_title(f'Physics Channel')

    # Visualize Mask
    g = isns.imgplot(ax=ax[3], data=mask, cmap=color_map)
    ax[3].set_title(f'Mask')

    # Visualize Landsat7 Bands
    g = isns.ImageGrid(x, cmap=color_map, stop=10, col_wrap=5, height=figure_size, cbar_label=labels, despine=True)

    #%% ********** PREDICTION TIME! **********
    # Send input to model to get predicted labels
    _x = torch.from_numpy(np.expand_dims(x[:,:,conf.use_channels_cleanice], axis=0)).float()
    pred = frame.infer(_x)
    pred = torch.nn.Softmax(3)(pred)
    pred = np.squeeze(pred.cpu())

    # Threshold and fill holes for each class
    _bg = pred[:, :, 0] >= conf.threshold[0]
    _bg = binary_fill_holes(_bg)
    _ci = pred[:, :, 1] >= conf.threshold[1]
    _ci = binary_fill_holes(_ci)
    _debris = pred[:, :, 2] >= conf.threshold[2]
    _debris = binary_fill_holes(_debris)
    
    # Combine predictions once again
    _pred = np.zeros((pred.shape[0], pred.shape[1]), dtype=np.uint8)
    _pred[:] = 1

    _pred[_bg] = 1
    _pred[_ci] = 2
    _pred[_debris] = 3
    _pred[mask] = 0

    fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(3*figure_size, 2*figure_size))
    g = isns.imgplot(ax=ax[1, 0], data=y_true, cmap=color_map, cbar_ticks=ticks)
    ax[1, 0].set_title('Labels (0=Mask, 1=BG, 2=CI, 3=Debris)')

    g = isns.imgplot(ax=ax[1, 1], data=im)
    ax[1, 1].set_title('False Color Image')

    g = isns.imgplot(ax=ax[1, 2], data=_pred, cmap=color_map, cbar_ticks=ticks)
    ax[1, 2].set_title('Prediction (0=Mask, 1=BG, 2=CI, 3=Debris)')
    print(np.unique(y_true), np.unique(_pred), np.sum(_pred==0))
    # fig.suptitle('Label vs Prediction Disagreements (Red=Label, Blue=Prediction)')
    for idx, (class_name, class_label) in enumerate([('BG', 1), ('CleanIce', 2), ('Debris', 3)]):
        # Zero out everything in our label (y_true) that isn't the current class
        class_only_true = y_true.copy()
        class_only_true[class_only_true!=class_label] = 0

        # Zero out everything in our prediction (_pred) that isn't the current class
        class_only_pred = _pred.copy()
        class_only_pred[class_only_pred!=class_label] = 0

        # Color the disagreements
        disagreements = np.zeros((class_only_true.shape[0], class_only_true.shape[1], 3), dtype=np.float32)
        disagreements[np.logical_and(class_only_true==class_label, class_only_pred!=class_label)] = [1, 0, 0]
        disagreements[np.logical_and(class_only_true!=class_label, class_only_pred==class_label)] = [0, 0, 1]

        # Create an overlay for the disagreements
        alphas = np.zeros_like(class_only_true, dtype=np.float32)
        alphas[class_only_true!=class_only_pred] = 0.5
        if visualize_agreement:
            m = np.logical_and(class_only_true==class_only_pred, class_only_true==class_label)
            disagreements[m] = [0, 1, 0]
            alphas[m] = 0.5
        im_overlay = np.dstack((disagreements, alphas))

        # Visualize false color image with overlay
        g = isns.imgplot(ax=ax[0, idx], data=im)
        g = isns.imgplot(ax=ax[0, idx], data=im_overlay)
        ax[0, idx].set_title(f'Red pixels -> TrueLabel={class_name} but model disagrees\nBlue pixels -> Prediction={class_name} but label disagrees\nGreen pixels -> Agreement')
    

interactive(children=(Dropdown(description='x_fname', options=('tiff_7_slice_11.npy', 'tiff_14_slice_0.npy', '…