In [1]:
# Standard
import os
import sys
import datetime
import numpy as np
import matplotlib.pyplot as plt

# Utils
import h5py
import SimpleITK as sitk
import cv2
from medpy.io.save import save

# Deep Learning
from sklearn.model_selection import train_test_split
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.models import load_model
import tensorflow as tf

# User defined
sys.path.insert(0, "../")
from architectures.laddernet import LadderNet
from architectures.unet import UNet
from metrics.multiclass_dice import multiclass_dice, dice_lv, dice_la, dice_myo

In [2]:
f = h5py.File("../../data/image_dataset_normalized.hdf5", "r")
frames2ch = f["train 2ch frames"][:,:,:,:]
frames4ch = f["train 4ch frames"][:,:,:,:]

In [3]:
model2ch = load_model("../models/laddernet088_2ch.h5", custom_objects={"multiclass_dice": multiclass_dice,
                                                                       "dice_lv": dice_lv,
                                                                       "dice_la": dice_la,
                                                                       "dice_myo": dice_myo})

model4ch = load_model("../models/laddernet090_4ch.h5", custom_objects={"multiclass_dice": multiclass_dice,
                                                                       "dice_lv": dice_lv,
                                                                       "dice_la": dice_la,
                                                                       "dice_myo": dice_myo})

In [4]:
seg2ch = model2ch.predict(frames2ch)
seg2ch = np.argmax(seg2ch, axis=3)
seg4ch = model4ch.predict(frames4ch)
seg4ch = np.argmax(seg4ch, axis=3)

In [5]:
def mhd_to_array(path):
    return sitk.GetArrayFromImage(sitk.ReadImage(path, sitk.sitkFloat32))

In [6]:
train_path = "../../data/training/"
train_2ch_frames_list = sorted(os.listdir(train_path + "2ch/frames/"))
train_4ch_frames_list = sorted(os.listdir(train_path + "4ch/frames/"))

In [10]:
echo_sizes2ch = []

for i in train_2ch_frames_list:
    if "mhd" in i:
        x = mhd_to_array(os.path.join(train_path, "2ch/frames", i))
        sizes = (x.shape[2], x.shape[1])
        echo_sizes2ch.append(sizes)
        
echo_sizes4ch = []

for i in train_4ch_frames_list:
    if "mhd" in i:
        x = mhd_to_array(os.path.join(train_path, "4ch/frames", i))
        sizes = (x.shape[2], x.shape[1])
        echo_sizes4ch.append(sizes)

In [None]:
j = 1
for i in range(seg2ch.shape[0]):
    pred = seg2ch[i]
    pred = cv2.resize(pred, echo_sizes2ch[i], interpolation=cv2.INTER_NEAREST)
    if (i+1)%2==1:
        save(pred, "../predictions/" + "patient" + "{:04d}".format(j) + "_2CH_ED.mhd")
    else:
        save(pred, "../predictions/" + "patient" + "{:04d}".format(j) + "_2CH_ES.mhd")
        j = j + 1
        
j = 1
for i in range(seg4ch.shape[0]):
    pred = seg4ch[i]
    pred = cv2.resize(pred, echo_sizes4ch[i], interpolation=cv2.INTER_NEAREST)
    if (i+1)%2==1:
        save(pred, "../predictions/" + "patient" + "{:04d}".format(j) + "_2CH_ED.mhd")
    else:
        save(pred, "../predictions/" + "patient" + "{:04d}".format(j) + "_2CH_ES.mhd")
        j = j + 1

In [9]:
echo_sizes[0]

(549, 778)

In [10]:
x.shape

(1, 1232, 869)