# Import Library

In [None]:
import os
import math
import cv2
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from transformers import BertConfig, BertModel

import torch
import torch.nn as nn
import tqdm

import tensorflow as tf
from tensorflow.keras.layers import concatenate, Dropout, Conv2DTranspose, Input, Conv2D, MaxPooling2D
from keras.src.losses import SparseCategoricalCrossentropy
from sklearn.metrics import mean_squared_error

# Paths

In [None]:
# DataSet
_dataRootPath = ''
_videosPath = _dataRootPath + 'Videos'
_fileNamesPath = _dataRootPath + 'FileList.csv'
_volumeTracingPath = _dataRootPath + 'VolumeTracings.csv'

# Loaded Videos
_loadedVideosPath = ''

# Transformer
_transformerModelPath = ''

# U-NET
_trueMasksPath = ''
_ED_Model_Path = ''
_ES_Model_Path = ''

# DataModel

In [None]:
class LandMarks:
    def __init__(self, X1, Y1, X2, Y2):
        self.X1 = X1
        self.Y1 = Y1
        self.X2 = X2
        self.Y2 = Y2

    def displayInfo(self):
        print(f"""
              land Marks are :
                    X1 is  {self.X1}
                    Y1 is {self.Y1}
                    X2 is {self.X2}
                    Y2 is {self.Y2}""")


class VideoData:
    def __init__(self, fileName, EF_value, ED_value, ES_value, ED_frame, ES_frame, Split, ED_landMark, ES_landMark,
                 numberOfFrames,
                 ED_Frame_IMG, ES_Frame_IMG):
        self.fileName = fileName
        self.EF_value = EF_value
        self.ED_value = ED_value
        self.ES_value = ES_value
        self.ED_frame = ED_frame
        self.ES_frame = ES_frame
        self.Split = Split
        self.ED_landMark = ED_landMark
        self.ES_landMark = ES_landMark
        self.numberOfFrames = numberOfFrames
        self.ED_Frame_IMG = ED_Frame_IMG
        self.ES_Frame_IMG = ES_Frame_IMG

    def displayInfo(self):
        print(f"""
        Video Information:
              File Name is  {self.fileName}
              EF Value is {self.EF_value}
              ES Value is {self.ES_value}
              ED Value is {self.ED_value}
              ED Frame is {self.ED_frame}
              ES Frame is {self.ES_frame}
              Split is {self.Split}
              numberOfFrames is {self.numberOfFrames}""")

# HelperFunction

In [None]:
def _FilterNot_42rows(VolumeTracings, FileList):
    VolumeTracings.dropna(inplace=True)
    FileList.dropna(inplace=True)
    VolumeTracings_names = VolumeTracings['FileName']
    # VolumeTracings_names_no_extension = np.array([name[:-4] for name in VolumeTracings_names])

    VolumeUniqueName, frame_counter = np.unique(VolumeTracings_names, return_counts=True)
    Video_counts = dict(zip(VolumeUniqueName, frame_counter))

    no_rows = 0
    not_42_Rows_video_names = []
    for vName, count in Video_counts.items():
        if count != 42:
            no_rows += count
            not_42_Rows_video_names.append(vName)

    VolumeTracings = VolumeTracings[~VolumeTracings['FileName'].isin(not_42_Rows_video_names)]
    FileList = FileList[(FileList['FileName'] + ".avi").isin(VolumeTracings['FileName'])]

    # Delete rows where 'FileName' column has value '0X4F8859C8AB4DA9CB.avi'
    VolumeTracings = VolumeTracings[VolumeTracings['FileName'] != '0X4F8859C8AB4DA9CB.avi']

    return VolumeTracings, FileList


def _loadAlldata(split_type):
    FileList = pd.read_csv(_fileNamesPath)
    VolumeTracings = pd.read_csv(_volumeTracingPath)

    VolumeTracings, FileList = _FilterNot_42rows(VolumeTracings, FileList)

    leftVentricle_list = []

    VolumeTracings.dropna(inplace=True)
    FileList.dropna(inplace=True)

    for i in range(FileList.iloc[:, 0].size):

        Split = FileList.iloc[i, 8]
        if split_type != "ALL":
            if split_type != Split:
                continue
        fileName = FileList.iloc[i, 0]

        VT = VolumeTracings[VolumeTracings['FileName'] == fileName + '.avi']
        unique_Frames = VT['Frame'].unique()

        if len(unique_Frames) == 0:
            continue

        ED_Frame = unique_Frames[0]

        ES_Frame = unique_Frames[1]
        ED_tmp = VT[VT['Frame'] == ED_Frame]
        ES_tmp = VT[VT['Frame'] == ES_Frame]

        if len(ED_tmp) != 21 or len(ES_tmp) != 21:
            continue
        ED_landMark = LandMarks([], [], [], [])
        ES_landMark = LandMarks([], [], [], [])

        for k in range(21):
            ED_landMark.X1.append(ED_tmp.iloc[k, 1])
            ED_landMark.Y1.append(ED_tmp.iloc[k, 2])
            ED_landMark.X2.append(ED_tmp.iloc[k, 3])
            ED_landMark.Y2.append(ED_tmp.iloc[k, 4])

            ES_landMark.X1.append(ES_tmp.iloc[k, 1])
            ES_landMark.Y1.append(ES_tmp.iloc[k, 2])
            ES_landMark.X2.append(ES_tmp.iloc[k, 3])
            ES_landMark.Y2.append(ES_tmp.iloc[k, 4])

        EF_value = FileList.iloc[i, 1]
        ED_value = FileList.iloc[i, 2]
        ES_value = FileList.iloc[i, 3]
        numberOfFrames = FileList.iloc[i, 7]

        video_path = os.path.join(_videosPath, fileName + '.avi')

        cap = cv2.VideoCapture(video_path)

        if not cap.isOpened():
            print("Error opening video file")

        cap.set(cv2.CAP_PROP_POS_FRAMES, ED_Frame - 1)
        _, ED_Frame_IMG = cap.read()

        cap.set(cv2.CAP_PROP_POS_FRAMES, ES_Frame - 1)
        _, ES_Frame_IMG = cap.read()

        cap.release()

        obj = VideoData(fileName, EF_value, ED_value, ES_value, ED_Frame, ES_Frame, Split, ED_landMark,
                        ES_landMark, numberOfFrames, ED_Frame_IMG, ES_Frame_IMG)

        leftVentricle_list.append(obj)
    return leftVentricle_list


def load_or_get_data(spilt_type="ALL"):
    if spilt_type not in ['TRAIN', 'TEST', 'VAL', 'ALL']:
        print('Error not valid split type')
        return None

    if not os.path.exists(_loadedVideosPath):
        os.makedirs(_loadedVideosPath)
        print(f'{_loadedVideosPath} created')

    file_path = f'{_loadedVideosPath}/Loaded_Videos_Objects_{spilt_type}.pkl'
    # If file exists, load the data from the file
    if os.path.exists(file_path):
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
    # If file doesn't exist, execute loadAlldata() to get the data
    else:
        data = _loadAlldata(spilt_type)
        with open(file_path, 'wb') as f:
            pickle.dump(data, f)

    return data


