In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [2]:
import matplotlib.pyplot as plt
import re
import numpy as np 
import cv2
import copy
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.models.resnet import resnet34
from torchvision.models.video import r3d_18
from torchsummary import summary
from torchstat import stat
from math import floor
import cmath
from scipy.special import expit
import math
from scipy.ndimage import label
import random
import pandas as pd
import gc

In [3]:
seed=3407
random.seed(seed)   # Python的随机性
os.environ['PYTHONHASHSEED'] = str(seed)    # 设置Python哈希种子，为了禁止hash随机化，使得实验可复现
np.random.seed(seed)   # numpy的随机性
torch.manual_seed(seed)   # torch的CPU随机性，为CPU设置随机种子
torch.cuda.manual_seed(seed)   # torch的GPU随机性，为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.   torch的GPU随机性，为所有GPU设置随机种子
torch.backends.cudnn.benchmark = False   # if benchmark=True, deterministic will be False
torch.backends.cudnn.deterministic = True   # 选择确定性算法

In [None]:
import argparse
from utils.config import get_config
from models.model_dict import get_model

parser = argparse.ArgumentParser(description='Networks')
parser.add_argument('--modelname', default='SAME', type=str, help='type of model, e.g., SAM, SAMFull, MedSAM, MSA, SAMed, SAMUS...')
parser.add_argument('-encoder_input_size', type=int, default=256, help='the image size of the encoder input, 1024 in SAM and MSA, 512 in SAMed, 256 in SAMUS')
parser.add_argument('-low_image_size', type=int, default=128, help='the image embedding size, 256 in SAM and MSA, 128 in SAMed and SAMUS')
parser.add_argument('--task', default='US30K', help='task or dataset name')
parser.add_argument('--vit_name', type=str, default='vit_b', help='select the vit model for the image encoder of sam')
parser.add_argument('--sam_ckpt', type=str, default='sam_vit_b_01ec64.pth', help='Pretrained checkpoint of SAM')
parser.add_argument('--fcba_flag', type=bool, default=None, help='FCBA')
parser.add_argument('--sgpa_flag', type=bool, default=None, help='AGPA')
parser.add_argument('--bott_flag', type=bool, default=None, help='BOTT')
parser.add_argument("-f", dest = 'j_cfile', help = "jupyter config file", default = "file.json", type = str)
args = parser.parse_known_args()[0]
opt = get_config(args.task) 
args.bott_flag = False

In [5]:
args.fcba_flag = False
args.sgpa_flag = False

In [6]:
lr=0.0002
NFrame = 10
epochs_fin = 60
epochs_wm = 10
epochs_ot = epochs_fin-epochs_wm
lambda1 = lambda epoch: (epoch / epochs_wm) if epoch < epochs_wm else 0.5 * (math.cos((epoch - epochs_wm)/(epochs_ot) * math.pi) + 1)

In [7]:
class CamusIterator(Dataset):
    def __init__(self,data_list,do_aug=True):

        self.data_list = data_list
        self.do_aug=do_aug

    def __read_image_gt(self,pat_i,label):
        ED_LV_gt_file = 'D:/CAMUS_240307/patient{}_{}CH_ED_LV_gt.png'.format(pat_i,label)
        ES_LV_gt_file = 'D:/CAMUS_240307/patient{}_{}CH_ES_LV_gt.png'.format(pat_i,label)
        ED_LV_gt = cv2.imread(ED_LV_gt_file,0)
        ES_LV_gt = cv2.imread(ES_LV_gt_file,0)
        return ED_LV_gt,ES_LV_gt

    def __read_cfg(self,pat_i,label):
        cfg_str = 'D:/CAMUS_240307/patient{}_Info_{}CH.cfg'.format(pat_i,label)
  
        with open(cfg_str, 'r') as file:
            file_content = file.read()

        try:

            lv_edv_match = re.search(r'LVedv:\s*([\d.]+)', file_content)
            EDV=float(lv_edv_match.group(1))


            lv_esv_match = re.search(r'LVesv:\s*([\d.]+)', file_content)
            ESV = float(lv_esv_match.group(1))

            lv_ef_match = re.search(r'LVef:\s*([\d.]+)', file_content)
            EF_para = float(lv_ef_match.group(1))

            w_match = re.search(r'w:\s*([\d.]+)', file_content)
            W = float(w_match.group(1))
            h_match = re.search(r'h:\s*([\d.]+)', file_content)
            H = float(h_match.group(1))
            im_size = np.array((H,W))
            EV_para = np.around(np.array((EDV, ESV)))

            return im_size,EV_para,EF_para
        except:
            EDV = ESV = 0
            EF_para = 0
            w_match = re.search(r'w:\s*([\d.]+)', file_content)
            W = float(w_match.group(1))
            h_match = re.search(r'h:\s*([\d.]+)', file_content)
            H = float(h_match.group(1))
            im_size = np.array((H,W))
            EV_para = np.around(np.array((EDV, ESV)))
            return im_size,EV_para,EF_para

        return im_size,EV_para,EF_para
    
    def __read_seq( self,pat_i,label):
        sequence_str = 'D:/CAMUS_240307/patient{}_{}CH_sequence.npy'.format(pat_i,label)
        sequence = np.load(sequence_str)   
        return sequence

    def __read_keypoint(self,pat_i,label):
        ED_key_point_str = "D:/CAMUS_240307/patient{}_{}CH_ED_5_Points.npy".format(pat_i,label)
        ES_key_point_str = "D:/CAMUS_240307/patient{}_{}CH_ES_5_Points.npy".format(pat_i,label)
        ED_key_point=np.load(ED_key_point_str)[0:3,:]
        ES_key_point = np.load(ES_key_point_str)[0:3,:]
        return ED_key_point,ES_key_point
    
    def __len__( self ):
        return len(self.data_list)
    
    def __getitem__( self, index):
        ED_LV_gt_first,ES_LV_gt_first = self.__read_image_gt(self.data_list[index],2)
        ED_key_point_first,ES_key_point_first = self.__read_keypoint(self.data_list[index],2)
        im_size_first,EV_para_first,EF_para_first = self.__read_cfg(self.data_list[index],2)
        sequence_first = self.__read_seq(self.data_list[index],2)


        ED_LV_gt_last,ES_LV_gt_last = self.__read_image_gt(self.data_list[index],4)
        ED_key_point_last,ES_key_point_last = self.__read_keypoint(self.data_list[index],4)
        im_size_last,EV_para_last,EF_para_last = self.__read_cfg(self.data_list[index],4)
        sequence_last = self.__read_seq(self.data_list[index],4)

        im_size = np.array([im_size_first,im_size_last])


        ED_LV_gt = torch.tensor(np.array((ED_LV_gt_first,ED_LV_gt_last)), dtype = torch.uint8)
        ES_LV_gt = torch.tensor(np.array((ES_LV_gt_first,ES_LV_gt_last)), dtype = torch.uint8)
        ED_key_point = torch.tensor(np.array((ED_key_point_first,ED_key_point_last)), dtype = torch.uint8)
        ES_key_point = torch.tensor(np.array((ES_key_point_first,ES_key_point_last)), dtype = torch.uint8)
        tensor_sequence = torch.tensor(np.array((sequence_first,sequence_last)), dtype = torch.float32)
        tensor_EV_para = torch.tensor(np.array((EV_para_first,EV_para_last)), dtype = torch.float32)
        tensor_EF_para = torch.tensor(np.array((EF_para_first,EF_para_last)), dtype = torch.float32)
 
        data = {
            'index': index,
            'sequence': tensor_sequence,
            'size': im_size,
            'ED_LV_gt': ED_LV_gt,
            'ES_LV_gt': ES_LV_gt,
            'ED_key_point':ED_key_point,
            'ES_key_point':ES_key_point,
            'EV_para': tensor_EV_para,
            'EF_para': tensor_EF_para
        }
        return data

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

