In [1]:
import sys
import torch

sys.path.append('./CMR-AI/mmaction/')
sys.path.append('./CMR-AI/')

from abc import ABCMeta, abstractmethod
from swinTransformer3D_origin import SwinTransformer3D
from mutil_class_loss import FocalLoss, cal_auc
from weighted_auc_f1 import get_weighted_auc_f1
from load_dataset import ACDC

from sklearn.model_selection import train_test_split
import os
from PIL import Image
import torch
from torchvision import transforms
import pandas as pd
from skimage import transform
import numpy as np
from torch import nn
import SimpleITK as sitk
from torch.utils.data import DataLoader
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from imgaug import augmenters as iaa
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import nibabel as nib
from sklearn.metrics import confusion_matrix

  check_for_updates()


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_class = 4
funsion_model = SwinTransformer3D(num_class=num_class)

phase = 'test'

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
if phase == 'train':
    # _pretrained_dict = torch.load(r'../pretrained model/swin_base_patch244_window877_kinetics600_22k.pth')
    _pretrained_dict = torch.load(r'./ACDC/organmnist3d/organmnist3d_250.pth')
    pretrained_dict = {}
    for k, v in _pretrained_dict['state_dict'].items():
        if k.startswith('backbone.'):
            new_key = k.replace('backbone.', '')
        elif k.startswith('cls_head.'):
            new_key = k.replace('cls_head.', '')
        else:
            new_key = k
        pretrained_dict[new_key] = v
else:
    _pretrained_dict = torch.load(r'./train_model/VST/epoch_250.pth')
    pretrained_dict = {}
    for k, v in _pretrained_dict.items():
        if k.startswith('module.'):
            new_key = k.replace('module.', '')
        elif k.startswith('cls_head.'):
            new_key = k.replace('cls_head.', '')
        else:
            new_key = k
        pretrained_dict[new_key] = v
    
model_dict = funsion_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
funsion_model.load_state_dict(pretrained_dict, strict=False)

<All keys matched successfully>

In [4]:
train_data = pd.read_csv('./roi_processed/training/train_data.csv', encoding='GBK')
test_data = pd.read_csv('./roi_processed/testing/test_data.csv', encoding='GBK')
dummy_labels = ['HCM', 'RV', 'DCM', 'MINF']# ['HCM', 'DCM', 'MINF', 'RV', 'NOR'] # taken from paper
train_data = train_data[train_data['Finding Labels'].isin(dummy_labels)]
test_data = test_data[test_data['Finding Labels'].isin(dummy_labels)]

In [5]:
# One Hot Encoding of Finding Labels to dummy_labels
for label in dummy_labels:
    train_data[label] = train_data['Finding Labels'].map(lambda result: 1.0 if label in result else 0)

In [6]:
# One Hot Encoding of Finding Labels to dummy_labels
for label in dummy_labels:
    test_data[label] = test_data['Finding Labels'].map(lambda result: 1.0 if label in result else 0)

In [7]:
train_data['target_vector'] = train_data.apply(lambda target: [target[dummy_labels].values], 1).map(lambda target: target[0])

test_data['target_vector'] = test_data.apply(lambda target: [target[dummy_labels].values], 1).map(lambda target: target[0])

In [8]:
clean_labels = train_data[dummy_labels].sum().sort_values(ascending= False) # get sorted value_count for clean labels
print(f'train size：')
print(clean_labels)

train size：
HCM     20.0
RV      20.0
DCM     20.0
MINF    20.0
dtype: float64


In [9]:
print(f'test size：')
clean_labels = test_data[dummy_labels].sum().sort_values(ascending= False) # get sorted value_count for clean labels
print(clean_labels)

test size：
HCM     10.0
RV      10.0
DCM     10.0
MINF    10.0
dtype: float64


## 训练开始

In [10]:
base_lr = 0.0005
batch_size = 8
max_epoch = 600
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    funsion_model = nn.DataParallel(funsion_model)

funsion_model = funsion_model.cuda()
fn_loss  = FocalLoss(device = device, gamma = 2.).to(device)
cross_loss = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(funsion_model.parameters(), lr=base_lr)