def _extractVideoFrames(path):
    capture = cv2.VideoCapture(str(path))

    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frames = np.zeros((frame_count, frame_width, frame_height, 3), np.uint8)

    for count in range(frame_count):
        ret, frame = capture.read()
        if not ret:
            raise ValueError("Failed to load frame #{} of {}.".format(count, path))
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames[count] = frame

    return frames


# Transformer Data
def _mirroringVideo(video_obj):
    original_tuple = []
    desired_length = 128

    path = os.path.join(_videosPath, video_obj.fileName + '.avi')

    if not os.path.exists(path):
        raise FileNotFoundError(path)

    v = _extractVideoFrames(path)

    # Mirror
    start = min(video_obj.ED_frame, video_obj.ES_frame)
    end = max(video_obj.ED_frame, video_obj.ES_frame) + 1

    for i in range(start, end):
        img = v[i]
        if video_obj.ED_frame == i:
            original_tuple.append((img, "ED"))
        elif video_obj.ES_frame == i:
            original_tuple.append((img, "ES"))
        else:
            original_tuple.append((img, "Transition"))

    while len(original_tuple) < desired_length:

        # Create a mirrored dictionary by reversing keys and values
        mirrored_tuple = list(reversed(original_tuple))[1:]

        # Append the mirrored list
        original_tuple.extend(mirrored_tuple)

        # If the list length exceeds desired_length, break the loop
        if len(original_tuple) >= desired_length:
            break

        # Append the original list again
        original_tuple.extend(original_tuple[1:])

    # Trim the dictionary to desired_length if it exceeds it
    original_tuple = original_tuple[:desired_length]

    return original_tuple


# U-Net Data
def _prepareDataToPolygon(landmark):
    data = []
    for i in range(21):
        data.append((landmark.X1[i], landmark.Y1[i]))

    for i in range(21):
        data.append((landmark.X2[i], landmark.Y2[i]))

    if data[0][1] > data[21][1]:
        tmp = data[0]
        data[0] = data[21]
        data[21] = tmp

    if data[21][0] < data[20][0]:
        tmp = data[21]
        data[21] = data[20]
        data[20] = tmp

    tmp = data[22:]
    data[22:] = tmp[::-1]

    return data


def _createBinaryMask(landmark):
    vertices = _prepareDataToPolygon(landmark)

    # Create an empty black image
    mask = np.zeros((112, 112)).astype(float)

    vertices = np.array(vertices)
    vertices = np.round(vertices)
    pts = vertices.astype(int)

    cv2.fillPoly(mask, [pts], color=(255, 255, 255))

    mask[mask == 255] = 1

    return mask


def _createImageAndMaskFolders(frameType, split, path):
    image_path = path + f'/Frames_{frameType}/'
    mask_path = path + f'/Masks_{frameType}/'

    try:
        os.makedirs(image_path)
        print(f'Frames_{frameType} create')
    except OSError:
        print(f'Frames_{frameType} is exist')

    try:
        os.makedirs(mask_path)
        print(f'Masks_{frameType} create')
    except OSError:
        print(f'Masks_{frameType} is exist')

    image_path += f'{split}/'
    mask_path += f'{split}/'

    try:
        os.makedirs(image_path)
        print(f'Frames_{frameType}/{split} create')

    except OSError:
        print(f'Frames_{frameType}/{split} is exist')

    try:
        os.makedirs(mask_path)
        print(f'Masks_{frameType}/{split} create')

    except OSError:
        print(f'Masks_{frameType}/{split} is exist')
    return image_path, mask_path


def _saveImageAndMask(frameType, split, trueMasksPath='', ):
    data_set = load_or_get_data(split)
    if data_set is None:
        return

    image_path, mask_path = _createImageAndMaskFolders(frameType, split, trueMasksPath)
    img = None
    landmarks = None
    for obj in data_set:
        if frameType == 'ES':
            img = obj.ES_Frame_IMG
            landmarks = obj.ES_landMark
        elif frameType == 'ED':
            img = obj.ED_Frame_IMG
            landmarks = obj.ED_landMark

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = _createBinaryMask(landmarks)

        cv2.imwrite(mask_path + f'{obj.fileName}.png', mask)
        cv2.imwrite(image_path + f'{obj.fileName}.png', img)


def CreateAllMasks(trueMasksPath):
    _saveImageAndMask(frameType='ES', split='TRAIN', trueMasksPath=trueMasksPath)
    _saveImageAndMask(frameType='ES', split='TEST', trueMasksPath=trueMasksPath)
    _saveImageAndMask(frameType='ES', split='VAL', trueMasksPath=trueMasksPath)

    _saveImageAndMask(frameType='ED', split='TRAIN', trueMasksPath=trueMasksPath)
    _saveImageAndMask(frameType='ED', split='TEST', trueMasksPath=trueMasksPath)
    _saveImageAndMask(frameType='ED', split='VAL', trueMasksPath=trueMasksPath)


# Read Data
def _process_path(image_path, mask_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)

    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=3)
    mask = tf.math.reduce_max(mask, axis=-1, keepdims=True)
    return img, mask


def getImageAndMasks(frameType='', split='', trueMasksPath=''):
    if not os.path.exists(trueMasksPath):
        os.makedirs(trueMasksPath)
        print(f'{trueMasksPath} created')
    image_path = os.path.join(trueMasksPath, f"Frames_{frameType}/{split}/")
    mask_path = os.path.join(trueMasksPath, f"Masks_{frameType}/{split}/")
    image_list = os.listdir(image_path)
    mask_list = os.listdir(mask_path)
    image_list = [image_path + i for i in image_list]
    mask_list = [mask_path + i for i in mask_list]

    image_filenames = tf.constant(image_list)
    masks_filenames = tf.constant(mask_list)

    dataset = tf.data.Dataset.from_tensor_slices((image_filenames, masks_filenames))
    print(split, len(mask_list))

    image_ds = dataset.map(_process_path)

    return image_ds