In [8]:
df = pd.read_excel('CAMUS_split_fin.xlsx')

train_index_list = [format(i, "04") for i in df[df['fold'] == 'Train']['pat_i']]
test_index_list = [format(i, "04") for i in df[df['fold'] == 'Test']['pat_i']]

train_valid_iter = CamusIterator(data_list = train_index_list,do_aug=True)
test_iter = CamusIterator(data_list = test_index_list,do_aug=False)

In [9]:
def split_indices(n, val_pct, seed):
    # Determine size of validation set
    n_val = int(val_pct*n)
    # Set the random seed (for reproducibility)
    np.random.seed(seed)
    # Create random permutation of 0 to n-1
    idxs = np.random.permutation(n)
    # Pick first n_val indices for validation set
    return idxs[n_val:], idxs[:n_val]

val_pct = 1/9
train_indices, val_indices = split_indices(len(train_valid_iter), val_pct, seed)


train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)


traindata = DataLoader(train_valid_iter, sampler=train_sampler, batch_size = 1)
valdata = DataLoader(train_valid_iter, sampler=val_sampler, batch_size = 1)
testdata = DataLoader(test_iter, batch_size = 1)

In [10]:
def key_point_2_onehot(gt_image,key_point):
    key_point_y = key_point[:,0]
    key_point_x = key_point[:,1]
    onehot = np.zeros((3,gt_image.shape[-2],gt_image.shape[-1]))
    for k in range(onehot.shape[0]):
        onehot[k,(key_point_y[k]-2):(key_point_y[k]+2),(key_point_x[k]-2):(key_point_x[k]+2)]= 0.5
        onehot[k,(key_point_y[k]-1):(key_point_y[k]+1),(key_point_x[k]-1):(key_point_x[k]+1)]= 0.75
        onehot[k][key_point_y[k]][key_point_x[k]]= 1
    return onehot


