In [None]:
import os
import sys
import random
import warnings
import glob
import datetime

import numpy as np
import pandas as pd
import pickle
import shutil
import socket
import GPUtil as GPU
import imageio
import tensorflow as tf
import scipy.misc
import matplotlib.pyplot as plt
import cv2
import elasticdeform

from PIL import Image
from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from PIL import ImageFile
from sklearn.model_selection import train_test_split
from pathlib import Path
from skimage import exposure


from tensorflow import ones_like, equal
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dropout, Lambda
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, BatchNormalization
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import concatenate
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras import backend as K
from skimage import color, io, img_as_float

In [None]:
tf.__version__

In [None]:
noteBookName = "u_net_uterus_seg.ipynb"

In [None]:
# Set some parameters

DATA_DIR = Path('../data/michael/clean')

IMG_WIDTH = 240  # 180 # 240
IMG_HEIGHT = 240 # 180 # 240
IMG_CHANNELS = 1
NUM_CLASSES = 2

NUM_FOLDS = 10
VAL_SPLIT = 0.1
BATCH_SIZE = 8

warnings.filterwarnings('ignore', category=UserWarning, module='skimage')

seed = 42
random.seed = seed
np.random.seed = seed
tf.random.set_seed(seed)

In [None]:
# get first available gpu
if "fabian" in str(socket.gethostname()):
    gpu_str = str(GPU.getFirstAvailable(order="load")[0])
    print("local gpu: " + gpu_str)
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str
else:
    gpu_str = str(GPU.getFirstAvailable(order="load", maxLoad=10**-6, maxMemory=10**-1)[0]) 
    print("server gpu: " + gpu_str)
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str

In [None]:
def get_video_frame(filename, frame_idx=None):
    
    video = get_echo_video(filename)
    n_frames = video.shape[0]
    
    if frame_idx is None:
        frame_idx = random.randrange(0, n_frames)
    
    else:
        if frame_idx >= n_frames:
            frame_idx = n_frames-1
        
    frame = video[frame_idx, ...]
    
    # only return first channel (all are the same)
    return frame

In [None]:
def get_video_as_array(file, do_resize=False):
    
    # Create a VideoCapture object and read from input file
    # If the input is the camera, pass 0 instead of the video file name
    cap = cv2.VideoCapture(file)

    # Check if camera opened successfully
    if (cap.isOpened()== False): 
        print("Error opening video stream or file")

    # Read until video is completed
    frames = []
    while(cap.isOpened()):
        # Capture frame-by-frame
        ret, frame = cap.read()
        if ret == True:
            
            if do_resize:
                frame = resize(frame, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True) 
            
            frames.append(frame)

        # Break the loop
        else: 
            break

    # When everything done, release the video capture object
    cap.release()

    return np.asarray(frames).astype(np.uint8)

In [None]:
def get_frames(original_path, annotated_path, do_resize=True):
    
    original = get_video_as_array(original_path)
    annotated = get_video_as_array(annotated_path)
    
    if original.shape != annotated.shape:
        print("Shapes do not match!")
        print(original.shape,  annotated.shape)
        
        return None, None
    
    original_frames = {}
    annotated_frames = {}
    
    for i in range(annotated.shape[0]):

        frame = annotated[i, ...]

        annotated_frame = ((frame[..., 0] < 10) & (frame[..., 1] < 10) & (frame[..., 2] > 150))        
        if annotated_frame.any():
            
            
            if do_resize:
                original_frame = resize(original[i, ...], 
                                        (IMG_HEIGHT, IMG_WIDTH), 
                                        mode='constant', 
                                        preserve_range=True) 
                annotated_frame = resize(annotated_frame, 
                                         (IMG_HEIGHT, IMG_WIDTH), 
                                         mode='constant', 
                                         preserve_range=True)
            else:
                original_frame = original[i, ...]
                annotated_frame = annotated_frame
                
            
            original_frame = original_frame[..., 0].astype(np.uint8)
            annotated_frame = annotated_frame.astype(np.uint8)
            
                
            
            original_frames[str(i)] = original_frame
            annotated_frames[str(i)] = annotated_frame
                
                
    return original_frames, annotated_frames   