# Predict LandMask From Mask
def _getHorizontalLabel(mask):
    upXY = (112, 0)

    downXY_right = (0, 0)

    downXY_left = (0, 112)

    for x in range(112):
        for y in range(112):
            if mask[x][y] == 0:
                continue

            if upXY[0] > x and upXY[1] < y:
                upXY = (x, y)
            elif downXY_right[0] <= x and downXY_right[1] <= y:
                downXY_right = (x, y)
            elif downXY_left[0] < x or downXY_left[1] > y:
                downXY_left = (x, y)

    midpoint = ((downXY_left[0] + downXY_right[0]) // 2, (downXY_left[1] + downXY_right[1]) // 2)
    for x in range(112):
        for y in range(112):
            midpoint = (midpoint[0] + 1, midpoint[1])

            if midpoint[0] == 112 or mask[midpoint[0]][midpoint[1]] == 0:
                midpoint = (midpoint[0] - 1, midpoint[1])
                break

    #     for i in range(112):
    #         for j in range(112):
    #             if mask[i][j] == 1:
    #                 img[i][j] = (250, 250, 250)

    # plt.scatter(upXY[1], upXY[0], color='orange', marker='o')
    #
    # plt.scatter(downXY_right[1], downXY_right[0], color='orange', marker='o')
    #
    # plt.scatter(downXY_left[1], downXY_left[0], color='orange', marker='o')
    # plt.scatter(midpoint[1], midpoint[0], color='orange', marker='X')

    if midpoint[1] == upXY[1]:
        midpoint = (midpoint[0], midpoint[1] + 0.1)

    return (midpoint[1], midpoint[0]), (upXY[1], upXY[0])


def _perpendicular_points(x1, y1, x2, y2, distance):
    # Calculate slope of the original line
    if x2 - x1 != 0:  # Avoid division by zero
        slope_original = (y2 - y1) / (x2 - x1)
        # Calculate negative reciprocal to get slope of perpendicular line
        slope_perpendicular = -1 / slope_original
    else:
        slope_perpendicular = float('inf')  # Handle vertical lines

    # Find midpoint of the original line

    # Calculate unit vector along the original line
    magnitude = ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5
    unit_vector_x = (x2 - x1) / magnitude
    unit_vector_y = (y2 - y1) / magnitude

    # Calculate displacement vector based on distance
    displacement_x = unit_vector_x * distance
    displacement_y = unit_vector_y * distance

    # New midpoint for the perpendicular line
    new_mid_x = x1 + displacement_x
    new_mid_y = y1 + displacement_y

    # Find points for the perpendicular line
    dx = 1 / (1 + slope_perpendicular ** 2) ** 0.5
    dy = slope_perpendicular * dx

    # Two points for the perpendicular line
    perpendicular_point1 = (new_mid_x + dx, new_mid_y + dy)
    perpendicular_point2 = (new_mid_x - dx, new_mid_y - dy)

    return perpendicular_point1, perpendicular_point2


def _find_previous_or_next_points(point1=None, point2=None, t='n'):
    # Extract coordinates of the two points
    x1, y1 = point1
    x2, y2 = point2
    x, y = 0, 0
    # Calculate the distance between point1 and point2
    distance = ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5

    # Calculate the slope of the line
    if x2 - x1 != 0:  # Avoid division by zero
        slope = (y2 - y1) / (x2 - x1)
    else:
        slope = None  # Line is vertical

    if t == 'n':
        # Calculate the next point
        if slope is not None:
            # If the line is not vertical, find next x and y
            x = x2 + (x2 - x1) / distance
            y = y2 + (y2 - y1) / distance
        else:
            # If the line is vertical, next point has the same x-coordinate
            x = x2
            y = y2 - distance
    if t == 'p':
        # Calculate the previous point
        if slope is not None:
            # If the line is not vertical, find previous x and y
            x = x1 - (x2 - x1) / distance
            y = y1 - (y2 - y1) / distance
        else:
            # If the line is vertical, previous point has the same x-coordinate
            x = x1
            y = y1 - distance

    return x, y


def _getPointsInMask(point1, point2, mask):
    # plt.plot([point1[0], point2[0]], [point1[1], point2[1]], marker='o', label='Points and Line')

    pn = point2
    lpn = None
    for i in range(112):
        if i == 0:
            pn = _find_previous_or_next_points(point1, point2, t='n')
            lpn = pn
        x = int(np.round(pn[1]))
        y = int(np.round(pn[0]))
        if x == 112 or mask[x][y] == 0:
            break
        else:
            lpn = pn
            pn = _find_previous_or_next_points(point1, pn, t='n')
        # plt.plot(pn[0], pn[1], marker='o', color='red', label='Next Point')

    pn = lpn

    pp = None
    lpp = point1
    for i in range(112):
        if i == 0:
            pp = _find_previous_or_next_points(point1, point2, t='p')
        if int(np.round(pp[1])) == 112 or int(np.round(pp[0])) == 112 \
                or mask[int(np.round(pp[1]))][int(np.round(pp[0]))] == 0:
            break
        else:
            lpp = pp
            pp = _find_previous_or_next_points(pp, point2, t='p')
        # plt.plot(pp[0], pp[1], marker='o', color='red')

    pp = lpp

    # plt.plot([pp[0], pn[0]], [pp[1], pn[1]], color='blue')

    return pn, pp


def _GetPointOfSegmentMask(mask=None):
    labels = [_getHorizontalLabel(mask)]

    x1, y1 = labels[0][0]
    x2, y2 = labels[0][1]

    numOfLines = 20
    distance = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) / numOfLines

    point = []
    for i in range(numOfLines):
        point1, point2 = _perpendicular_points(x1, y1, x2, y2, distance * i)
        point.append((point1, point2))

        # plt.plot([x1, x2], [y1, y2], label="Original Line")
        # plt.plot([point1[0], point2[0]], [point1[1], point2[1]], label="Perpendicular Line")
        # plt.scatter([x1, x2], [y1, y2], color='red')
        # plt.scatter([point1[0], point2[0]], [point1[1], point2[1]], color='blue')

    for i in range(numOfLines):
        labels.append(_getPointsInMask(point[i][0], point[i][1], mask))

    #     for i in range(112):
    #         for j in range(112):
    #             if mask[i][j] == 1:
    #                 img[i][j] = (250, 250, 250)

    #     x = 'x'
    #     for p1, p2 in labels:
    #         plt.plot(p1[0], p1[1], marker=x, color='red', label='Next Point')
    #         plt.plot(p2[0], p2[1], marker=x, color='red', label='Next Point')
    #         plt.plot([p1[0], p2[0]], [p1[1], p2[1]], color='blue')
    #         x = 'o'

    landmarks_pred = LandMarks([], [], [], [])

    for i in range(21):
        (x1, y1), (x2, y2) = labels[i]
        landmarks_pred.X1.append(x1)
        landmarks_pred.Y1.append(y1)

        landmarks_pred.X2.append(x2)
        landmarks_pred.Y2.append(y2)

    # Arrange landmarks like Dataset
    landmarks_pred.X1[0], landmarks_pred.X2[0] = landmarks_pred.X2[0], landmarks_pred.X1[0]
    landmarks_pred.Y1[0], landmarks_pred.Y2[0] = landmarks_pred.Y2[0], landmarks_pred.Y1[0]

    landmarks_pred.X1[1:21] = landmarks_pred.X1[1:21][::-1]
    landmarks_pred.Y1[1:21] = landmarks_pred.Y1[1:21][::-1]

    landmarks_pred.X2[1:21] = landmarks_pred.X2[1:21][::-1]
    landmarks_pred.Y2[1:21] = landmarks_pred.Y2[1:21][::-1]

    return landmarks_pred


# Calc Volume and EF
def _calculate_volume(landmarks):
    X1 = landmarks.X1
    Y1 = landmarks.Y1
    X2 = landmarks.X2
    Y2 = landmarks.Y2
    verticalLine_distance = math.sqrt((X2[0] - X1[0]) ** 2 + (Y2[0] - Y1[0]) ** 2)
    dx = verticalLine_distance / 20

    volume = 0
    for i in range(1, 21):
        volume += (math.pi * ((X2[i] - X1[i]) ** 2 + (Y2[i] - Y1[i]) ** 2) * dx) / 4.0

    return volume