def key_point_2_heatmap(gt_image,key_point,mask_r,ratio):
    mask_r = mask_r//ratio
    key_point_y = key_point[:,0]//ratio
    key_point_x = key_point[:,1]//ratio
    heat_map = np.zeros((3,gt_image.shape[-2]//ratio,gt_image.shape[-1]//ratio))
    for k in range(heat_map.shape[0]):
        for i in range(heat_map.shape[1]):
            for j in range(heat_map.shape[2]):
                distance_2 = (i-key_point_y[k])**2+(j-key_point_x[k])**2
                heat_map[k][i][j]= np.e ** (-1 * (distance_2  / (2 * mask_r ** 2)))
    # heat_map = np.sum(heat_map,axis = 0)
    return heat_map.flatten()

def heatmap_2_key_point(heatmap):
    key_point = np.zeros((3,2))
    for idx in range(3):
        img = heatmap[idx]
        M = img.argmax()
        key_point[idx][0] = M//img.shape[1]
        key_point[idx][1] = M%img.shape[1]
    return key_point

In [11]:
def geo_aug_fin(aug_flag,ED_LV_gt_raw,ES_LV_gt_raw,ED_key_point_raw,ES_key_point_raw,sequence_raw,mask_r):
    ED_LV_gt_new = []
    ES_LV_gt_new = []
    ED_LV_heatmap_new = []
    ES_LV_heatmap_new = []
    ED_LV_point_new = []
    ES_LV_point_new = []
    sequence_new = []

    for batch_i in range(len(sequence_raw)):
        ED_LV_gt_batch = []
        ES_LV_gt_batch = []
        ED_LV_heatmap_batch = []
        ES_LV_heatmap_batch = []
        sequence_batch = []
        ED_LV_point_batch = []
        ES_LV_point_batch = []

        for paired_i in range(len(sequence_raw[batch_i])):
            sequence_list = []
            if aug_flag == True:
                ED_heatmap_raw = key_point_2_onehot(ED_LV_gt_raw[batch_i][paired_i],ED_key_point_raw[batch_i][paired_i])
                ES_heatmap_raw = key_point_2_onehot(ES_LV_gt_raw[batch_i][paired_i],ES_key_point_raw[batch_i][paired_i])
                theta = random.uniform(-15,15)
                scale = random.uniform(0.95,1.05)
                ratio = random.uniform(0.95,1.05)
                distortion_scale = random.uniform(0,0.05)
                prob = random.randint(0,1)

                
                geo_transforms = transforms.Compose([transforms.ToTensor(),transforms.RandomRotation((theta,theta),expand=True),transforms.CenterCrop(256),
                                                     transforms.RandomPerspective(distortion_scale=0.3,p=prob),transforms.RandomResizedCrop((256,256),scale=(scale,scale),ratio = (ratio,ratio))])


                concat_image = np.concatenate((np.expand_dims(ED_LV_gt_raw[batch_i][paired_i],axis=0),
                                               np.expand_dims(ES_LV_gt_raw[batch_i][paired_i],axis=0),
                                               ED_heatmap_raw,ES_heatmap_raw,
                                               sequence_raw[batch_i][paired_i])).transpose(1,2,0)
                concat_transform = geo_transforms(concat_image)


                ED_LV_gt_temp = concat_transform[0,:,:].numpy()
                ES_LV_gt_temp = concat_transform[1,:,:].numpy()
                
                ED_LV_heatmap_repack = concat_transform[2:5,:,:].numpy()
                ES_LV_heatmap_repack = concat_transform[5:8,:,:].numpy()

                sequence_temp= concat_transform[8:10,:,:].numpy()

                ED_key_point_num_temp = heatmap_2_key_point(ED_LV_heatmap_repack)
                ES_key_point_num_temp = heatmap_2_key_point(ES_LV_heatmap_repack)
                ED_LV_point_batch.append(ED_key_point_num_temp)
                ES_LV_point_batch.append(ES_key_point_num_temp)

                ED_LV_heatmap_temp = key_point_2_heatmap(ED_LV_gt_temp,ED_key_point_num_temp,mask_r,2)
                ES_LV_heatmap_temp = key_point_2_heatmap(ES_LV_gt_temp,ES_key_point_num_temp,mask_r,2)
            else:
                ED_LV_gt_temp = ED_LV_gt_raw[batch_i][paired_i]
                ES_LV_gt_temp = ES_LV_gt_raw[batch_i][paired_i]
                ED_LV_heatmap_temp = key_point_2_heatmap(ED_LV_gt_raw[batch_i][paired_i],ED_key_point_raw[batch_i][paired_i],mask_r,2)
                ES_LV_heatmap_temp = key_point_2_heatmap(ES_LV_gt_raw[batch_i][paired_i],ES_key_point_raw[batch_i][paired_i],mask_r,2)
                for frame_i in range(len(sequence_raw[batch_i][paired_i])):
                    sequence_list.append(sequence_raw[batch_i][paired_i][frame_i])
                sequence_temp= np.array(sequence_list)
                ED_LV_point_batch.append(ED_key_point_raw[batch_i][paired_i])
                ES_LV_point_batch.append(ES_key_point_raw[batch_i][paired_i])




            ED_LV_gt_temp = (ED_LV_gt_temp>((np.min(ED_LV_gt_temp)+np.max(ED_LV_gt_temp))/2)).astype(int)
            ES_LV_gt_temp = (ES_LV_gt_temp>((np.min(ES_LV_gt_temp)+np.max(ES_LV_gt_temp))/2)).astype(int)

            ED_LV_gt_batch.append(ED_LV_gt_temp)
            ES_LV_gt_batch.append(ES_LV_gt_temp)
            ED_LV_heatmap_batch.append(ED_LV_heatmap_temp)
            ES_LV_heatmap_batch.append(ES_LV_heatmap_temp)
            sequence_batch.append(sequence_temp)

        ED_LV_gt_new.append(ED_LV_gt_batch)
        ES_LV_gt_new.append(ES_LV_gt_batch)
        ED_LV_heatmap_new.append(ED_LV_heatmap_batch)
        ES_LV_heatmap_new.append(ES_LV_heatmap_batch)
        ED_LV_point_new.append(ED_LV_point_batch)
        ES_LV_point_new.append(ES_LV_point_batch)
        sequence_new.append(sequence_batch)



    return np.array(ED_LV_gt_new),np.array(ES_LV_gt_new),np.array(ED_LV_heatmap_new),np.array(ES_LV_heatmap_new),np.array(ED_LV_point_new),np.array(ES_LV_point_new),np.array(sequence_new)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

In [13]:
def loss_epoch_end(outputs):
    batch_losses = [x['loss'] for x in outputs]
    Seg_losses = [x['Seg_loss'] for x in outputs]
    HR_losses = [x['HR_loss'] for x in outputs]
    EB_losses = [x['EB_loss'] for x in outputs]

    epoch_loss = torch.stack(batch_losses).mean()
    epoch_Seg_loss = np.mean(Seg_losses)
    epoch_HR_loss = np.mean(HR_losses)
    epoch_EB_loss = np.mean(EB_losses)

    return {'epoch_loss':epoch_loss.item(),'epoch_Seg_loss':epoch_Seg_loss,'epoch_HR_loss':epoch_HR_loss,'epoch_EB_loss':epoch_EB_loss}
def result_epoch_end(epoch,epochs, train_result):
    print("Epoch [{}/{}], Train loss: {:.4f}, Seg loss: {:.4f}, HR loss: {:.4f}, EB loss: {:.4f} ".format(epoch+1,epochs, train_result['epoch_loss'], train_result['epoch_Seg_loss'], train_result['epoch_HR_loss'], train_result['epoch_EB_loss']))

In [14]:
def chMask(image):
    arr = np.array(image)
    # print('Hi: ',arr.shape)
    new_arr = np.zeros((arr.shape[0],arr.shape[1],arr.shape[2]))
    dice_arr = np.zeros((arr.shape[0],2,arr.shape[1],arr.shape[2]))
    # print(arr.shape)
    # print(new_arr.shape)
    # print('elements: ',np.unique(arr))
    ele = np.unique(arr)
    for i in range(arr.shape[0]):
        for y in range(arr.shape[2]):
            for x in range(arr.shape[1]):
                if arr[i, x,y] == ele[0]:
                    new_arr[i,x,y] = 0
                    dice_arr[i,0,x,y] = 1
                
                elif arr[i, x,y] == ele[1]:
                    new_arr[i,x,y] = 1
                    dice_arr[i,1,x,y] = 1
                
    #new_arr for cross entropy: N,H,W   and  dice_arr for dice loss: N,C,H,W 
    return new_arr,dice_arr

In [15]:
def dice_loss(pred, target):
    # print(target.shape,pred.shape)
    target = torch.tensor(target)
    smooth = 0.0001
    num = pred.size(0)
    m1 = pred.contiguous().view(num, -1)  # Flatten
    m1 = m1.to(device)
    # print(type(m1))
    m2 = target.contiguous().view(num, -1)  # Flatten
    m2 = m2.to(device)
    # print(type(m2))
    intersection = (m1 * m2).sum()
    return 1-(2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

In [16]:
def dice_coeff(pred, target):
    smooth = 0.0001

    # print(pred.shape,target.shape)
    
    pred_LV = pred[:,1]
    target_LV = target[:,1]
    intersection = (pred_LV * target_LV).sum()

 
    return (2. * intersection + smooth) / (pred_LV.sum() + target_LV.sum() + smooth)

In [17]:
def PCK(true_HR_list,pred_HR_list):
    CorrectKeypoint=0
    for i in range(len(true_HR_list)):
        true_HR = true_HR_list[i]
        pred_HR = pred_HR_list[i]
        dis = ((true_HR[0]-pred_HR[0])**2+(true_HR[1]-pred_HR[1])**2)**(1/2)
        for j in range(len(dis)):
            if dis[j]<=12.8:
                CorrectKeypoint=CorrectKeypoint+1

        TotalKeypoint=len(true_HR_list)*3
        Pck=CorrectKeypoint/TotalKeypoint

    return Pck

In [18]:
def random_click(mask_list,point_num=4, class_id=1):
    pt_list = []
    pt_label_list = []
    for mask in mask_list:
        indices = np.argwhere(mask == class_id)
        indices[:, [0,1]] = indices[:, [1,0]]
        point_label = class_id
        pt = indices[np.random.randint(len(indices),size=point_num)]
        pt_list.append(pt)
        pt_label_list.append([point_label]*4)
    return np.array(pt_list), np.array(pt_label_list)

In [19]:
def pos_neg_clicks(mask_list, class_id=1, pos_prompt_number=2, neg_prompt_number=2):
    pt_list = []
    pt_label_list = []
    for mask in mask_list:
        pos_indices = np.argwhere(mask == class_id)
        pos_indices[:, [0,1]] = pos_indices[:, [1,0]]
        pos_prompt_indices = np.random.randint(len(pos_indices), size=pos_prompt_number)
        pos_prompt = pos_indices[pos_prompt_indices]
        pos_label = np.repeat(1, pos_prompt_number)

        neg_indices = np.argwhere(mask != class_id)
        neg_indices[:, [0,1]] = neg_indices[:, [1,0]]
        neg_prompt_indices = np.random.randint(len(neg_indices), size=neg_prompt_number)
        neg_prompt = neg_indices[neg_prompt_indices]
        neg_label = np.repeat(0, neg_prompt_number)

        pt = np.vstack((pos_prompt, neg_prompt))
        point_label = np.hstack((pos_label, neg_label))
        pt_list.append(pt)
        pt_label_list.append(point_label)
    return np.array(pt_list), np.array(pt_label_list)

In [20]:
def anatomic_click(mask_list,key_point_list, class_id=1):
    pt_list = []
    pt_label_list = []
    assert len(key_point_list) == len(mask_list)
    for b_i in range(len(key_point_list)):
        point_label = class_id
        key_points = key_point_list[b_i].copy()
        key_points[:,[0,1]] =  key_points[:,[1,0]]
        # print('kp',key_points)
        mask = mask_list[b_i]

        indices = np.argwhere(mask == class_id)
        indices[:, [0,1]] = indices[:, [1,0]]
        pt = list(key_points)
        for pi in pt:
            if pi not in indices:
                arg_i = np.argmin(np.sum(np.abs(indices-pi),axis = 1))
                pi = indices[arg_i]
                re_f = True
                out_p = [pt,indices,mask]
        pt_list.append(pt)
        pt_label_list.append([point_label]*3)
    return np.array(pt_list), np.array(pt_label_list)

In [21]:
def fixed_bbox(mask_list, class_id = 1, img_size=256):
    bbox_list = []
    for mask in mask_list:
        indices = np.argwhere(mask == class_id) # Y X (0, 1)
        indices[:, [0,1]] = indices[:, [1,0]]
        if indices.shape[0] ==0:
            return np.array([-1, -1, img_size, img_size])
        minx = np.min(indices[:, 0])
        maxx = np.max(indices[:, 0])
        miny = np.min(indices[:, 1])
        maxy = np.max(indices[:, 1])
        bbox_list.append(np.array([minx, miny, maxx, maxy]))
    return np.array(bbox_list)

In [22]:
def train_EV_model_fin(model,NFrame,train_dl,val_dl,epochs,epochs_wm,opt,scheduler,sgpa_flag):
    MSE_loss_func = nn.MSELoss(reduction='mean')
    cosine_loss_func = nn.CosineEmbeddingLoss()
    history = []
    wins_len = 3
    valid_wins_loss = []
    best_epoch = 0
    best_model_params = copy.deepcopy(model.state_dict())
    best_loss = 99999.9
    for epoch in range(epochs):
        alpha = 100
        beta = 2000
        
        train_stats = []
        val_stats = []
        temp_gamma_flag = 1.0 - 1.0*epoch/epochs_wm
        gamma_flag = np.clip(temp_gamma_flag,0,1)
     
        if sgpa_flag == True:
            gamma = 100*gamma_flag
        else:          
            gamma = 0

        if epoch<=epochs_wm:
            temp_fuse_apg = 1.0*epoch/epochs_wm
            mask_r = 20
        else:
            temp_fuse_apg = 1.0
            if mask_r>10:
                mask_r = mask_r-0.25
            else:
                pass

        for train_dl_i in tqdm(train_dl):
            temp_ED_LV_gt = train_dl_i['ED_LV_gt'].numpy()
            temp_ES_LV_gt = train_dl_i['ES_LV_gt'].numpy()
            temp_ED_key_point = train_dl_i['ED_key_point'].numpy()
            temp_ES_key_point = train_dl_i['ES_key_point'].numpy()
            train_sequence = train_dl_i['sequence'].numpy()[:,:,[0,-1],:,:]
            im_size = train_dl_i['size'].squeeze(0).numpy()

            temp_ED_LV_gt,temp_ES_LV_gt,temp_ED_LV_heatmap,temp_ES_LV_heatmap,temp_ED_LV_point,temp_ES_LV_point,train_sequence = geo_aug_fin(True,temp_ED_LV_gt,temp_ES_LV_gt,temp_ED_key_point,temp_ES_key_point,train_sequence,mask_r)
            ED_LV_gt = np.squeeze(temp_ED_LV_gt)
            ES_LV_gt = np.squeeze(temp_ES_LV_gt)
            ED_LV_point = np.squeeze(temp_ED_LV_point)
            ES_LV_point = np.squeeze(temp_ES_LV_point)
            # print(ED_LV_gt.shape)

            _, ED_dice_mask = chMask(ED_LV_gt)
            _, ES_dice_mask = chMask(ES_LV_gt)


            ED_LV_heatmap = torch.tensor(np.squeeze(temp_ED_LV_heatmap), dtype = torch.float32).to(device)
            ES_LV_heatmap = torch.tensor(np.squeeze(temp_ES_LV_heatmap), dtype = torch.float32).to(device)


            temp_EDV_image = torch.tensor(train_sequence[:,:,0,:,:], dtype = torch.float32).squeeze(0).unsqueeze(1).to(device)
            temp_ESV_image = torch.tensor(train_sequence[:,:,-1,:,:], dtype = torch.float32).squeeze(0).unsqueeze(1).to(device)
            ED_pt,ED_pt_label = anatomic_click(ED_LV_gt,ED_LV_point, class_id=1)
            ES_pt,ES_pt_label = anatomic_click(ES_LV_gt,ES_LV_point, class_id=1)
            ED_bbox = fixed_bbox(ED_LV_gt, class_id=1)
            ES_bbox = fixed_bbox(ED_LV_gt, class_id=1)

            ED_pt = torch.tensor(ED_pt).to(device)
            ED_pt_label = torch.tensor(ED_pt_label).to(device)
            ES_pt = torch.tensor(ES_pt).to(device)
            ES_pt_label = torch.tensor(ES_pt_label).to(device)
            ED_bbox = torch.tensor(ED_bbox).to(device)
            ES_bbox = torch.tensor(ES_bbox).to(device)


            EDV_model_out = model(temp_EDV_image,(ED_pt,ED_pt_label), ED_bbox, fuse_apg=temp_fuse_apg)
            ESV_model_out = model(temp_ESV_image,(ES_pt,ES_pt_label), ES_bbox, fuse_apg=temp_fuse_apg)

            Seg_EDV_output =  EDV_model_out['seg_masks'].sigmoid()
            Seg_ESV_output =  ESV_model_out['seg_masks'].sigmoid()
            Seg_EDV_output = torch.concat([(1-Seg_EDV_output),Seg_EDV_output],dim=1)
            Seg_ESV_output = torch.concat([(1-Seg_ESV_output),Seg_ESV_output],dim=1)

            HR_EDV_output = EDV_model_out['hr_low_res'].sigmoid()
            HR_ESV_output = ESV_model_out['hr_low_res'].sigmoid()
            HR_EDV_output = HR_EDV_output.reshape(HR_EDV_output.shape[0],-1)
            HR_ESV_output = HR_ESV_output.reshape(HR_ESV_output.shape[0],-1)

            Seg_EDV_loss = dice_loss(Seg_EDV_output,ED_dice_mask)       
            Seg_ESV_loss = dice_loss(Seg_ESV_output,ES_dice_mask)

            
            HR_EDV_loss = MSE_loss_func(HR_EDV_output,ED_LV_heatmap)
            HR_ESV_loss = MSE_loss_func(HR_ESV_output,ES_LV_heatmap)


            Seg_ED_B,Seg_ED_C,Seg_ED_N = EDV_model_out['seg_se'].shape
            Seg_ED_se = EDV_model_out['seg_se'].reshape(Seg_ED_B*Seg_ED_C,Seg_ED_N)
            HR_ED_B,HR_ED_C,HR_ED_N = EDV_model_out['hr_se'].shape
            HR_ED_se = EDV_model_out['hr_se'].reshape(HR_ED_B*HR_ED_C,HR_ED_N)

            Seg_EDV_se_apg =  EDV_model_out['seg_se_apg'].reshape(Seg_ED_B*Seg_ED_C,Seg_ED_N)
            HR_EDV_se_apg =  EDV_model_out['hr_se_apg'].reshape(HR_ED_B*HR_ED_C,HR_ED_N)       

            Seg_ES_B,Seg_ES_C,Seg_ES_N = ESV_model_out['seg_se'].shape
            Seg_ES_se = ESV_model_out['seg_se'].reshape(Seg_ES_B*Seg_ES_C,Seg_ES_N)
            HR_ES_B,HR_ES_C,HR_ES_N = ESV_model_out['hr_se'].shape
            HR_ES_se = ESV_model_out['hr_se'].reshape(HR_ES_B*HR_ES_C,HR_ES_N)

            Seg_ESV_se_apg =  ESV_model_out['seg_se_apg'].reshape(Seg_ES_B*Seg_ES_C,Seg_ES_N)
            HR_ESV_se_apg =  ESV_model_out['hr_se_apg'].reshape(HR_ES_B*HR_ES_C,HR_ES_N)
            

            ED_se_loss = cosine_loss_func(Seg_EDV_se_apg,Seg_ED_se,torch.ones(Seg_ED_B*Seg_ED_C).to(device))+cosine_loss_func(HR_EDV_se_apg,HR_ED_se,torch.ones(HR_ED_B*HR_ED_C).to(device))
            ES_se_loss = cosine_loss_func(Seg_ESV_se_apg,Seg_ES_se,torch.ones(Seg_ES_B*Seg_ES_C).to(device))+cosine_loss_func(HR_ESV_se_apg,HR_ES_se,torch.ones(HR_ES_B*HR_ES_C).to(device))


            loss = alpha*(Seg_EDV_loss +Seg_ESV_loss)+beta*(HR_EDV_loss+HR_ESV_loss)+gamma*(ED_se_loss + ES_se_loss)
            opt.zero_grad()
            loss.backward()
            opt.step()
            train_stats.append([loss.item(),alpha*(Seg_EDV_loss +Seg_ESV_loss).item(),beta*(HR_EDV_loss+HR_ESV_loss).item(),gamma*(ED_se_loss + ES_se_loss).item()])

            
            del train_dl_i,EDV_model_out,ESV_model_out,Seg_EDV_output,Seg_ESV_output,HR_EDV_output,HR_ESV_output,Seg_EDV_se_apg,HR_EDV_se_apg,Seg_ESV_se_apg,HR_ESV_se_apg
            del loss,Seg_EDV_loss,Seg_ESV_loss,HR_EDV_loss,HR_ESV_loss,ED_se_loss,ES_se_loss

            # gc.collect()
            # torch.cuda.empty_cache()

        for val_dl_i in tqdm(val_dl):
            temp_ED_LV_gt = val_dl_i['ED_LV_gt'].numpy()
            temp_ES_LV_gt = val_dl_i['ES_LV_gt'].numpy()
            temp_ED_key_point = val_dl_i['ED_key_point'].numpy()
            temp_ES_key_point = val_dl_i['ES_key_point'].numpy()
            val_sequence = val_dl_i['sequence'].numpy()[:,:,[0,-1],:,:]
            im_size = val_dl_i['size'].squeeze(0).numpy()

            temp_ED_LV_gt,temp_ES_LV_gt,temp_ED_LV_heatmap,temp_ES_LV_heatmap,temp_ED_LV_point,temp_ES_LV_point,val_sequence = geo_aug_fin(False,temp_ED_LV_gt,temp_ES_LV_gt,temp_ED_key_point,temp_ES_key_point,val_sequence,mask_r)

            # geo_para =np.squeeze(val_dl_i['geo_para'].squeeze(0).numpy()).T
            # geo_sequence_para = np.round(np.linspace(geo_para[0,:], geo_para[-1,:], num=NFrame, endpoint=True))

            ED_LV_gt = np.squeeze(temp_ED_LV_gt)
            ES_LV_gt = np.squeeze(temp_ES_LV_gt)
            ED_LV_point = np.squeeze(temp_ED_LV_point)
            ES_LV_point = np.squeeze(temp_ES_LV_point)

            _, ED_dice_mask = chMask(ED_LV_gt)
            _, ES_dice_mask = chMask(ES_LV_gt)


            ED_LV_heatmap = torch.tensor(np.squeeze(temp_ED_LV_heatmap), dtype = torch.float32).to(device)
            ES_LV_heatmap = torch.tensor(np.squeeze(temp_ES_LV_heatmap), dtype = torch.float32).to(device)


            temp_EDV_image = torch.tensor(val_sequence[:,:,0,:,:], dtype = torch.float32).squeeze(0).unsqueeze(1).to(device)
            temp_ESV_image = torch.tensor(val_sequence[:,:,-1,:,:], dtype = torch.float32).squeeze(0).unsqueeze(1).to(device)
            ED_pt,ED_pt_label = anatomic_click(ED_LV_gt, ED_LV_point, class_id=1)
            ES_pt,ES_pt_label = anatomic_click(ES_LV_gt, ES_LV_point, class_id=1)
            ED_bbox = fixed_bbox(ED_LV_gt, class_id=1)
            ES_bbox = fixed_bbox(ED_LV_gt, class_id=1)

            ED_pt = torch.tensor(ED_pt).to(device)
            ED_pt_label = torch.tensor(ED_pt_label).to(device)
            ES_pt = torch.tensor(ES_pt).to(device)
            ES_pt_label = torch.tensor(ES_pt_label).to(device)
            ED_bbox = torch.tensor(ED_bbox).to(device)
            ES_bbox = torch.tensor(ES_bbox).to(device)

            EDV_model_out = model(temp_EDV_image,(ED_pt,ED_pt_label), ED_bbox, fuse_apg=temp_fuse_apg)
            ESV_model_out = model(temp_ESV_image,(ES_pt,ES_pt_label), ES_bbox, fuse_apg=temp_fuse_apg)

            Seg_EDV_output =  EDV_model_out['seg_masks'].sigmoid()
            Seg_ESV_output =  ESV_model_out['seg_masks'].sigmoid()
            Seg_EDV_output = torch.concat([(1-Seg_EDV_output),Seg_EDV_output],dim=1)
            Seg_ESV_output = torch.concat([(1-Seg_ESV_output),Seg_ESV_output],dim=1)

            HR_EDV_output = EDV_model_out['hr_low_res'].sigmoid()
            HR_ESV_output = ESV_model_out['hr_low_res'].sigmoid()
            HR_EDV_output = HR_EDV_output.reshape(HR_EDV_output.shape[0],-1)
            HR_ESV_output = HR_ESV_output.reshape(HR_ESV_output.shape[0],-1)

            Seg_EDV_loss = dice_loss(Seg_EDV_output,ED_dice_mask).item() 
            Seg_ESV_loss = dice_loss(Seg_ESV_output,ES_dice_mask).item()   

            
            HR_EDV_loss = MSE_loss_func(HR_EDV_output,ED_LV_heatmap).item()   
            HR_ESV_loss = MSE_loss_func(HR_ESV_output,ES_LV_heatmap).item()

            Seg_ED_B,Seg_ED_C,Seg_ED_N = EDV_model_out['seg_se'].shape
            Seg_ED_se = EDV_model_out['seg_se'].reshape(Seg_ED_B*Seg_ED_C,Seg_ED_N)
            HR_ED_B,HR_ED_C,HR_ED_N = EDV_model_out['hr_se'].shape
            HR_ED_se = EDV_model_out['hr_se'].reshape(HR_ED_B*HR_ED_C,HR_ED_N)

            Seg_EDV_se_apg =  EDV_model_out['seg_se_apg'].reshape(Seg_ED_B*Seg_ED_C,Seg_ED_N)
            HR_EDV_se_apg =  EDV_model_out['hr_se_apg'].reshape(HR_ED_B*HR_ED_C,HR_ED_N)       

            Seg_ES_B,Seg_ES_C,Seg_ES_N = ESV_model_out['seg_se'].shape
            Seg_ES_se = ESV_model_out['seg_se'].reshape(Seg_ES_B*Seg_ES_C,Seg_ES_N)
            HR_ES_B,HR_ES_C,HR_ES_N = ESV_model_out['hr_se'].shape
            HR_ES_se = ESV_model_out['hr_se'].reshape(HR_ES_B*HR_ES_C,HR_ES_N)

            Seg_ESV_se_apg =  ESV_model_out['seg_se_apg'].reshape(Seg_ES_B*Seg_ES_C,Seg_ES_N)
            HR_ESV_se_apg =  ESV_model_out['hr_se_apg'].reshape(HR_ES_B*HR_ES_C,HR_ES_N)
            

            ED_se_loss = cosine_loss_func(Seg_EDV_se_apg,Seg_ED_se,torch.ones(Seg_ED_B*Seg_ED_C).to(device)).item()+cosine_loss_func(HR_EDV_se_apg,HR_ED_se,torch.ones(HR_ED_B*HR_ED_C).to(device)).item()
            ES_se_loss = cosine_loss_func(Seg_ESV_se_apg,Seg_ES_se,torch.ones(Seg_ES_B*Seg_ES_C).to(device)).item()+cosine_loss_func(HR_ESV_se_apg,HR_ES_se,torch.ones(HR_ES_B*HR_ES_C).to(device)).item()

            loss = alpha*(Seg_EDV_loss +Seg_ESV_loss)+beta*(HR_EDV_loss+HR_ESV_loss)+gamma*(ED_se_loss+ES_se_loss)

            val_stats.append([loss,alpha*(Seg_EDV_loss +Seg_ESV_loss),beta*(HR_EDV_loss+HR_ESV_loss),gamma*(ED_se_loss+ES_se_loss)])

            del val_dl_i,EDV_model_out,ESV_model_out,Seg_EDV_output,Seg_ESV_output,HR_EDV_output,HR_ESV_output,Seg_EDV_se_apg,HR_EDV_se_apg,Seg_ESV_se_apg,HR_ESV_se_apg
            del loss,Seg_EDV_loss,Seg_ESV_loss,HR_EDV_loss,HR_ESV_loss,ED_se_loss,ES_se_loss
            # gc.collect()
            # torch.cuda.empty_cache()

        scheduler.step()


        print('Epoch: ',epoch, 'Train Loss:', '%.4f' % np.mean(np.array(train_stats)[:,0]),'%.4f' %  np.mean(np.array(train_stats)[:,1]),'%.4f' % np.mean(np.mean(np.array(train_stats)[:,2])),
            '%.4f' % np.mean(np.array(train_stats)[:,3]),'Valid Loss:', '%.4f' % np.mean(np.array(val_stats)[:,0]))
        if np.mean(np.array(val_stats)[:,0])<best_loss:
            best_epoch = epoch+1
            best_model_params = copy.deepcopy(model.state_dict())
            best_loss = np.mean(np.array(val_stats)[:,0])

    model.load_state_dict(best_model_params)
    print('Best Validation Loss:{:.4f} at Epoch:{}'.format(best_loss,best_epoch))
    return model,history

In [23]:
# EV_model_fin =  get_model(args.modelname, args=args, opt=opt)
# EV_model_fin = EV_model_fin.to(device)
# EV_opt_fin = torch.optim.Adam(EV_model_fin.parameters(), lr=lr)
# scheduler =  torch.optim.lr_scheduler.LambdaLR(EV_opt_fin, lr_lambda=lambda1)
# EV_model_fin,EV_history_fin = train_EV_model_fin(EV_model_fin,NFrame,traindata,valdata,epochs_fin,epochs_wm,EV_opt_fin,scheduler,args.sgpa_flag)

In [24]:
def heatmap_2_key_point(heatmap,temp_seg,ratio):
    raw_key_point = np.zeros((3,2))

    for idx in range(3):
        img = cv2.resize(heatmap[idx],(temp_seg.shape[0],temp_seg.shape[1]))
        img = img*temp_seg
        M = img.argmax()
        raw_key_point[idx][0] = M//img.shape[1]
        raw_key_point[idx][1] = M%img.shape[1]
    return raw_key_point

In [25]:
def HR_output_2_key_point(key_point_true,HR_output,temp_seg,ratio=2):
    H = W = round(np.sqrt(HR_output.shape[1]//3))
    dis_list = []
    kp_list = []
    for paired_i in range(HR_output.shape[0]):
        key_point_output = heatmap_2_key_point(np.reshape(HR_output[paired_i],(3,H,W)),temp_seg[paired_i],ratio)
        # print(key_point_output.shape)
        # print(key_point_true[paired_i].shape)
        dis = np.sqrt(np.sum(np.square(key_point_output - key_point_true[paired_i]),axis=1))
        dis_list.append(dis)
        kp_list.append(key_point_output)
    dis_array = np.array(dis_list)
    kp_array = np.array(kp_list)
    
    
    return dis_array, kp_array

In [26]:
def close_dropout(m):
    if type(m) == nn.Dropout:
        m.eval()

In [27]:
def heatmap_2_key_point(heatmap,temp_seg,ratio):
    raw_key_point = np.zeros((3,2))
    img_plt = np.zeros_like(temp_seg)

    for idx in range(3):
        
        img = cv2.resize(heatmap[idx],(temp_seg.shape[0],temp_seg.shape[1]))
        M = img.argmax()
        raw_key_point[idx][0] = M//img.shape[1]
        raw_key_point[idx][1] = M%img.shape[1]

    return raw_key_point

def HR_output_2_key_point(key_point_true,HR_output,temp_seg,ratio=2):
    H = W = round(np.sqrt(HR_output.shape[1]//3))
    # print('H, W: ',H,W)
    dis_list = []
    kp_list = []
    for paired_i in range(HR_output.shape[0]):
        key_point_output = heatmap_2_key_point(np.reshape(HR_output[paired_i],(3,H,W)),temp_seg[paired_i],ratio)
        dis = np.sqrt(np.sum(np.square(key_point_output - key_point_true[paired_i]),axis=1))
        dis_list.append(dis)
        kp_list.append(key_point_output)
    dis_array = np.array(dis_list)
    kp_array = np.array(kp_list)
    
    
    return dis_array, kp_array

In [28]:
def post_processing(batch_im):
    batch = batch_im.shape[0]
    batch_im_new = np.zeros_like(batch_im)
    for image_i in range(batch):
        image = batch_im[image_i]

        for_which_classes = np.unique(image)
        for_which_classes = for_which_classes[for_which_classes > 0]

        assert 0 not in for_which_classes, "cannot remove background"
        largest_removed = {}
        kept_size = {}
        for c in for_which_classes:
            if isinstance(c, (list, tuple)):
                c = tuple(c)  # otherwise it cant be used as key in the dict
                mask = np.zeros_like(image, dtype=bool)
                for cl in c:
                    mask[image == cl] = True
            else:
                mask = image == c
            # get labelmap and number of objects
            lmap, num_objects = label(mask.astype(int))

            # collect object sizes
            object_sizes = {}
            for object_id in range(1, num_objects + 1):
                object_sizes[object_id] = (lmap == object_id).sum() 

            largest_removed[c] = None
            kept_size[c] = None

            if num_objects > 0:
                # we always keep the largest object. We could also consider removing the largest object if it is smaller
                # than minimum_valid_object_size in the future but we don't do that now.
                maximum_size = max(object_sizes.values())
                kept_size[c] = maximum_size

                for object_id in range(1, num_objects + 1):
                    # we only remove objects that are not the largest
                    if object_sizes[object_id] != maximum_size:
                        # we only remove objects that are smaller than minimum_valid_object_size
                        remove = True
                        if remove:
                            image[(lmap == object_id) & mask] = 0
                            if largest_removed[c] is None:
                                largest_removed[c] = object_sizes[object_id]
                            else:
                                largest_removed[c] = max(largest_removed[c], object_sizes[object_id])
        
        image = image.astype(int)
        
        h, w = image.shape[:2]
        floodfill_mask = np.zeros((h+2, w+2), np.uint8)
        isbreak = False
        for i in range(image.shape[0]):
            for j in range(image.shape[1]):
                if(image[i][j]==0):
                    seedPoint=(i,j)
                    isbreak = True
                    break
            if(isbreak):
                break
        image = image.astype(np.uint8).copy()
        cv2.floodFill(image, floodfill_mask,seedPoint, 255)
        # 孔洞填充函数
        img_floodfill_raw = cv2.bitwise_not(image)
        batch_im_new[image_i] = img_floodfill_raw

    return batch_im_new

In [29]:
def test_EV_model_fin(model,test_dl):
    test_Seg_EDV_list = []
    test_Seg_ESV_list = []
    test_HR_EDV_list = []
    test_HR_ESV_list = []
    gt_para_list = []
    pred_para_list = []
    model.apply(close_dropout)
    with torch.no_grad():
        idx = 0
        for test_dl_i in tqdm(test_dl):
            idx = idx+1
            temp_ED_LV_gt = test_dl_i['ED_LV_gt'].squeeze(0).numpy()
            temp_ES_LV_gt = test_dl_i['ES_LV_gt'].squeeze(0).numpy()
            temp_ED_key_point = test_dl_i['ED_key_point'].squeeze(0).numpy()
            temp_ES_key_point = test_dl_i['ES_key_point'].squeeze(0).numpy()
            test_sequence= test_dl_i['sequence'].squeeze(0).numpy()[:,[0,-1],:,:]
            im_size = test_dl_i['size'].squeeze(0).numpy()

            temp_EF_para = test_dl_i['EF_para']
            temp_EV_para = test_dl_i['EV_para']

            ED_LV_gt = temp_ED_LV_gt
            ES_LV_gt = temp_ES_LV_gt
            
            _, ED_dice_mask = chMask(ED_LV_gt)
            _, ES_dice_mask = chMask(ES_LV_gt)

            temp_EDV_image = torch.tensor(test_sequence[:,0,:,:], dtype = torch.float32).unsqueeze(1)
            temp_ESV_image = torch.tensor(test_sequence[:,-1,:,:], dtype = torch.float32).unsqueeze(1)

            ED_pt,ED_pt_label = anatomic_click(ED_LV_gt, temp_ED_key_point, class_id=1)
            ES_pt,ES_pt_label = anatomic_click(ES_LV_gt, temp_ES_key_point, class_id=1)
            ED_bbox = fixed_bbox(ED_LV_gt, class_id=1)
            ES_bbox = fixed_bbox(ED_LV_gt, class_id=1)

            ED_pt = torch.tensor(ED_pt).to(device)
            ED_pt_label = torch.tensor(ED_pt_label).to(device)
            ES_pt = torch.tensor(ES_pt).to(device)
            ES_pt_label = torch.tensor(ES_pt_label).to(device)
            ED_bbox = torch.tensor(ED_bbox).to(device)
            ES_bbox = torch.tensor(ES_bbox).to(device)
            temp_EDV_image = temp_EDV_image.to(device)
            temp_ESV_image = temp_ESV_image.to(device)
            ED_pt = torch.tensor(ED_pt).to(device)
            ED_pt_label = torch.tensor(ED_pt_label).to(device)
            ES_pt = torch.tensor(ES_pt).to(device)
            ES_pt_label = torch.tensor(ES_pt_label).to(device)
            EDV_model_out = model(temp_EDV_image,(ED_pt,ED_pt_label), ED_bbox, fuse_apg=1.0)
            ESV_model_out = model(temp_ESV_image,(ES_pt,ES_pt_label), ES_bbox, fuse_apg=1.0)
            Seg_EDV_output =  EDV_model_out['seg_masks'].sigmoid()>0.5
            Seg_ESV_output =  ESV_model_out['seg_masks'].sigmoid()>0.5

            HR_EDV_output = EDV_model_out['hr_low_res'].sigmoid()
            HR_ESV_output = ESV_model_out['hr_low_res'].sigmoid()
            HR_EDV_output = HR_EDV_output.reshape(HR_EDV_output.shape[0],-1)
            HR_ESV_output = HR_ESV_output.reshape(HR_ESV_output.shape[0],-1)


            # EDV_im = torch.argmax(Seg_EDV_output,dim=1)
            Seg_EDV_output_np = Seg_EDV_output.squeeze().cpu().numpy().astype(float)
            Seg_EDV_output_np = post_processing(Seg_EDV_output_np)
            _, Seg_EDV_output_dice = chMask(Seg_EDV_output_np)
            # ESV_im = torch.argmax(Seg_ESV_output,dim=1)
            Seg_ESV_output_np = Seg_ESV_output.squeeze().cpu().numpy().astype(float)
            Seg_ESV_output_np = post_processing(Seg_ESV_output_np)
            _, Seg_ESV_output_dice = chMask(Seg_ESV_output_np)

            test_Seg_EDV_list.append([dice_coeff(Seg_EDV_output_dice,ED_dice_mask)])
            test_Seg_ESV_list.append([dice_coeff(Seg_ESV_output_dice,ES_dice_mask)])

            EDV_key_point_dis,EDV_key_point_pre = HR_output_2_key_point(temp_ED_key_point,HR_EDV_output.squeeze().cpu().numpy(),ratio=2,temp_seg = Seg_EDV_output_np)
            ESV_key_point_dis,ESV_key_point_pre = HR_output_2_key_point(temp_ES_key_point,HR_ESV_output.squeeze().cpu().numpy(),ratio=2,temp_seg = Seg_ESV_output_np)
            test_HR_EDV_list.append(EDV_key_point_dis)
            test_HR_ESV_list.append(ESV_key_point_dis)

        df_Seg_para=pd.DataFrame([[np.mean((np.mean(np.array(test_Seg_EDV_list)[:,0]),np.mean(np.array(test_Seg_ESV_list)[:,0])))]],
                                   columns=['Dice'])
        df_HR_para=pd.DataFrame([[np.mean((np.mean(np.array(test_HR_EDV_list)<12.8),np.mean(np.array(test_HR_ESV_list)<12.8)))]],
                                     columns=['PCK'])
    return df_Seg_para,df_HR_para

In [30]:
# df_Seg_para,df_HR_para = test_EV_model_fin(EV_model_fin,testdata)