In [1]:
import sys
import torch
sys.path.append('./R3D/models')

from abc import ABCMeta, abstractmethod
from resnet2p1d import ResNet, Bottleneck
from mutil_class_loss import FocalLoss, cal_auc
from weighted_auc_f1 import get_weighted_auc_f1

from sklearn.model_selection import train_test_split
import os
from PIL import Image
import torch
from torchvision import transforms
import pandas as pd
from skimage import transform
import numpy as np
from torch import nn
import SimpleITK as sitk
from torch.utils.data import DataLoader
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from imgaug import augmenters as iaa
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from load_dataset import ACDC

  check_for_updates()


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dummy_labels =  ['HCM', 'RV', 'DCM', 'MINF']
num_class = len(dummy_labels)
model = ResNet(block=Bottleneck, layers=[3, 4, 6, 3], block_inplanes=[64, 128, 256, 512], n_classes=700, cls_num=num_class)

phase = 'test'

In [3]:
if phase == 'train':
    checkpoint = torch.load(r'./r2p1d50_K_200ep.pth', map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'], strict=False)
else:
    checkpoint = torch.load(r'./train_model/R3D/epoch_100.pth')
    model.load_state_dict(checkpoint, strict=False)

In [4]:
train_data = pd.read_csv('./roi_processed/training/train_data.csv', encoding='GBK')
test_data = pd.read_csv('./roi_processed/testing/test_data.csv', encoding='GBK')

train_data = train_data[train_data['Finding Labels'].isin(dummy_labels)]
test_data = test_data[test_data['Finding Labels'].isin(dummy_labels)]

In [5]:
# One Hot Encoding of Finding Labels to dummy_labels
for label in dummy_labels:
    train_data[label] = train_data['Finding Labels'].map(lambda result: 1.0 if label in result else 0)

In [6]:
# One Hot Encoding of Finding Labels to dummy_labels
for label in dummy_labels:
    test_data[label] = test_data['Finding Labels'].map(lambda result: 1.0 if label in result else 0)

In [14]:
train_data['target_vector'] = train_data.apply(lambda target: [target[dummy_labels].values], 1).map(lambda target: target[0])
test_data['target_vector'] = test_data.apply(lambda target: [target[dummy_labels].values], 1).map(lambda target: target[0])

In [15]:
clean_labels = train_data[dummy_labels].sum().sort_values(ascending= False) # get sorted value_count for clean labels
print(f'train size：')
print(clean_labels) # view tabular results

train size：
HCM     20.0
RV      20.0
DCM     20.0
MINF    20.0
dtype: float64


In [16]:
print(f'test zise：')
clean_labels = test_data[dummy_labels].sum().sort_values(ascending= False) # get sorted value_count for clean labels
print(clean_labels) # view tabular results

test zise：
HCM     10.0
RV      10.0
DCM     10.0
MINF    10.0
dtype: float64


In [17]:
base_lr = 0.0005
batch_size = 8
max_epoch = 200
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'device {device}')
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

model.to(device)
model = model.to(device)
fn_loss  = FocalLoss(device = device, gamma = 2.).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=base_lr)

train_acdc_data = ACDC(data=train_data, phase = 'train', img_size=(224, 224))
train_data_loader = DataLoader(train_acdc_data, batch_size=batch_size, shuffle=True, num_workers=5)
test_acdc_data = ACDC(data=test_data, phase = 'test', img_size=(224, 224))
test_data_loader = DataLoader(test_acdc_data, batch_size=batch_size, shuffle=True, num_workers=5)

device cuda:0


In [18]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=f'./runs/R3D')

In [19]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message="Was asked to gather along dimension 0, but all input tensors were scalars")

if phase == 'train':
    import time
    for epoch_num in range(0, max_epoch):
        print(f"--------> epoch_num: {epoch_num}")
        train_loader_nums = len(train_data_loader.dataset)
        probs = np.zeros((train_loader_nums, num_class), dtype = np.float32)
        gt    = np.zeros((train_loader_nums, num_class), dtype = np.float32)
        k=0
        start_time = time.time()
        total_train_loss = 0.0
        correct = 0.0
        model.train()
        for train_data_batch, batch_finding, train_labels_batch in train_data_loader:
            train_data_batch = train_data_batch.to(device)
            train_labels_batch = train_labels_batch.to(device)
            outputs = model(train_data_batch)
            outputs = outputs.reshape(outputs.shape[0], -1) 
            train_labels_batch = train_labels_batch.reshape(train_labels_batch.shape[0], -1)
            probs[k: k + outputs.shape[0], :] = outputs.cpu().detach().numpy()
            gt[   k: k + outputs.shape[0], :] = train_labels_batch.cpu().detach().numpy()
            k += outputs.shape[0]

            train_loss = fn_loss(outputs, train_labels_batch)
            total_train_loss += train_loss
            
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            
            predicted = torch.argmax(outputs, 1)
            labels = torch.argmax(train_labels_batch, 1)
            correct += (predicted == labels).sum().item()
        
        auc = cal_auc(gt, probs)
        print(f"epoch_num {epoch_num} av train loss {total_train_loss}  train auc {auc} train acc {correct/train_loader_nums}")  
        
        writer.add_scalars('Training Metrics', {
            'Loss': total_train_loss / train_loader_nums*20,
            'Accuracy': correct / train_loader_nums,
            'AUC': auc,
        }, epoch_num)
        
        end_time = time.time() 
        elapsed_time = end_time - start_time 
        print(f"程序运行时间：{elapsed_time} 秒")

        lr_ = base_lr*(1-0.0009)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_

        save_interval = 25 
        if (epoch_num + 1) % save_interval == 0:
            save_mode_path = os.path.join(f'./train_model/R3D', 'epoch_' + str(epoch_num+1) + '.pth')
            torch.save(model.state_dict(), save_mode_path)
            print("save model to {}".format(save_mode_path))
        test_interval = 25
        if (epoch_num + 1) % test_interval == 0:
            test_loader_nums = len(test_data_loader.dataset)
            test_probs = np.zeros((test_loader_nums, num_class), dtype = np.float32)
            test_gt    = np.zeros((test_loader_nums, num_class), dtype = np.float32)
            test_k  =0
            model.eval()
            with torch.no_grad():
                for test_data_batch, _, test_label_batch in test_data_loader:
                    test_data_batch = test_data_batch.to("cuda:0")
                    test_label_batch = test_label_batch.to("cuda:0")
                    test_outputs = model(test_data_batch)
                    test_outputs = test_outputs.reshape(test_outputs.shape[0], -1)           
                    test_label_batch = test_label_batch.reshape(test_outputs.shape[0], -1)
                    test_probs[test_k: test_k + test_outputs.shape[0], :] = test_outputs.cpu().detach().numpy()
                    test_gt[   test_k: test_k + test_outputs.shape[0], :] = test_label_batch.cpu().detach().numpy()
                    test_k += test_outputs.shape[0]
                test_label = np.argmax(test_gt, axis=1)
                test_pred = np.argmax(test_probs, axis=1)
                print(f"auc: {cal_auc(test_gt, test_probs)} | acc: {np.sum(test_label==test_pred)/test_loader_nums}")

