In [1]:
import sys
import torch

sys.path.append('./unetr_pp/')

from abc import ABCMeta, abstractmethod

from network_architecture.acdc.unetr_pp_acdc_cls import UNETR_PP

from mutil_class_loss import FocalLoss, cal_auc
from weighted_auc_f1 import get_weighted_auc_f1

from sklearn.model_selection import train_test_split
import os
from PIL import Image
import torch
from torchvision import transforms
import pandas as pd
from skimage import transform
import numpy as np
from torch import nn
import SimpleITK as sitk
from torch.utils.data import DataLoader
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from imgaug import augmenters as iaa
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt



In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dummy_labels =  ['HCM', 'RV', 'DCM', 'MINF']
num_class = len(dummy_labels)
unter_pp = UNETR_PP(in_channels=1, out_channels=4, dims=[32, 64, 128, 256])

phase = 'test'

In [6]:
if phase == 'train':
    pretrained_dict = torch.load("../unetr_plus_plus-main/Acdc_ckpt/model_final_checkpoint.model", map_location="cpu")
    pretrained_dict  = pretrained_dict['state_dict'] 
    model_dict = unter_pp.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
    model_dict.update(pretrained_dict)
    index = 0
    for name, param in unter_pp.named_parameters():
        index += 1
        if index == 1 or index == 2: continue

        if name in pretrained_dict and pretrained_dict[name].size() == param.size():
            # print(f"Loading {name} from pretrained")
            param.data.copy_(pretrained_dict[name])
        else:
            pass
            # print(f"Skipping {name}, not found or size mismatch")
else:
    _pretrained_dict = torch.load('./train_model/UNTERPP/epoch_200.pth')
    pretrained_dict = {}
    for k, v in _pretrained_dict.items():
        if k.startswith('module.'):
            new_key = k.replace('module.', '')
        elif k.startswith('cls_head.'):
            new_key = k.replace('cls_head.', '')
        else:
            new_key = 

        pretrained_dict[new_key] = v
    model_dict = unter_pp.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
    model_dict.update(pretrained_dict)
    unter_pp.load_state_dict(pretrained_dict, strict=False)

加载数据

In [7]:
train_data = pd.read_csv('./roi_processed/training/train_data.csv', encoding='GBK')
test_data = pd.read_csv('./roi_processed/testing/test_data.csv', encoding='GBK')

train_data = train_data[train_data['Finding Labels'].isin(dummy_labels)]
test_data = test_data[test_data['Finding Labels'].isin(dummy_labels)]

In [8]:
# One Hot Encoding of Finding Labels to dummy_labels
for label in dummy_labels:
    train_data[label] = train_data['Finding Labels'].map(lambda result: 1.0 if label in result else 0)
    
# One Hot Encoding of Finding Labels to dummy_labels
for label in dummy_labels:
    test_data[label] = test_data['Finding Labels'].map(lambda result: 1.0 if label in result else 0)

In [9]:
train_data['target_vector'] = train_data.apply(lambda target: [target[dummy_labels].values], 1).map(lambda target: target[0])
test_data['target_vector'] = test_data.apply(lambda target: [target[dummy_labels].values], 1).map(lambda target: target[0])

In [10]:
clean_labels = train_data[dummy_labels].sum().sort_values(ascending= False) # get sorted value_count for clean labels
print(f'train size：')
print(clean_labels)

print(f'test size：')
clean_labels = test_data[dummy_labels].sum().sort_values(ascending= False) # get sorted value_count for clean labels
print(clean_labels)

train size：
HCM     20.0
RV      20.0
DCM     20.0
MINF    20.0
dtype: float64
test size：
HCM     10.0
RV      10.0
DCM     10.0
MINF    10.0
dtype: float64


In [11]:
import torch
import os
from PIL import Image
import torch
from torchvision import transforms
import pandas as pd
from skimage import transform
import numpy as np
from torch import nn
import SimpleITK as sitk
from torch.utils.data import DataLoader
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from imgaug import augmenters as iaa
import nibabel as nib

import albumentations as A
from albumentations.pytorch import ToTensorV2