In [None]:
def get_data(data_dir):
    
    patients = os.listdir(data_dir)
    
    data = []
    
    for i, p in tqdm(enumerate(patients), total=len(patients)):
        pat_file_dic = {}
        for f in os.listdir(data_dir / p):
            path = str(data_dir / p / f)
            pat_file_dic['pat'] = p

            # check for correct files
            if p.replace('case', '') in f:
                if 'rendered.mp4' in f:
                    pat_file_dic['orig_video_path'] = path
                if 'annotated.mp4' in f:
                    pat_file_dic['anno_video_path'] = path

        if len(pat_file_dic.keys()) == 3:

            o_frames, a_frames = get_frames(pat_file_dic['orig_video_path'],
                                                                     pat_file_dic['anno_video_path'])

            if o_frames is None:
                continue

            pat_file_dic['original_frames'] = o_frames
            pat_file_dic['annotated_frames'] = a_frames

            data.append(pat_file_dic)
        else:
            print(f"Problem with pat: {p}. Incorrect file name.")
    
    return data

In [None]:
def to_one_hot(frame, n_classes=2):
    n_values = np.max(frame) + 1
    one_hot = np.eye(n_values)[frame]
    return one_hot

In [None]:
def augment_frame(frame, label, is_train=True):
    
    frame = np.asarray(frame, dtype=np.float32)
    label =  np.asarray(label, dtype=np.float32)
    if np.max(frame) > 1:
        frame /= 255.0
    
    if is_train:
        # contrast augmentation
        k_scale = random.randrange(6, 9, 1)
        clip_limit = random.randrange(2, 5, 1)/100
        
        # TODO: add brightness augmentation?
        
    else:
        k_scale = 8
        clip_limit = 0.03
      
    #frame = ndimage.median_filter(frame, size=3) # (makes it slower)
    
    # normalize frame
    frame = exposure.equalize_adapthist(frame, 
                                        clip_limit=clip_limit, 
                                        kernel_size=[int(d/k_scale) for d in frame.shape])
    
    if is_train:
        
        # augment 50% of the frames
        if random.random() > 0.5:
            frame, label = elasticdeform.deform_random_grid([frame, label], 
                                                     sigma=10, 
                                                     points=3,
                                                     rotate=random.randrange(-15,15, 1), 
                                                     zoom=random.randrange(9,14, 1)/10)

    label = np.asarray(label > 0.5, dtype=np.uint8)
    
    return frame, label

In [None]:
def get_dataset(data, is_train=True):
    
    def _yield_frames():
        for d in data:
            for k in d['original_frames'].keys():

                original_frame = d['original_frames'][k]
                annotated_frame = d['annotated_frames'][k]
                
                original_frame, annotated_frame = augment_frame(original_frame, annotated_frame, is_train=is_train)
                
                original_frame = original_frame[..., np.newaxis]
                annotated_frame = to_one_hot(annotated_frame, n_classes=NUM_CLASSES)
                
                if annotated_frame.shape[-1] != 2:
                    continue
                

                yield original_frame, annotated_frame
        
    
    if is_train:
        ds = tf.data.Dataset.from_generator(_yield_frames,
                                             output_signature=(
                                                 tf.TensorSpec(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), 
                                                               dtype=tf.float32),
                                                 tf.TensorSpec(shape=(IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), 
                                                               dtype=tf.float32)))           
        ds = ds.repeat().shuffle(48)
        ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
        
        
    else:
        ds = tf.data.Dataset.from_generator(_yield_frames,
                                             output_signature=(
                                                 tf.TensorSpec(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), 
                                                               dtype=tf.float32),
                                                 tf.TensorSpec(shape=(IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), 
                                                               dtype=tf.float32)))           
        ds = ds.batch(1).prefetch(tf.data.AUTOTUNE)
        
    return ds

In [None]:
def weighted_binary_crossentropy(w1, w2):
    '''
    w1 and w2 are the weights for the two classes.
    Computes weighted binary crossentropy
    Use like so:  model.compile(loss=weighted_binary_crossentropy(), optimizer="adam", metrics=["accuracy"])
    '''

    def loss(y_true, y_pred):
        # avoid absolute 0
        y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
        ones = ones_like(y_true)
        msk = equal(y_true, ones)
        # tensor of booleans of length == y_true; true means that the true class is 1

        res, _ = tf.map_fn(lambda x: (mul(-tf.math.log(x[0]), w1) if x[1] is True
                                      else mul(-tf.math.log(1 - x[0]), w2), x[1]),
                           (y_pred, msk), dtype=(tf.float32, tf.bool))

        return res

    return loss