def calculate_EF(ED_volume, ES_volume):
    return (abs(abs(ED_volume) - abs(ES_volume)) / ED_volume) * 100


def get_LV_volume(mask):
    landmarks = _GetPointOfSegmentMask(mask)
    volume = _calculate_volume(landmarks)
    return volume, landmarks

# DataSet Model

In [None]:
class VideoDataSetForModel(torch.utils.data.Dataset):
    def __init__(self, dataSet=None, fullVideo=False):

        self.dataSet = dataSet
        self.fullVideo = fullVideo
        self.frame_width = self.frame_height = 128

    def __len__(self):
        return len(self.dataSet)

    def __getitem__(self, index):
        obj = self.dataSet[index]

        if self.fullVideo:
            frame_count = obj.numberOfFrames
            path = os.path.join(_videosPath, obj.fileName + '.avi')
            video = _extractVideoFrames(path)
        else:
            frame_count = 128
            video = _mirroringVideo(obj)

        frames = np.zeros((frame_count, 112, 112, 3), np.float32)
        labels = np.zeros(frame_count, np.int8)

        if self.fullVideo:
            frames = video
            labels[obj.ES_frame] = 1
            labels[obj.ED_frame] = 2
        else:
            for i in range(0, frame_count):
                # 0 TR , 1 ES, 2 ED
                label = video[i][1]
                if label == 'ES':
                    label = 1
                elif label == 'ED':
                    label = 2
                else:
                    label = 0

                frames[i] = video[i][0]
                labels[i] = label

        # (F,W,H,C) > F C W H
        frames = frames.transpose((3, 0, 1, 2))

        ########################
        # Load video into np.array
        frames = frames.astype(np.float32)

        # Scale pixel values from 0-255 to 0-1
        frames /= 255.0

        frames = np.moveaxis(frames, 0, 1)
        p = 8
        frames = np.pad(frames, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant', constant_values=0)
        ########################

        return frames, labels

# Transformer

ResNet Model

In [None]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1)):
        super(ResidualBlock, self).__init__()

        self.residual_block = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                            padding=1),
            torch.nn.BatchNorm2d(in_channels),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                            padding=1),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return x + self.residual_block(x)


class ResNetEncoder(torch.nn.Module):
    def __init__(self,
                 n_ResidualBlock=8,
                 n_levels=4,
                 input_ch=3,
                 z_dim=10,
                 bUseMultiResSkips=True):

        super(ResNetEncoder, self).__init__()

        self.max_filters = 2 ** (n_levels + 3)
        self.n_levels = n_levels
        self.bUseMultiResSkips = bUseMultiResSkips

        self.conv_list = torch.nn.ModuleList()
        self.res_blk_list = torch.nn.ModuleList()
        self.multi_res_skip_list = torch.nn.ModuleList()

        self.input_conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=input_ch, out_channels=8,
                            kernel_size=(3, 3), stride=(1, 1), padding=1),
            torch.nn.BatchNorm2d(8),
            torch.nn.ReLU(inplace=True),
        )

        for i in range(n_levels):
            n_filters_1 = 2 ** (i + 3)
            n_filters_2 = 2 ** (i + 4)
            ks = 2 ** (n_levels - i)

            self.res_blk_list.append(
                torch.nn.Sequential(*[ResidualBlock(n_filters_1, n_filters_1)
                                      for _ in range(n_ResidualBlock)])
            )

            self.conv_list.append(
                torch.nn.Sequential(
                    torch.nn.Conv2d(n_filters_1, n_filters_2,
                                    kernel_size=(2, 2), stride=(2, 2), padding=0),
                    torch.nn.BatchNorm2d(n_filters_2),
                    torch.nn.ReLU(inplace=True),
                )
            )

            if bUseMultiResSkips:
                self.multi_res_skip_list.append(
                    torch.nn.Sequential(
                        torch.nn.Conv2d(in_channels=n_filters_1, out_channels=self.max_filters, kernel_size=(ks, ks),
                                        stride=(ks, ks), padding=0),
                        torch.nn.BatchNorm2d(self.max_filters),
                        torch.nn.ReLU(inplace=True),
                    )
                )

        self.output_conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=self.max_filters, out_channels=z_dim,
                            kernel_size=(3, 3), stride=(1, 1), padding=1),
            torch.nn.BatchNorm2d(z_dim),
            torch.nn.ReLU(inplace=True),
        )

    def forward(self, x):

        x = self.input_conv(x)

        skips = []
        for i in range(self.n_levels):
            x = self.res_blk_list[i](x)
            if self.bUseMultiResSkips:
                skips.append(self.multi_res_skip_list[i](x))
            x = self.conv_list[i](x)

        if self.bUseMultiResSkips:
            x = sum([x] + skips)

        x = self.output_conv(x)

        return x


class ResNetDecoder(torch.nn.Module):
    def __init__(self,
                 n_ResidualBlock=8,
                 n_levels=4,
                 z_dim=10,
                 output_channels=3,
                 bUseMultiResSkips=True):

        super(ResNetDecoder, self).__init__()

        self.max_filters = 2 ** (n_levels + 3)
        self.n_levels = n_levels
        self.bUseMultiResSkips = bUseMultiResSkips

        self.conv_list = torch.nn.ModuleList()
        self.res_blk_list = torch.nn.ModuleList()
        self.multi_res_skip_list = torch.nn.ModuleList()

        self.input_conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=z_dim, out_channels=self.max_filters,
                            kernel_size=(3, 3), stride=(1, 1), padding=1),
            torch.nn.BatchNorm2d(self.max_filters),
            torch.nn.ReLU(inplace=True),
        )

        n_filters_1 = 2 ** (self.n_levels - 0 + 2)
        for i in range(n_levels):
            n_filters_0 = 2 ** (self.n_levels - i + 3)
            n_filters_1 = 2 ** (self.n_levels - i + 2)
            ks = 2 ** (i + 1)

            self.res_blk_list.append(
                torch.nn.Sequential(*[ResidualBlock(n_filters_1, n_filters_1)
                                      for _ in range(n_ResidualBlock)])
            )

            self.conv_list.append(
                torch.nn.Sequential(
                    torch.nn.ConvTranspose2d(n_filters_0, n_filters_1,
                                             kernel_size=(2, 2), stride=(2, 2), padding=0),
                    torch.nn.BatchNorm2d(n_filters_1),
                    torch.nn.ReLU(inplace=True),
                )
            )

            if bUseMultiResSkips:
                self.multi_res_skip_list.append(
                    torch.nn.Sequential(
                        torch.nn.ConvTranspose2d(in_channels=self.max_filters, out_channels=n_filters_1,
                                                 kernel_size=(ks, ks), stride=(ks, ks), padding=0),
                        torch.nn.BatchNorm2d(n_filters_1),
                        torch.nn.ReLU(inplace=True),
                    )
                )

        self.output_conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=n_filters_1, out_channels=output_channels,
                            kernel_size=(3, 3), stride=(1, 1), padding=1),
            # torch.nn.BatchNorm2d(output_channels),
            # torch.nn.ReLU(inplace=True),
        )

    def forward(self, z):

        z = z_top = self.input_conv(z)

        for i in range(self.n_levels):
            z = self.conv_list[i](z)
            z = self.res_blk_list[i](z)
            if self.bUseMultiResSkips:
                z += self.multi_res_skip_list[i](z_top)

        z = self.output_conv(z)

        return z