class ACDC(torch.utils.data.Dataset):
    def __init__(self, data=None, phase = 'train', img_size=(224, 224)):
        self.img_size = img_size
        self.datas = data
     
        self.seq = iaa.Sequential([
            iaa.PadToFixedSize(width=160, height=160, position="center"),
            iaa.GaussianBlur(sigma=(0.1, 0.25)),
            iaa.Affine(rotate=(-45, 45)),
        ], random_order=False)

    def __len__(self):
        return len(self.datas)
    

    def pad_or_truncate_T(self, video_array, target_frames=25):
        C, T, H, W = video_array.shape

        if T == target_frames:
            return video_array

        elif T > target_frames:
            return video_array[:, :target_frames, :, :]

        else:
            pad_frames = target_frames - T
            repeat_times = (pad_frames + T - 1) // T + 1
            repeated_frames = np.tile(video_array[:, :T, :, :], (1, repeat_times, 1, 1))[:, :pad_frames + T, :, :]
            return repeated_frames  # 3 25 224 224


    def __getitem__(self, index):
        fpath = self.datas.iloc[index]['full_path']
        linux_path = fpath.replace("\\", "/")
        image = nib.load(linux_path)
        image_array = image.get_fdata()
        image_array=np.transpose(image_array, (3, 2, 0, -1, 1))[0]
        
        seq_det = self.seq.to_deterministic()
        image_list = seq_det(images=image_array.astype(np.float32))
        
        image_array = np.array(image_list)
        image_array=np.transpose(image_array, (3, 0, 2, 1))
        if image_array.shape[0] != 3 or image_array.shape[1] < 5 or image_array.shape[2] != 160 or image_array.shape[3] != 160:
            print(f"problem data : {fpath}")
            print(f'image_array {image_array.shape}')
        image_array = self.pad_or_truncate_T(image_array)
        
        image_array = image_array[:, :16, :, :]
        image_tensor = torch.from_numpy(image_array).float()
        label = self.datas.iloc[index]['target_vector']
        label = label.astype(np.int64)
        label = torch.from_numpy(label).float()
        
        image_tensor = image_tensor[1].unsqueeze(0)
        return image_tensor, self.datas.iloc[index]['Finding Labels'], label

  check_for_updates()


model train.

In [12]:
base_lr = 0.0005
batch_size = 8
max_epoch = 600
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f'device {device}')
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

unter_pp.to(device)
unter_pp = unter_pp.to(device)
fn_loss  = FocalLoss(device = device, gamma = 2.).to(device)

optimizer = torch.optim.SGD(unter_pp.parameters(), lr=base_lr)

train_acdc_data = ACDC(data=train_data, phase = 'train', img_size=(224, 224))
train_data_loader = DataLoader(train_acdc_data, batch_size=batch_size, shuffle=True, num_workers=5)
test_acdc_data = ACDC(data=test_data, phase = 'test', img_size=(224, 224))
test_data_loader = DataLoader(test_acdc_data, batch_size=batch_size, shuffle=True, num_workers=5)

device cuda:0


In [13]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=f'./runs/UNTER++')

In [14]:
if phase == 'train':
    import time
    for epoch_num in range(0, max_epoch):
        print(f"--------> epoch_num: {epoch_num}")
        train_loader_nums = len(train_data_loader.dataset)
        probs = np.zeros((train_loader_nums, num_class), dtype = np.float32)
        gt    = np.zeros((train_loader_nums, num_class), dtype = np.float32)
        k=0
        start_time = time.time()
        total_train_loss = 0.0
        correct = 0.0
        unter_pp.train()
        for train_data_batch, batch_finding, train_labels_batch in train_data_loader:
            train_data_batch = train_data_batch.to(device)
            train_labels_batch = train_labels_batch.to(device)
            outputs = unter_pp(train_data_batch)
            outputs = outputs.reshape(outputs.shape[0], -1) 
            train_labels_batch = train_labels_batch.reshape(train_labels_batch.shape[0], -1)
            probs[k: k + outputs.shape[0], :] = outputs.cpu().detach().numpy()
            gt[   k: k + outputs.shape[0], :] = train_labels_batch.cpu().detach().numpy()
            k += outputs.shape[0]

            train_loss = fn_loss(outputs, train_labels_batch)
            total_train_loss += train_loss
            
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            
            predicted = torch.argmax(outputs, 1)
            labels = torch.argmax(train_labels_batch, 1)
            correct += (predicted == labels).sum().item()
        
        auc = cal_auc(gt, probs)
        print(f"epoch_num {epoch_num} av train loss {total_train_loss}  train auc {auc} train acc {correct/train_loader_nums}")  
        
        writer.add_scalars('Training Metrics', {
            'Loss': total_train_loss / train_loader_nums*20,
            'Accuracy': correct / train_loader_nums,
            'AUC': auc,
        }, epoch_num)
        
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"程序运行时间：{elapsed_time} 秒")

        lr_ = base_lr*(1-0.0009)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_

        save_interval = 25  # int(max_epoch/6)
        if (epoch_num + 1) % save_interval == 0:
            save_mode_path = os.path.join(f'./train_model/UNTERPP', 'epoch_' + str(epoch_num+1) + '.pth')
            torch.save(unter_pp.state_dict(), save_mode_path)
            print("save model to {}".format(save_mode_path))
        test_interval = 25  # int(max_epoch/6)
        if (epoch_num + 1) % test_interval == 0:
            test_loader_nums = len(test_data_loader.dataset)
            test_probs = np.zeros((test_loader_nums, num_class), dtype = np.float32)
            test_gt    = np.zeros((test_loader_nums, num_class), dtype = np.float32)
            test_k  =0
            unter_pp.eval()
            with torch.no_grad():
                for test_data_batch, _, test_label_batch in test_data_loader:
                    test_data_batch = test_data_batch.to("cuda:0")
                    test_label_batch = test_label_batch.to("cuda:0")
                    test_outputs = unter_pp(test_data_batch)
                    test_outputs = test_outputs.reshape(test_outputs.shape[0], -1)           
                    test_label_batch = test_label_batch.reshape(test_outputs.shape[0], -1)
                    test_probs[test_k: test_k + test_outputs.shape[0], :] = test_outputs.cpu().detach().numpy()
                    test_gt[   test_k: test_k + test_outputs.shape[0], :] = test_label_batch.cpu().detach().numpy()
                    test_k += test_outputs.shape[0]
                test_label = np.argmax(test_gt, axis=1)
                test_pred = np.argmax(test_probs, axis=1)
                print(f"auc: {cal_auc(test_gt, test_probs)} | acc: {np.sum(test_label==test_pred)/test_loader_nums}")