In [None]:
# Build U-Net model
def build_model(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS):
    
    inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))

    c1 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (inputs)
    c1 = BatchNormalization()(c1)
    c1 = Dropout(0.1) (c1)
    c1 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c1)
    c1 = BatchNormalization()(c1)
    p1 = MaxPooling2D((2, 2)) (c1)

    c2 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p1)
    c2 = BatchNormalization()(c2)
    c2 = Dropout(0.1) (c2)
    c2 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c2)
    c2 = BatchNormalization()(c2)
    p2 = MaxPooling2D((2, 2)) (c2)

    c3 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p2)
    c3 = BatchNormalization()(c3)
    c3 = Dropout(0.2) (c3)
    c3 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c3)
    c3 = BatchNormalization()(c3)
    p3 = MaxPooling2D((2, 2)) (c3)

    c4 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p3)
    c4 = BatchNormalization()(c4)
    c4 = Dropout(0.2) (c4)
    c4 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c4)
    c4 = BatchNormalization()(c4)
    p4 = MaxPooling2D(pool_size=(2, 2)) (c4)

    c5 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p4)
    c5 = BatchNormalization()(c5)
    c5 = Dropout(0.3) (c5)
    c5 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c5)
    c5 = BatchNormalization()(c5)

    u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same') (c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u6)
    c6 = BatchNormalization()(c6)
    c6 = Dropout(0.2) (c6)
    c6 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c6)
    c6 = BatchNormalization()(c6)

    u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u7)
    c7 = BatchNormalization()(c7)
    c7 = Dropout(0.2) (c7)
    c7 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c7)
    c7 = BatchNormalization()(c7)

    u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u8)
    c8 = BatchNormalization()(c8)
    c8 = Dropout(0.1) (c8)
    c8 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c8)
    c8 = BatchNormalization()(c8)

    u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c8)
    u9 = concatenate([u9, c1], axis=3)
    c9 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u9)
    c9 = BatchNormalization()(c9)
    c9 = Dropout(0.1) (c9)
    c9 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c9)
    c9 = BatchNormalization()(c9)

    outputs = Conv2D(NUM_CLASSES, (1, 1), activation='softmax') (c9)
    
    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model

In [None]:
def get_dice(im1, im2):
    """
    Computes the Dice coefficient, a measure of set similarity.
    Parameters
    ----------
    im1 : array-like, bool
        Any array of arbitrary size. If not boolean, will be converted.
    im2 : array-like, bool
        Any other array of identical size. If not boolean, will be converted.
    Returns
    -------
    dice : float
        Dice coefficient as a float on range [0,1].
        Maximum similarity = 1
        No similarity = 0
        
    Notes
    -----
    The order of inputs for `dice` is irrelevant. The result will be
    identical if `im1` and `im2` are switched.
    """
    im1 = np.asarray(im1).astype(np.bool)
    im2 = np.asarray(im2).astype(np.bool)

    if im1.shape != im2.shape:
        raise ValueError("Shape mismatch: im1 and im2 must have the same shape.")

    # Compute Dice coefficient
    intersection = np.logical_and(im1, im2)

    return 2. * intersection.sum() / (im1.sum() + im2.sum())

In [None]:
def get_iou(target, prediction):
    
    target = np.asarray(target).astype(np.bool)
    prediction = np.asarray(prediction).astype(np.bool)
    
    intersection = np.logical_and(target, prediction)
    union = np.logical_or(target, prediction)
    iou_score = np.sum(intersection) / np.sum(union)
    return iou_score

In [None]:
def overlay_imgs(frame, prediction=None, groundtruth=None):
    
    if prediction is None:
        prediction = groundtruth
    
    if groundtruth is None:
        groundtruth = prediction
    
    overlayed = np.clip(np.dstack([0.75*frame + prediction,
                                     0.75*frame,
                                     0.75*frame + groundtruth]), a_min=0, a_max=1)
    
    return overlayed