class ResNetAE(torch.nn.Module):
    def __init__(self,
                 input_shape=(256, 256, 3),
                 n_ResidualBlock=8,
                 n_levels=4,
                 z_dim=128,
                 bottleneck_dim=128,
                 bUseMultiResSkips=True):
        super(ResNetAE, self).__init__()

        assert input_shape[0] == input_shape[1]
        image_channels = input_shape[2]
        self.z_dim = z_dim
        self.img_latent_dim = input_shape[0] // (2 ** n_levels)

        self.encoder = ResNetEncoder(n_ResidualBlock=n_ResidualBlock, n_levels=n_levels,
                                     input_ch=image_channels, z_dim=z_dim, bUseMultiResSkips=bUseMultiResSkips)
        self.decoder = ResNetDecoder(n_ResidualBlock=n_ResidualBlock, n_levels=n_levels,
                                     output_channels=image_channels, z_dim=z_dim, bUseMultiResSkips=bUseMultiResSkips)

        self.fc1 = torch.nn.Linear(self.z_dim * self.img_latent_dim * self.img_latent_dim, bottleneck_dim)
        self.fc2 = torch.nn.Linear(bottleneck_dim, self.z_dim * self.img_latent_dim * self.img_latent_dim)

    def encode(self, x):
        h = self.encoder(x)
        return torch.tanh(self.fc1(h.view(x.shape[0], self.z_dim * self.img_latent_dim * self.img_latent_dim)))

    def decode(self, z):
        h = self.decoder(self.fc2(z).view(-1, self.z_dim, self.img_latent_dim, self.img_latent_dim))
        return torch.sigmoid(h)

    def forward(self, x):
        return self.decode(self.encode(x))


def reParameterize(mu, log):
    std = torch.exp(0.5 * log)
    eps = torch.randn_like(std)
    return mu + eps * std


class ResNetVAE(torch.nn.Module):
    def __init__(self,
                 input_shape=(256, 256, 3),
                 n_ResidualBlock=8,
                 n_levels=4,
                 z_dim=128,
                 bottleneck_dim=128,
                 bUseMultiResSkips=True):
        super(ResNetVAE, self).__init__()

        assert input_shape[0] == input_shape[1]
        image_channels = input_shape[2]
        self.z_dim = z_dim
        self.img_latent_dim = input_shape[0] // (2 ** n_levels)

        self.encoder = ResNetEncoder(n_ResidualBlock=n_ResidualBlock, n_levels=n_levels,
                                     input_ch=image_channels, z_dim=z_dim, bUseMultiResSkips=bUseMultiResSkips)
        self.decoder = ResNetDecoder(n_ResidualBlock=n_ResidualBlock, n_levels=n_levels,
                                     output_channels=image_channels, z_dim=z_dim, bUseMultiResSkips=bUseMultiResSkips)

        # Assumes the input to be of shape 256x256
        self.fc21 = torch.nn.Linear(self.z_dim * self.img_latent_dim * self.img_latent_dim, bottleneck_dim)
        self.fc22 = torch.nn.Linear(self.z_dim * self.img_latent_dim * self.img_latent_dim, bottleneck_dim)
        self.fc3 = torch.nn.Linear(bottleneck_dim, self.z_dim * self.img_latent_dim * self.img_latent_dim)

    def encode(self, x):
        h1 = self.encoder(x)
        return self.fc21(h1.view(-1, self.z_dim * self.img_latent_dim * self.img_latent_dim)), \
            self.fc22(h1.view(-1, self.z_dim * self.img_latent_dim * self.img_latent_dim))

    def decode(self, z):
        h3 = self.decoder(self.fc3(z).view(-1, self.z_dim, self.img_latent_dim, self.img_latent_dim))
        return torch.sigmoid(h3)

    def forward(self, x):
        mu, log = self.encode(x)
        z = reParameterize(mu, log)
        return self.decode(z), mu, log

Transformer Model

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, embedding_dim=1024, num_hidden_layers=16, attention_heads=16, intermediate_size=8192,
                 input_shape=(128, 128, 3)):
        super(TransformerModel, self).__init__()

        # ResNet encoder
        self.model_AE = ResNetAE(input_shape=input_shape, n_ResidualBlock=8, n_levels=4,
                                 bottleneck_dim=embedding_dim)
        self.model_AE.decoder = None
        self.model_AE.fc2 = None

        # BertModel encoder
        configuration = BertConfig(
            vocab_size=1,  # Set to 0/None ?
            hidden_size=embedding_dim,  # Length of embeddings
            num_hidden_layers=num_hidden_layers,  # 16
            num_attention_heads=attention_heads,
            intermediate_size=intermediate_size,  # 8192
            hidden_act='gelu',
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=1024,  # 64 ?
            type_vocab_size=1,
            initializer_range=0.02,
            layer_norm_eps=1e-12,
            pad_token_id=0,
            gradient_checkpointing=False,
            position_embedding_type='absolute',
            use_cache=True)

        configuration.num_labels = 3

        self.model_Bert = BertModel(configuration).encoder

        self.embedding_dim = embedding_dim

        last_features = 3
        self.extremas = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=embedding_dim // 2, bias=True),
            nn.LayerNorm(embedding_dim // 2),
            nn.LeakyReLU(negative_slope=0.05, inplace=True),

            nn.Linear(in_features=embedding_dim // 2, out_features=embedding_dim // 4, bias=True),
            nn.LayerNorm(embedding_dim // 4),
            nn.LeakyReLU(negative_slope=0.05, inplace=True),

            nn.Linear(in_features=embedding_dim // 4, out_features=last_features, bias=True),
            nn.Tanh()
        )

    def forward(self, frames, nF):
        # (BxF) x C x H x W => (BxF) x Emb
        # Frame [128, 3, 128, 128]
        frames = self.model_AE.encode(frames).squeeze()
        # embeddings [128, 1024]
        # output_bert
        # B x F x Emb => AttHeads+1 x B x F x Emb
        frames = self.model_Bert(frames.view(-1, nF, self.embedding_dim), output_hidden_states=True)
        # AttHeads+1 x B x F x Emb => B x F x Emb
        frames = torch.stack(frames.hidden_states).mean(dim=0)
        # B x F x Emb => B x F x 1
        frames = self.extremas(frames)
        # classes_vec [1, 128, 3]
        return frames

Transformer Function

In [None]:
# Train
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def run_epoch(model, dataloader, optim, device):
    total = 0.
    n = 0
    loss_hist = []

    model.train(True)
    print("Learning rate:", get_lr(optim))

    weighting = torch.tensor([1., 5., 5.]).to(device)
    loss_fct1 = nn.CrossEntropyLoss(weight=weighting, reduction='mean')

    with tqdm.tqdm(total=len(dataloader)) as pbar:
        for (frames, label) in dataloader:
            nB, nF, nC, nH, nW = frames.shape

            # Merge batch and frames dimension
            frames = frames.view(nB * nF, nC, nH, nW)
            frames = frames.to(device, dtype=torch.float32)

            # (F*B) X C x W X H
            class_vec = model(frames, nF)

            label = label.to(device, dtype=torch.long)

            loss1 = loss_fct1(class_vec.view(-1, 3), label.view(-1))
            loss = loss1

            # Take gradient step if training
            optim.zero_grad()
            loss.backward()
            optim.step()

            # Accumulate losses and compute baselines
            total += loss.item()
            n += 1
            loss_hist.append(loss.item())

            avg = np.mean(loss_hist[max(-len(loss_hist), -10):])

            # Show info on process bar
            pbar.set_postfix_str("{:.4f} / {:.4f} / {:.4f}".format(total / n, loss1.item(), avg))
            pbar.update()
    loss_hist = np.array(loss_hist)

    return (total / 1), loss_hist


def trainTransformer(train_dataSet, num_epochs, batch_size, parallel):
    np.random.seed(0)
    torch.manual_seed(0)

    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if type(device) == type(list()):
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in device)
        device = "cuda"

    device = torch.device(device)
    print("Using device:", device)

    # Model
    model = TransformerModel(embedding_dim=1024, num_hidden_layers=16, attention_heads=16, intermediate_size=8192,
                             input_shape=(128, 128, 3))

    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print('model name', model.__class__.__name__, "contains", pytorch_total_params, "parameters.")

    if parallel:
        model = nn.DataParallel(model)

    model.to(device)

    # DATA SETUP
    train_dataSet = VideoDataSetForModel(dataSet=train_dataSet, fullVideo=False)

    train_dataloader = torch.utils.data.DataLoader(train_dataSet, batch_size=batch_size, num_workers=0,
                                                   shuffle=True,
                                                   pin_memory=(device.type == "cuda"),
                                                   drop_last=True)

    dataloaders = {'train': train_dataloader}
    # len(dataloaders['train'])

    # Set up optimizer
    lr = 1e-5
    optim = torch.optim.AdamW(model.parameters(), lr=lr)
    lr_step_period = 1
    scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period)

    bestLoss = float("inf")

    for epoch in range(1, num_epochs + 1):
        print("Epoch {} / {}".format(epoch, num_epochs), flush=True)
        for phase in ['train']:  # , 'val']:
            print("Running on", phase)
            loss, _ = run_epoch(model, dataloaders[phase], optim, device)
            print('Loss =', loss)
            print()
        scheduler.step()

        # Save checkpoint
        save = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'best_loss': bestLoss,
            'loss': loss,
            'opt_dict': optim.state_dict(),
            'scheduler_dict': scheduler.state_dict(),
        }
        if loss < bestLoss:
            print('new Best Version')
            torch.save(save, "best.pt")
            bestLoss = loss
        else:
            torch.save(save, "checkpoint_" + str(epoch) + ".pt")


