In [None]:
import os
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from libs.pconv_model import PConvUnet
from libs.util import ImageChunker, MaskGenerator
import matplotlib.image as mpimg
import cv2
from matplotlib import cm
from mpl_toolkits import mplot3d
import tensorflow as tf
import itertools
try:
    import cPickle as pickle
except ImportError:  # python 3.x
    import pickle


os.environ["CUDA_VISIBLE_DEVICES"] = "6"

def plot_pred(_pred, _mask, _orig):
    _, axes = plt.subplots(1, 3, figsize=(20, 5))
    axes[0].imshow(_orig * _mask)
    axes[1].imshow(_pred)
    axes[2].imshow(_orig)
    axes[0].set_title('Masked Image')
    axes[1].set_title('Predicted Image')
    axes[2].set_title('Original Image')


def heatmap2d(x, y, z, _orig=None, fname=None):
    z_min, z_max = -np.abs(z).max(), np.abs(z).max()
    if _orig is None:
        fig, ax = plt.subplots()
        c = ax.pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
        #ax.set_title('')
        ax.axis([x.min(), x.max(), y.min(), y.max()])
        fig.colorbar(c, ax=ax)
        plt.show()
    else:
        _, axes = plt.subplots(1, 2, figsize=(20, 8))
        axes[0].imshow(_orig)
        axes[0].set_title('Original Image')
        c = axes[1].pcolormesh(x, y, z, cmap='RdBu', vmin=z_min, vmax=z_max)
        axes[1].set_title('Diff. map')
        axes[1].axis([x.min(), x.max(), y.min(), y.max()])
        _.colorbar(c, ax=axes[1])
        plt.show()
    if fname:
        plt.savefig(fname)


def diff_plot(_pred, _orig, quiet=False):
    
    # Look at the differences in each channel then average them
    for n in range(3):
        ch_pred = _pred[:,:,n]
        ch_orig = _orig[:,:,n]
        if n==0:
            ch_diff = ch_pred-ch_orig
        else:
            ch_diff += ch_pred-ch_orig
    ch_diff = ch_diff/3
    
    X2, Y2 = np.meshgrid(range(_pred.shape[0]), range(_pred.shape[1]))
    
    # Normalize the colors based on Z value
    norm = plt.Normalize(ch_diff.min(), ch_diff.max())
    colors = cm.jet(norm(ch_diff))
    
    if not quiet:
        ax = plt.axes(projection='3d')
        surf = ax.plot_surface(X2, Y2, ch_diff, facecolors=colors, shade=False)
        surf.set_facecolor((0,0,0,0))
    return X2, Y2, ch_diff

# Function to overwrite original image with pixels from prediction
def apply_mask(_mask, _pred, _orig):
    masked = _orig * _mask
    shape = masked.shape
    for x in range(shape[0]):
        for y in range(shape[1]):
            if masked[x, y].all() < 0.01:
                masked[x, y] = _pred[x, y]
    return masked

def x_channel(in_rgb, channel='r'):
    img = deepcopy(in_rgb)
    if channel=='r':
        img[:,:,1] = 0
        img[:,:,2] = 0
    elif channel=='g':
        img[:,:,0] = 0
        img[:,:,2] = 0
    else: # Assume b
        img[:,:,0] = 0
        img[:,:,1] = 0
    return img

def x_channel_ex(in_rgb, channel='r'):
    img = deepcopy(in_rgb)
    if channel=='r':
        return img[:,:,0]
    elif channel=='g':
        return img[:,:,1]
    else: # Assume b
        return img[:,:,2]
    
def RegularMaskGenerator(_width, _height, _xmin, _ymin, _sizex, _sizey):
    _xset = range(_xmin, _xmin+_sizex)
    _yset = range(_ymin, _ymin+_sizey)
    replacement_color_w = (1, 1, 1)
    replacement_color_b = (0, 0, 0)

    _i_mask = np.zeros((_width, _height, 3), dtype=np.uint8)

    for x in range(_width):
        for y in range(_height):
            if x in _xset and y in _yset:
                _i_mask[x, y] = replacement_color_b
            else:
                _i_mask[x, y] = replacement_color_w
                
    return _i_mask, _xset, _yset

def plotScans(_res, _image):
    _entry = _res[_image]
    _scans = _entry['scans']
    _image = _entry['image']
    _mask = _entry["mask"]
    
    # plot mask and image
    _, axes = plt.subplots(3, 3, figsize=(20, 20))
    _sc = axes[0, 0].imshow(_scans[0], origin='lower')
    axes[0, 1].imshow(_image[:,:,2], origin='lower')
    axes[0, 2].imshow(_mask[:,:,0], origin='lower')
    _.colorbar(_sc, ax=axes[0,0])
    axes[1, 0].imshow(_scans[1], origin='lower')
    axes[1, 1].imshow(_image[:,:,0], origin='lower')
    axes[1, 2].imshow(_mask[:,:,0], origin='lower')
    axes[2, 0].imshow(_scans[2], origin='lower')
    axes[2, 1].imshow(_image[:,:,1], origin='lower')
    axes[2, 2].imshow(_mask[:,:,0], origin='lower')

In [None]:
SAMPLE_IMAGE = 'images/PGBM-017_08-21-1997-MR_RCBV_SEQUENCE-73885_7_False.png'

ori = cv2.imread(SAMPLE_IMAGE, cv2.IMREAD_UNCHANGED)
ori = cv2.cvtColor(ori, cv2.COLOR_BGR2RGB)

In [None]:
model = PConvUnet(vgg_weights="/DATA/sb/pytorch_to_keras_vgg16.h5", inference_only=True)
model.load("data/logs/brain_tp_dt1_phase3/weights.291-0.22.h5", train_bn=False)

In [None]:
mask_gen = MaskGenerator(512, 512, rand_seed=42)
mask_1 = mask_gen._generate_mask()

In [None]:
chunker = ImageChunker(512, 512, 0)

ori = ori / 255.

prepro_img = chunker.dimension_preprocess(deepcopy(ori))
prepro_mask = chunker.dimension_preprocess(deepcopy(mask_1))
pred = model.predict([prepro_img, prepro_mask])
reconstructed_image = chunker.dimension_postprocess(pred, ori)

In [None]:
# plot mask and image
fig, axes = plt.subplots(3, 3, figsize=(20, 20))
axes[0, 0].imshow(reconstructed_image[:,:,0], origin='lower')
axes[0, 1].imshow(reconstructed_image[:,:,1], origin='lower')
axes[0, 2].imshow(reconstructed_image[:,:,2], origin='lower')
axes[1, 0].imshow((ori)[:,:,0], origin='lower')
axes[1, 1].imshow((ori)[:,:,1], origin='lower')
axes[1, 2].imshow((ori)[:,:,2], origin='lower')
axes[2, 0].imshow((ori*mask_1)[:,:,0], origin='lower')
axes[2, 1].imshow((ori*mask_1)[:,:,1], origin='lower')
axes[2, 2].imshow((ori*mask_1)[:,:,2], origin='lower')
axes[0, 0].set_title("Model (FLAIR)")
axes[0, 1].set_title("Model (ADC)")
axes[0, 2].set_title("Model (dT1)")
axes[1, 0].set_title("Original (FLAIR)")
axes[1, 1].set_title("Original (ADC)")
axes[1, 2].set_title("Original (dT1)")
axes[2, 0].set_title("Masked (FLAIR)")
axes[2, 1].set_title("Masked (ADC)")
axes[2, 2].set_title("Masked (dT1)")
plt.savefig("comparison.png")