train_acdc_data = ACDC(data=train_data, phase = 'train', img_size=(224, 224))
train_data_loader = DataLoader(train_acdc_data, batch_size=batch_size, shuffle=True, num_workers=5)
test_acdc_data = ACDC(data=test_data, phase = 'test', img_size=(224, 224))
test_data_loader = DataLoader(test_acdc_data, batch_size=batch_size, shuffle=True, num_workers=5)

In [11]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir='./runs/VST-ACDC') 

### model train.

In [12]:
if phase == 'train':
    import time
    for epoch_num in range(0, max_epoch):
        print(f"--------> epoch_num: {epoch_num}")
        train_loader_nums = len(train_data_loader.dataset)
        probs = np.zeros((train_loader_nums, num_class), dtype = np.float32)
        gt    = np.zeros((train_loader_nums, num_class), dtype = np.float32)
        k=0
        start_time = time.time()
        total_train_loss = 0.0
        correct = 0.0
        funsion_model.train()
        train_batch_sitorch = 0.0
        for train_data_batch, _, train_labels_batch in train_data_loader:
            train_data_batch = train_data_batch.cuda()
            train_labels_batch = train_labels_batch.cuda()
            outputs, _ = funsion_model(train_data_batch)
            outputs = outputs.reshape(outputs.shape[0], -1)
            train_labels_batch = train_labels_batch.reshape(train_labels_batch.shape[0], -1)

            probs[k: k + outputs.shape[0], :] = outputs.cpu().detach().numpy()
            gt[   k: k + outputs.shape[0], :] = train_labels_batch.cpu().detach().numpy()
            k += outputs.shape[0]

            train_loss = cross_loss(outputs, train_labels_batch)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            
            total_train_loss += train_loss
            
            predicted = torch.argmax(outputs, 1)
            labels = torch.argmax(train_labels_batch, 1)
            correct += (predicted == labels).sum().item()
                        
        auc = cal_auc(gt, probs)
        print(f"epoch_num {epoch_num} av train loss {total_train_loss}  train auc {auc} train acc {correct/k}")  
        
        writer.add_scalars('Training Metrics', {
            'Loss': total_train_loss,
            'Accuracy': correct / train_loader_nums,
            'AUC': auc,
        }, epoch_num)
        
        end_time = time.time() 
        elapsed_time = end_time - start_time 
        print(f"程序运行时间：{elapsed_time} 秒")

        lr_ = base_lr*(1-0.0009)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_

        save_interval = 25  
        if (epoch_num + 1) % save_interval == 0:
            save_mode_path = os.path.join('./train_model/VST', 'epoch_' + str(epoch_num+1) + '.pth')
            torch.save(funsion_model.state_dict(), save_mode_path)
            print("save model to {}".format(save_mode_path))
        test_interval = 25  # int(max_epoch/6)
        if (epoch_num + 1) % test_interval == 0:
            test_loader_nums = len(test_data_loader.dataset)
            test_probs = np.zeros((test_loader_nums, num_class), dtype = np.float32)
            test_gt    = np.zeros((test_loader_nums, num_class), dtype = np.float32)
            test_k  =0
            funsion_model.eval()
            with torch.no_grad():
                for test_data_batch, _, test_label_batch in test_data_loader:
                    test_data_batch = test_data_batch.cuda()
                    test_label_batch = test_label_batch.cuda()
                    test_outputs, _ = funsion_model(test_data_batch)
                    test_outputs = test_outputs.reshape(test_outputs.shape[0], -1)           
                    test_label_batch = test_label_batch.reshape(test_outputs.shape[0], -1)
                    # storing model predictions for metric evaluat`ion 
                    test_probs[test_k: test_k + test_outputs.shape[0], :] = test_outputs.cpu().detach().numpy()
                    test_gt[   test_k: test_k + test_outputs.shape[0], :] = test_label_batch.cpu().detach().numpy()
                    test_k += test_outputs.shape[0]
                test_label = np.argmax(test_gt, axis=1)
                test_pred = np.argmax(test_probs, axis=1)
                print(f"auc: {cal_auc(test_gt, test_probs)} | acc: {np.sum(test_label==test_pred)/test_k}")