# Test
def show_graph(label, predict):
    predict = predict.copy() + 0.5
    plt.plot(predict, label='predict')
    plt.plot(label, label='Label')
    plt.legend()
    plt.show()


def smooth(vec, window=5, rep=1):
    weight = torch.ones((1, 1, window)) / window
    for _ in range(rep):
        pad = int((window - 1) / 2)
        vec = vec.unsqueeze(0).unsqueeze(0)
        vec = torch.nn.functional.conv1d(vec, weight, bias=None, stride=1, padding=pad, dilation=1, groups=1).squeeze()
    return vec


def loadTransformerModel(path):
    best = torch.load(path, map_location="cpu")
    model = TransformerModel(embedding_dim=1024, num_hidden_layers=16, attention_heads=16, intermediate_size=8192,
                             input_shape=(128, 128, 3))
    model = torch.nn.DataParallel(model)
    model.load_state_dict(best['state_dict'])
    model.eval()
    return model


def GetLengthOfEachBet(predict, deleteLastHalfBet=False):
    lengthBet = []
    firstFrames = []
    lastFrames = []
    check = True

    for i in range(len(predict)):
        if predict[i] != 0:
            if check:
                firstFrames.append(i)
                check = False
            else:
                lastFrames.append(i)
                check = True

    if deleteLastHalfBet:
        # Delete the last half Bet
        if len(firstFrames) > len(lastFrames):
            predict[firstFrames[-1]] = 0
            firstFrames.pop()
        elif len(firstFrames) < len(lastFrames):
            predict[lastFrames[-1]] = 0
            lastFrames.pop()

    for i in range(len(firstFrames)):
        lengthBet.append(lastFrames[i] - firstFrames[i])

    return lengthBet, firstFrames, lastFrames


def testForOneVideo(model, frames, device):
    nB, nF, nC, nH, nW = frames.shape
    frames = torch.cat(([frames[i] for i in range(frames.size(0))]), dim=0)
    frames = frames.to(device, dtype=torch.float)

    class_vec = model(frames, nF).squeeze()

    class_diff = class_vec[:, 2] - class_vec[:, 1]

    smooth_vec = smooth(class_diff, window=5, rep=3).detach().numpy()

    # Get Peaks
    predict = np.zeros((len(smooth_vec)), np.int8)
    for i in range(len(smooth_vec)):
        if i == 0 or i == len(smooth_vec) - 1:
            continue
        if smooth_vec[i] < smooth_vec[i + 1] and smooth_vec[i] < smooth_vec[i - 1]:
            predict[i] = 1
        if smooth_vec[i] > smooth_vec[i + 1] and smooth_vec[i] > smooth_vec[i - 1]:
            predict[i] = 2

    # Get length of each bet
    lengthBet, firstFrames, lastFrames = GetLengthOfEachBet(predict, True)

    # Apply Threshold
    thr = max(lengthBet) * 0.35
    for i in range(len(lengthBet)):
        if thr > lengthBet[i]:
            predict[firstFrames[i]] = 0
            predict[lastFrames[i]] = 0

    return predict


def testTransformer(transformer_path, dataSet):
    device = 'cpu'
    device = torch.device(device)
    model = loadTransformerModel(transformer_path)

    dataSet = VideoDataSetForModel(dataSet=dataSet, fullVideo=True)
    dataloader = torch.utils.data.DataLoader(dataSet, batch_size=1, shuffle=False)
    trueFrames = 0
    trueTransitionFrames = 0
    trueESFrames = 0
    trueEDFrames = 0
    totalFrames = 0
    trueFrames2 = 0
    with tqdm.tqdm(total=len(dataloader)) as pbar:
        for frames, label in dataloader:
            predict = testForOneVideo(model, frames, device)
            label = label.squeeze().detach().numpy()

            totalFrames += len(label)
            for i in range(len(label)):
                if predict[i] == label[i]:
                    trueFrames += 1
                if predict[i] == 0 and label[i] == 0:
                    trueTransitionFrames += 1
                elif predict[i] == 1 and label[i] == 1:
                    trueEDFrames += 1
                elif predict[i] == 2 and label[i] == 2:
                    trueESFrames += 1

                if label[i] != 0:
                    if predict[i] == label[i]:
                        trueFrames2 += 1

            pbar.update()

    accuracy = (trueFrames / totalFrames) * 100
    print('Accuracy: ', accuracy)

    accuracy2 = (trueFrames2 / (2 * len(dataloader))) * 100
    print('Accuracy ES & ED: ', accuracy2)

    accuracyED = (trueEDFrames / len(dataloader)) * 100
    print('Accuracy ED: ', accuracyED)

    accuracyES = (trueESFrames / len(dataloader)) * 100
    print('Accuracy ES: ', accuracyES)

    accuracyTransition = (trueTransitionFrames / (totalFrames - len(dataloader) * 2)) * 100
    print('Accuracy Transition: ', accuracyTransition)