### model eval.

In [15]:
total_acc_list = []
total_auroc_list = []

total_weight_auroc_list = []
total_weight_acc_list = []
### eval.
for i in range(10):
    test_loader_nums = len(test_data_loader.dataset)
    test_probs = np.zeros((test_loader_nums, len(dummy_labels)), dtype = np.float32)
    test_gt    = np.zeros((test_loader_nums, len(dummy_labels)), dtype = np.float32)
    test_k  =0
    unter_pp.eval()
    with torch.no_grad():
        for test_data_batch, _, test_label_batch in test_data_loader:
            test_data_batch = test_data_batch.cuda()
            test_label_batch = test_label_batch.cuda()
            test_outputs = unter_pp(test_data_batch.cuda())
            test_outputs = test_outputs.reshape(test_outputs.shape[0], -1)           
            test_label_batch = test_label_batch.reshape(test_outputs.shape[0], -1)
            test_probs[test_k: test_k + test_outputs.shape[0], :] = test_outputs.cpu().detach().numpy()
            test_gt[   test_k: test_k + test_outputs.shape[0], :] = test_label_batch.cpu().detach().numpy()
            test_k += test_outputs.shape[0]
        test_label = np.argmax(test_gt, axis=1)
        test_pred = np.argmax(test_probs, axis=1)
        weight_auc, auc_list = get_weighted_auc_f1(test_probs, test_pred, test_label)

        cm = confusion_matrix(test_label, test_pred)
        dataset_list = [10, 10, 10, 10] 
        acc_list = []
        weighted_acc = 0.0
        for i in range(len(dataset_list)):
            weight = dataset_list[i] / sum(dataset_list)
            correct = cm[i][i]
            acc = float(correct) / dataset_list[i]
            acc_list.append(acc)
            weighted_acc += weight*acc 
        
        total_auroc_list.append(auc_list)
        total_acc_list.append(acc_list)
        total_weight_auroc_list.append(weight_auc)
        total_weight_acc_list.append(weighted_acc)

--------------------------------------------------
auc_list : [0.9533333333333334, 0.8133333333333332, 0.7466666666666666, 0.54]
weighted_auroc:  0.7633333333333333
weighted_F1:  nan
--------------------------------------------------
auc_list : [0.9666666666666667, 0.74, 0.72, 0.5766666666666668]
weighted_auroc:  0.7508333333333334
weighted_F1:  nan
--------------------------------------------------
auc_list : [0.9633333333333334, 0.7566666666666666, 0.7, 0.6266666666666666]
weighted_auroc:  0.7616666666666666
weighted_F1:  0.4041478129713424
--------------------------------------------------
auc_list : [0.9566666666666666, 0.84, 0.7066666666666667, 0.55]
weighted_auroc:  0.7633333333333332
weighted_F1:  0.4257970647773279
--------------------------------------------------
auc_list : [0.9466666666666668, 0.6866666666666666, 0.7133333333333334, 0.5266666666666666]
weighted_auroc:  0.7183333333333333
weighted_F1:  0.3973039215686275
--------------------------------------------------
auc_

In [16]:
auc_arr = np.array(total_auroc_list)
print(auc_arr.shape)
for i in range(auc_arr.shape[-1]):
    auc_arr_cls = auc_arr[:, i]
    mean = np.mean(auc_arr_cls)
    std = np.std(auc_arr_cls)
    print(mean, std)

(10, 4)
0.9550000000000001 0.010027739304327554
0.7656666666666666 0.048489403195154095
0.7183333333333334 0.018929694486000907
0.5606666666666668 0.0347626875319565


In [17]:
acc_arr = np.array(total_acc_list)
print(acc_arr.shape)
for i in range(auc_arr.shape[-1]):
    acc_arr_cls = acc_arr[:, i]
    mean = np.mean(acc_arr_cls)
    std = np.std(acc_arr_cls)
    print(mean, std)

(10, 4)
0.72 0.0871779788708135
0.05 0.05000000000000001
0.9000000000000001 0.06324555320336757
0.21000000000000002 0.05385164807134503


END