### model eval.

In [13]:
total_acc_list = []
total_auroc_list = []

total_weight_auroc_list = []
total_weight_acc_list = []
### 模型验证
for i in range(10):
    test_loader_nums = len(test_data_loader.dataset)
    test_probs = np.zeros((test_loader_nums, len(dummy_labels)), dtype = np.float32)
    test_gt    = np.zeros((test_loader_nums, len(dummy_labels)), dtype = np.float32)
    test_k  =0
    funsion_model.eval()
    with torch.no_grad():
        for test_data_batch, _, test_label_batch in test_data_loader:
            test_data_batch = test_data_batch.cuda()
            test_label_batch = test_label_batch.cuda()
            test_outputs, _ = funsion_model(test_data_batch.cuda())
            test_outputs = test_outputs.reshape(test_outputs.shape[0], -1)           
            test_label_batch = test_label_batch.reshape(test_outputs.shape[0], -1)
            test_probs[test_k: test_k + test_outputs.shape[0], :] = test_outputs.cpu().detach().numpy()
            test_gt[   test_k: test_k + test_outputs.shape[0], :] = test_label_batch.cpu().detach().numpy()
            test_k += test_outputs.shape[0]
        test_label = np.argmax(test_gt, axis=1)
        test_pred = np.argmax(test_probs, axis=1)
        weight_auc, auc_list = get_weighted_auc_f1(test_probs, test_pred, test_label)

        cm = confusion_matrix(test_label, test_pred)
        dataset_list = [10, 10, 10, 10]  # , 7
        acc_list = []
        weighted_acc = 0.0
        for i in range(len(dataset_list)):
            weight = dataset_list[i] / sum(dataset_list)
            correct = cm[i][i]
            acc = float(correct) / dataset_list[i]
            acc_list.append(acc)
            weighted_acc += weight*acc 
        
        total_auroc_list.append(auc_list)
        total_acc_list.append(acc_list)
        total_weight_auroc_list.append(weight_auc)
        total_weight_acc_list.append(weighted_acc)

--------------------------------------------------
auc_list : [0.9400000000000001, 0.9933333333333333, 0.9566666666666668, 0.9366666666666668]
weighted_auroc:  0.9566666666666668
weighted_F1:  0.8519387236206458
weight auc: [0.9566666666666668] ; weight acc : [0.85]
--------------------------------------------------
auc_list : [0.9666666666666668, 0.9833333333333334, 0.9633333333333334, 0.9266666666666667]
weighted_auroc:  0.9600000000000001
weighted_F1:  0.8519387236206458
weight auc: [0.9566666666666668, 0.9600000000000001] ; weight acc : [0.85, 0.85]
--------------------------------------------------
auc_list : [0.9633333333333333, 0.9866666666666667, 0.9666666666666668, 0.9500000000000001]
weighted_auroc:  0.9666666666666668
weighted_F1:  0.853250773993808
weight auc: [0.9566666666666668, 0.9600000000000001, 0.9666666666666668] ; weight acc : [0.85, 0.85, 0.8500000000000001]
--------------------------------------------------
auc_list : [0.9733333333333334, 0.9866666666666667, 0.936

In [14]:
auc_arr = np.array(total_auroc_list)
print(auc_arr.shape)
for i in range(auc_arr.shape[-1]):
    auc_arr_cls = auc_arr[:, i]
    mean = np.mean(auc_arr_cls)
    std = np.std(auc_arr_cls)
    print(mean, std)

(10, 4)
0.9629999999999999 0.014640127503998509
0.9893333333333334 0.004163331998932223
0.9603333333333334 0.012151817422372141
0.9436666666666668 0.014564034849968988


In [15]:
acc_arr = np.array(total_acc_list)
print(acc_arr.shape)
for i in range(auc_arr.shape[-1]):
    acc_arr_cls = acc_arr[:, i]
    mean = np.mean(acc_arr_cls)
    std = np.std(acc_arr_cls)
    print(mean, std)

(10, 4)
0.9 0.0
0.8799999999999999 0.039999999999999994
0.71 0.030000000000000023
0.93 0.04582575694955839


### end