# Detect ES & ED Frame
def Detect_ESED_Frame(video_path, transformerModel, labels=None):
    device = 'cpu'
    device = torch.device(device)
    # Prepare Video
    frames = _extractVideoFrames(video_path)
    # (F,W,H,C) > F C W H
    frames = frames.transpose((3, 0, 1, 2))
    # Load video into np.array
    frames = frames.astype(np.float32)
    # Scale pixel values from 0-255 to 0-1
    frames /= 255.0

    frames = np.moveaxis(frames, 0, 1)
    p = 8
    frames = np.pad(frames, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant', constant_values=0)

    frames = torch.from_numpy(frames)
    frames = frames.unsqueeze(0)

    predict = testForOneVideo(transformerModel, frames, device)

    lengthBet, firstFrames, lastFrames = GetLengthOfEachBet(predict, False)

    maxIDX = 0
    for i in range(1, len(lengthBet)):
        if lengthBet[maxIDX] <= lengthBet[i]:
            maxIDX = i

    frames = frames.squeeze()
    if predict[firstFrames[maxIDX]] == 1:
        ES_Frame_IMG = np.transpose(frames[firstFrames[maxIDX]], (1, 2, 0))
        ED_Frame_IMG = np.transpose(frames[lastFrames[maxIDX]], (1, 2, 0))
    else:
        ED_Frame_IMG = np.transpose(frames[firstFrames[maxIDX]], (1, 2, 0))
        ES_Frame_IMG = np.transpose(frames[lastFrames[maxIDX]], (1, 2, 0))

    # Show 4 Frames
    if labels is not None:
        print(firstFrames[maxIDX], lastFrames[maxIDX])
        TrueES_Frame = 0
        TrueED_Frame = 0
        for i in range(len(labels)):
            if labels[i] == 1:
                TrueES_Frame = np.transpose(frames[i], (1, 2, 0))
            elif labels[i] == 2:
                TrueED_Frame = np.transpose(frames[i], (1, 2, 0))

            fig, axes = plt.subplots(2, 2)

            axes[0][0].imshow(ES_Frame_IMG)
            axes[0][0].set_title('ES Pred')
            axes[0][0].axis('off')

            axes[0][1].imshow(TrueES_Frame)
            axes[0][1].set_title('ES True')
            axes[0][1].axis('off')

            axes[1][0].imshow(ED_Frame_IMG)
            axes[1][0].set_title('ED Pred')
            axes[1][0].axis('off')

            axes[1][1].imshow(TrueED_Frame)
            axes[1][1].set_title('ED True')
            axes[1][1].axis('off')

            plt.show()

    # Crop the padding added in train
    ES_Frame_IMG = ES_Frame_IMG.numpy()[8:-8, 8:-8, :]
    ED_Frame_IMG = ED_Frame_IMG.numpy()[8:-8, 8:-8, :]

    return ES_Frame_IMG, ED_Frame_IMG

# Unet

Unet Model

In [None]:
# UNQ_C1
# GRADED FUNCTION: conv_block
def conv_block(inputs=None, n_filters=32, dropout_prob=0, max_pooling=True):
    """
    Convolutional downsampling block

    Arguments:
        inputs -- Input tensor
        n_filters -- Number of filters for the convolutional layers
        dropout_prob -- Dropout probability
        max_pooling -- Use MaxPooling2D to reduce the spatial dimensions of the output volume
    Returns:
        next_layer, skip_connection --  Next layer and skip connection outputs
    """

    ### START CODE HERE
    conv = Conv2D(n_filters,  # Number of filters
                  3,  # Kernel size
                  activation='relu',
                  padding='same',
                  kernel_initializer='he_normal')(inputs)
    conv = Conv2D(n_filters,  # Number of filters
                  3,  # Kernel size
                  activation='relu',
                  padding='same',
                  kernel_initializer='he_normal')(conv)
    ### END CODE HERE

    # if dropout_prob > 0 add a dropout layer, with the variable dropout_prob as parameter
    if dropout_prob > 0:
        ### START CODE HERE
        conv = Dropout(dropout_prob)(conv)
        ### END CODE HERE

    # if max_pooling is True add a MaxPooling2D with 2x2 pool_size
    if max_pooling:
        ### START CODE HERE
        next_layer = MaxPooling2D(2, strides=2)(conv)
        ### END CODE HERE

    else:
        next_layer = conv

    skip_connection = conv

    return next_layer, skip_connection


# UNQ_C2
# GRADED FUNCTION: upsampling_block
def upsampling_block(expansive_input, contractive_input, n_filters=32):
    """
    Convolutional upsampling block

    Arguments:
        expansive_input -- Input tensor from previous layer
        contractive_input -- Input tensor from previous skip layer
        n_filters -- Number of filters for the convolutional layers
    Returns:
        conv -- Tensor output
    """

    ### START CODE HERE
    up = Conv2DTranspose(
        n_filters,  # number of filters
        3,  # Kernel size
        strides=2,
        padding='same')(expansive_input)

    # Merge the previous output and the contractive_input
    merge = concatenate([up, contractive_input], axis=3)

    conv = Conv2D(n_filters,  # Number of filters
                  3,  # Kernel size
                  activation='relu',
                  padding='same',
                  kernel_initializer='he_normal')(merge)
    conv = Conv2D(n_filters,  # Number of filters
                  3,  # Kernel size
                  activation='relu',
                  padding='same',
                  kernel_initializer='he_normal')(conv)
    ### END CODE HERE

    return conv


# UNQ_C3
# GRADED FUNCTION: unet_model
def unetModel(input_size, n_filters, n_classes):
    """
    Unet model

    Arguments:
        input_size -- Input shape
        n_filters -- Number of filters for the convolutional layers
        n_classes -- Number of output classes
    Returns:
        model -- tf.keras.Model
    """
    inputs = Input(input_size)
    # Contracting Path (encoding)
    # Add a conv_block with the inputs of the unet_ model and n_filters
    ### START CODE HERE
    cblock1 = conv_block(inputs=inputs, n_filters=n_filters * 1)
    # Chain the first element of the output of each block to be the input of the next conv_block.
    # Double the number of filters at each new step
    cblock2 = conv_block(inputs=cblock1[0], n_filters=n_filters * 2)
    cblock3 = conv_block(inputs=cblock2[0], n_filters=n_filters * 4)
    # Include a dropout of 0.3 for this layer
    cblock4 = conv_block(inputs=cblock3[0], n_filters=n_filters * 8, dropout_prob=0.3)
    # Include a dropout of 0.3 for this layer, and avoid the max_pooling layer
    cblock5 = conv_block(inputs=cblock4[0], n_filters=n_filters * 16, dropout_prob=0.3, max_pooling=False)
    ### END CODE HERE

    # Expanding Path (decoding)
    # Add the first upsampling_block.
    # From here,at each step, use half the number of filters of the previous block
    # Use the cblock5[0] as expansive_input and cblock4[1] as contractive_input and n_filters * 8
    ### START CODE HERE
    ublock6 = upsampling_block(cblock5[0], cblock4[1], n_filters * 8)
    # Chain the output of the previous block as expansive_input and the corresponding contractive block output.
    # Note that you must use the second element of the contractive block i.e before the maxpooling layer.

    ublock7 = upsampling_block(ublock6, cblock3[1], n_filters * 4)
    ublock8 = upsampling_block(ublock7, cblock2[1], n_filters * 2)
    ublock9 = upsampling_block(ublock8, cblock1[1], n_filters * 1)
    ### END CODE HERE

    conv9 = Conv2D(n_filters,
                   3,
                   activation='relu',
                   padding='same',
                   kernel_initializer='he_normal')(ublock9)

    # Add a Conv2D layer with n_classes filter, kernel size of 1 and a 'same' padding
    ### START CODE HERE
    conv10 = Conv2D(n_classes, 1, padding='same')(conv9)
    ### END CODE HERE

    model = tf.keras.Model(inputs=inputs, outputs=conv10)

    return model

Unet Function

In [None]:
def predictLVForEDESFrames(ES_Frame_IMG, ED_Frame_IMG, ED_model, ES_model):
    # Mask
    ED_pred_mask = predictMask(ED_model, np.expand_dims(ED_Frame_IMG, axis=0))
    ES_pred_mask = predictMask(ES_model, np.expand_dims(ES_Frame_IMG, axis=0))

    # Volume
    ED_pred_volume, ED_pred_landmarks = get_LV_volume(ED_pred_mask)
    ES_pred_volume, ES_pred_landmarks = get_LV_volume(ES_pred_mask)

    # EF
    ef_pred = calculate_EF(ED_pred_volume, ES_pred_volume)

    return ef_pred


def calculate_mean_mse(ground_truth_masks, predicted_masks):
    num_masks = len(predicted_masks)
    mse_values = []

    for i in range(num_masks):
        mse = mean_squared_error(ground_truth_masks[i].flatten(), predicted_masks[i].flatten())
        mse_values.append(mse)

    mean_mse = np.mean(mse_values)
    print(f"Mean MSE: {mean_mse * 100}")
    # return mean_mse


def _createPredictedMask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]