内存：11102MiB\
5.2 s

In [20]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
total_acc_list = []
total_auroc_list = []

total_weight_auroc_list = []
total_weight_acc_list = []
### 模型验证
for i in range(10):
    test_loader_nums = len(test_data_loader.dataset)
    test_probs = np.zeros((test_loader_nums, len(dummy_labels)), dtype = np.float32)
    test_gt    = np.zeros((test_loader_nums, len(dummy_labels)), dtype = np.float32)
    test_k  =0
    model.eval()
    with torch.no_grad():
        for test_data_batch, _, test_label_batch in test_data_loader:
            test_data_batch = test_data_batch.cuda()
            test_label_batch = test_label_batch.cuda()
            test_outputs = model(test_data_batch.cuda())
            test_outputs = test_outputs.reshape(test_outputs.shape[0], -1)           
            test_label_batch = test_label_batch.reshape(test_outputs.shape[0], -1)
            test_probs[test_k: test_k + test_outputs.shape[0], :] = test_outputs.cpu().detach().numpy()
            test_gt[   test_k: test_k + test_outputs.shape[0], :] = test_label_batch.cpu().detach().numpy()
            test_k += test_outputs.shape[0]
        test_label = np.argmax(test_gt, axis=1)
        test_pred = np.argmax(test_probs, axis=1)
        weight_auc, auc_list = get_weighted_auc_f1(test_probs, test_pred, test_label)

        cm = confusion_matrix(test_label, test_pred)
        dataset_list = [10, 10, 10, 10]  # , 7
        acc_list = []
        weighted_acc = 0.0
        for i in range(len(dataset_list)):
            weight = dataset_list[i] / sum(dataset_list)
            correct = cm[i][i]
            acc = float(correct) / dataset_list[i]
            acc_list.append(acc)
            weighted_acc += weight*acc 
        
        total_auroc_list.append(auc_list)
        total_acc_list.append(acc_list)
        total_weight_auroc_list.append(weight_auc)
        total_weight_acc_list.append(weighted_acc)
    print(f'weight auc: {total_weight_auroc_list} ; weight acc : {total_weight_acc_list}')

--------------------------------------------------
auc_list : [0.9333333333333333, 0.9166666666666667, 0.9633333333333334, 0.8066666666666666]
weighted_auroc:  0.905
weighted_F1:  0.825139842359522
weight auc: [0.905] ; weight acc : [0.825]
--------------------------------------------------
auc_list : [0.9533333333333334, 0.91, 0.9566666666666668, 0.8400000000000001]
weighted_auroc:  0.915
weighted_F1:  0.7234299516908211
weight auc: [0.905, 0.915] ; weight acc : [0.825, 0.7249999999999999]
--------------------------------------------------
auc_list : [0.9666666666666667, 0.9333333333333333, 0.9199999999999999, 0.7666666666666666]
weighted_auroc:  0.8966666666666666
weighted_F1:  0.7813131313131314
weight auc: [0.905, 0.915, 0.8966666666666666] ; weight acc : [0.825, 0.7249999999999999, 0.7750000000000001]


In [None]:
auc_arr = np.array(total_auroc_list)
print(auc_arr.shape)
for i in range(auc_arr.shape[-1]):
    auc_arr_cls = auc_arr[:, i]
    mean = np.mean(auc_arr_cls)
    std = np.std(auc_arr_cls)
    print(mean, std)

In [None]:
acc_arr = np.array(total_acc_list)
print(acc_arr.shape)
for i in range(auc_arr.shape[-1]):
    acc_arr_cls = acc_arr[:, i]
    mean = np.mean(acc_arr_cls)
    std = np.std(acc_arr_cls)
    print(mean, std)

### end