In [None]:
model = build_model(IMG_HEIGHT, IMG_WIDTH, 1)

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
model.summary()

In [None]:
data = get_data(DATA_DIR)

### Some visualizations

In [None]:
annotated_frames_indices = list(data[0]['original_frames'].keys())

In [None]:
frame, label = augment_frame(data[0]['original_frames'][annotated_frames_indices[0]], 
                             data[0]['annotated_frames'][annotated_frames_indices[0]], is_train=False)
plt.imshow(frame, cmap='gray')
#plt.imshow(label, alpha=0.3)

In [None]:
frame, label = augment_frame(data[0]['original_frames'][annotated_frames_indices[0]], 
                             data[0]['annotated_frames'][annotated_frames_indices[0]], is_train=True)

plt.imshow(frame, cmap='gray')

In [None]:
# shuffle data
random.shuffle(data)
train_dataset = get_dataset(data[0:int(0.8*len(data))])
val_dataset = get_dataset(data[int(0.8*len(data)):], is_train=False)

## Fit model (visualize with weights and biases)

In [None]:
epochs = 100

callbacks = [
    tf.keras.callbacks.ModelCheckpoint("seg_model_best.h5", save_best_only=True),
    tf.keras.callbacks.EarlyStopping(patience=15, mode="auto",restore_best_weights=True, verbose=True)
]

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss = "categorical_crossentropy",
    metrics=["accuracy"],
)
history = model.fit(train_dataset, 
                    epochs=epochs, 
                    callbacks=callbacks, 
                    validation_data=val_dataset,
                    verbose=True,
                    steps_per_epoch=100)

In [None]:
model = load_model('seg_model_best.h5')

In [None]:
inputs = list(train_dataset.take(10).as_numpy_iterator())
frames = np.vstack([f for f, l in inputs])
labels = np.vstack([l for f, l in inputs])
predictions = model.predict(frames)

In [None]:
for f, l, p in zip(frames, labels, predictions):
    
    fig, axs = plt.subplots(1, 3, figsize=(16,6))
    
    axs[0].imshow(f[...], cmap='gray')
    axs[0].imshow(l[..., 0], alpha=0.1)
    axs[1].imshow(p[..., 0])
    axs[2].imshow(np.abs(l[..., 0] - (p[..., 0])))

In [None]:
# TODO: make video (prediction and prediction overlayed on real) in real size of test data

In [None]:
validation_data = data[int(0.8*len(data)):]

In [None]:
for p in validation_data:
    
    print(p['pat'])
        
    video = get_video_as_array(p['orig_video_path'], do_resize=False)
    orig_height, orig_width = video.shape[1:3]
    
    predictions = []
    frames = []
    overlays = []
    for f in video:
        orig_frame = f[..., 0]
        
        frames.append(orig_frame)
        
        label = orig_frame
        
        frame = resize(orig_frame, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
        frame, label = augment_frame(frame, label, is_train=False)
        frame = frame[np.newaxis, ..., np.newaxis]
        
        prediction = model.predict(frame, verbose=False)
        
        # thresholding prediction
        prediction = prediction[0, ..., 1] > 0.5
        
        # resize to original
        prediction = resize(prediction, (orig_height, orig_width), mode='constant', preserve_range=True)    
        predictions.append(prediction)
        
        overlayed = overlay_imgs(orig_frame/255.0, prediction=prediction*1.0, groundtruth=None)
        overlays.append(overlayed*255)
        
    
    # save as gif
    save_path = 'videos'
    os.makedirs(save_path, exist_ok=True)
    
    file = os.path.join(save_path, p['pat'] + "_original.avi")
    imageio.mimsave(file, [np.asarray(frame, dtype=np.uint8) for frame in frames],fps=25)
    
    file = os.path.join(save_path, p['pat'] + "_prediction.avi")
    imageio.mimsave(file, [np.asarray(p*255, dtype=np.uint8) for p in predictions],fps=25)
    
        
    file = os.path.join(save_path, p['pat'] + "_overlayed.avi")
    imageio.mimsave(file, [np.asarray(o, dtype=np.uint8) for o in overlays],fps=25)