def predictMask(model, image):
    mask = _createPredictedMask(model.predict(image))
    mask = mask.numpy()
    mask = np.squeeze(mask)
    return mask


def loadUnetModel(path):
    loaded_model = unetModel(input_size=(112, 112, 3), n_filters=32, n_classes=2)
    loaded_model.load_weights(path)
    loaded_model.compile(optimizer='adam', loss=SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

    # print(loaded_model.get_weights()[0][0][0][0])

    return loaded_model


def trainUnet(dataSet, epochs=5, batchSize=32, modelPath='', name=''):
    unet = unetModel(input_size=(112, 112, 3), n_filters=32, n_classes=2)
    unet.compile(optimizer='adam', loss=SparseCategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])
    # unet.summary()

    BUFFER_SIZE = 500
    dataSet.batch(batchSize)
    dataSet = dataSet.cache().shuffle(BUFFER_SIZE).batch(batchSize)

    unet.fit(dataSet, epochs=epochs)

    unet.save_weights(f'{modelPath}/{name}.weights.h5')


def testUnet(dataSet, path=''):
    dataSet.batch(1)
    dataSet = dataSet.cache().batch(1)

    loaded_model = loadUnetModel(path)

    evaluation_result = loaded_model.evaluate(dataSet)
    print("Test Accuracy:", evaluation_result[1])

# Train & Test

In [None]:
# Read Data
train_dataSet = load_or_get_data(spilt_type='TRAIN')
print('TRAIN =', len(train_dataSet))

test_dataSet = load_or_get_data(spilt_type='TEST')
print('TEST =', len(test_dataSet))

val_dataSet = load_or_get_data(spilt_type='VAL')
print('VAL =', len(val_dataSet))

In [None]:
# Transformer Model
# Train Note: Use 2 GPU to get fast train
# Transformer.train(train_dataSet, num_epochs=7, batch_size=2, parallel=True)
# Train Note: Use CPU
trainTransformer(train_dataSet, num_epochs=7, batch_size=1, parallel=True)

testTransformer(_transformerModelPath, test_dataSet)

In [None]:
# Prepare Data For U-NET
CreateAllMasks(_trueMasksPath)

frameType = 'ED'
unet_dataset_train = getImageAndMasks(frameType=frameType, split='TRAIN', trueMasksPath=_trueMasksPath)
unet_dataset_test = getImageAndMasks(frameType=frameType, split='TEST', trueMasksPath=_trueMasksPath)
unet_dataset_val = getImageAndMasks(frameType=frameType, split='VAL', trueMasksPath=_trueMasksPath)

# ED U-NET Model
trainUnet(unet_dataset_train, epochs=5, batchSize=32, modelPath='', name=f'{frameType}_U_NET_Model')
testUnet(unet_dataset_test, path=_ED_Model_Path)

frameType = 'ES'
unet_dataset_train = getImageAndMasks(frameType=frameType, split='TRAIN', trueMasksPath=_trueMasksPath)
unet_dataset_test = getImageAndMasks(frameType=frameType, split='TEST', trueMasksPath=_trueMasksPath)
unet_dataset_val = getImageAndMasks(frameType=frameType, split='VAL', trueMasksPath=_trueMasksPath)

# 80,20 %
N = 550
unet_dataset_train = unet_dataset_train.concatenate(unet_dataset_val.take(N))
unet_dataset_val = unet_dataset_val.skip(N)
unet_dataset_test = unet_dataset_test.concatenate(unet_dataset_val)
print(len(unet_dataset_train))
print(len(unet_dataset_test))

# ES U-NET Model
trainUnet(unet_dataset_train, epochs=5, batchSize=32, modelPath='', name=f'{frameType}_U_NET_Model')
testUnet(unet_dataset_test, path=_ES_Model_Path)

# MAIN

In [None]:
# Data
test_dataSet = load_or_get_data('TEST')
print('TEST =', len(test_dataSet))

In [None]:
# Load Model
transformerModel = loadTransformerModel(_transformerModelPath)
ED_Model = loadUnetModel(_ED_Model_Path)
ES_Model = loadUnetModel(_ES_Model_Path)

In [None]:

for obj in test_dataSet[0:1]:
    print(obj.fileName)
    name = obj.fileName + '.avi'
    videoPath = os.path.join(_videosPath, name)

    ES_Frame_IMG, ED_Frame_IMG = Detect_ESED_Frame(videoPath, transformerModel)

    efPred = predictLVForEDESFrames(ES_Frame_IMG, ED_Frame_IMG, ED_Model, ES_Model)

    print('EF Predicted =', efPred)
    print('EF =', obj.EF_value)
    print('-----------------------')