In [None]:
import sys
import torch

sys.path.append('./CMR-AI/mmaction/')
sys.path.append('./CMR-AI/')

from abc import ABCMeta, abstractmethod
from swinTransformer3D_origin import SwinTransformer3D
from mutil_class_loss import FocalLoss, cal_auc, get_alpha, kd_loss
from weighted_auc_f1 import get_weighted_auc_f1
from load_dataset import ACDC
from utilsss import generate_mask_matrix, pruning_mask, row_softmax
from Policy import Policy, train_agent

import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import os
from PIL import Image
import torch
from torchvision import transforms
import pandas as pd
from skimage import transform
import numpy as np
from torch import nn
import SimpleITK as sitk
from torch.utils.data import DataLoader
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
from imgaug import augmenters as iaa
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
TRAIN = 'train'
TEST  = 'test'
FINE_TUNE = 'fine_tune'

phase = TRAIN

In [None]:
## Task Split
num_class = {
        0: ['HCM', 'RV', 'DCM', 'MINF'],
        1: ['HCM', 'RV'], 
        2: ['DCM', 'MINF'],
    }

save_models = {0: 'full', 1: '1', 2: '2'}
dummy_labels = num_class[0]

models = []
for _, v in num_class.items():
    num_class_ = len(v)
    models.append(SwinTransformer3D(num_class=num_class_))
models_modules = []
for i in range(len(models)):
    models_modules.append(models[i].modules())
print(f'Total model is {len(models)}')

In [None]:
## RL agent
input_size, teacher_num = 0, len(models)
for name, param in list(models[0].named_parameters())[-2:-1]:
    print(f"layer anem: {name} | size: {param.size()}")
    input_size = param.size()[-1]
agent = Policy(input_size=input_size, teacher_num=teacher_num)

In [None]:
for i in range(len(models)):
    if phase == TRAIN:
        _pretrained_dict = torch.load(r'./organmnist3d_250.pth')
        pretrained_dict = {}
        for k, v in _pretrained_dict.items():
            if k.startswith('module.'):
                new_key = k.replace('module.', '')
            elif k.startswith('cls_head.'):
                new_key = k.replace('cls_head.', '')
            else:
                new_key = k 
            pretrained_dict[new_key] = v
    else:
        if phase == TEST:
            print(f'train test phase: load weightd from local file pth.')
            _pretrained_dict = torch.load(f'./MTRL-MKD-SKD/full/best_model.pth')
            pretrained_dict = {}
            for k, v in _pretrained_dict.items():
                if k.startswith('module.'):
                    new_key = k.replace('module.', '')
                # 移除 'cls_head.' 前缀
                elif k.startswith('cls_head.'):
                    new_key = k.replace('cls_head.', '')
                else:
                    new_key = k  # 如果没有前缀，则保持不变

                pretrained_dict[new_key] = v
    model_dict = models[i].state_dict()
    # 过滤掉不匹配的参数
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
    model_dict.update(pretrained_dict)
    models[i].load_state_dict(pretrained_dict, strict=False)
    print(f'No.{i+1} model load pretrained weighted end.')

In [None]:
if phase == TEST or phase == FINE_TUNE:
    print(f'load mask_matrix from local file.')
    import pickle
    with open(f'./mask_matrix.npz', 'rb') as f:
        mask_matrix_dict = pickle.load(f)

In [None]:
PatchMerging_module_idx = [38, 73, 348]

WindowAttention3D_module_idx = [44, 45, 47, 59, 60, 62, 79, 80, 82, 94, 95, 97, 99, 110, 112,  124, 125, 127, 139, 140, 142, 154, 155, 157, 169, 170, 172, \
                                184, 185, 187, 199, 200, 202, 214, 215, 217, 229, 230, 232, 244, 245, 247, 259, 260, 262, 274, 275, 277, 289, 290, 292, 304, 305, \
                                307, 319, 320, 321, 322, 334, 335, 337, 354, 355, 357, 369, 370, 372, ]

Mlp_module_idx = [18, 20, 33, 35, 53, 55, 68, 70, 88, 90, 103, 105, 118, 120, 133, 135, 148, 150, 163, 165, 178, 180, \
                  193, 195, 208, 210, 223, 225, 238, 240,253, 255,  268, 270, 283, 285, 298, 300, 313, 315, 328, 330, 343, \
                  345, 363, 365, 378, 380]
                   
if phase == TRAIN:
    print(f'start to generate mask_matrix.')
    mask_matrix_dict = {}
    for module_idx, module in enumerate(models[0].modules()):
        if module_idx in PatchMerging_module_idx or module_idx in Mlp_module_idx: continue
        if hasattr(module, "qkv"):
            if isinstance(module.qkv, nn.Linear):
                if module.qkv.weight.data.numel() < 50000: continue
                mask = torch.ByteTensor(module.qkv.weight.data.size()).fill_(0)
                mask_matrix_dict[module_idx] = generate_mask_matrix(mask.numpy())
        else:
            if isinstance(module, nn.Linear):
                if module.weight.data.numel() < 50000: continue
                mask = torch.ByteTensor(module.weight.data.size()).fill_(0)
                mask_matrix_dict[module_idx] = generate_mask_matrix(mask.numpy())

if phase == TRAIN:
    mask_list = []
    for module_idx, modules in enumerate(zip(models[0].modules(), models[1].modules(), models[2].modules())):
        if module_idx not in mask_matrix_dict.keys():  continue
        for mask_index, module in enumerate(modules):
            if hasattr(module, "qkv"):
                if mask_index == 0: 
                    pass
                else:
                    with torch.no_grad():
                        module.qkv.weight[mask_matrix_dict[module_idx] != mask_index] = \
                            module.qkv.weight[mask_matrix_dict[module_idx] != mask_index].detach().requires_grad_(False)
            else:
                if mask_index == 0: 
                    pass
                else:
                    with torch.no_grad():
                        module.weight[mask_matrix_dict[module_idx] != mask_index] = \
                            module.weight[mask_matrix_dict[module_idx] != mask_index].detach().requires_grad_(False)
                    
if phase == TRAIN:
    print(f'save mask_matrix to local file.')
    import pickle
    with open(f'./mask_matrix.npz', 'wb') as f:
        pickle.dump(mask_matrix_dict, f)

In [None]:
mask_matrix_dict.keys()

In [None]:
train_data = pd.read_csv('./sax_roi_processed/training/train_data.csv', encoding='GBK')
test_data = pd.read_csv('./sax_roi_processed/testing/test_data.csv', encoding='GBK')

train_data = train_data[train_data['Finding Labels'].isin(dummy_labels)]
test_data = test_data[test_data['Finding Labels'].isin(dummy_labels)]

In [None]:
# One Hot Encoding of Finding Labels to dummy_labels
for label in dummy_labels:
    train_data[label] = train_data['Finding Labels'].map(lambda result: 1.0 if label in result else 0)

In [None]:
# One Hot Encoding of Finding Labels to dummy_labels
for label in dummy_labels:
    test_data[label] = test_data['Finding Labels'].map(lambda result: 1.0 if label in result else 0)

In [None]:
train_data['target_vector'] = train_data.apply(lambda target: [target[dummy_labels].values], 1).map(lambda target: target[0])
test_data['target_vector'] = test_data.apply(lambda target: [target[dummy_labels].values], 1).map(lambda target: target[0])

In [None]:
clean_labels = train_data[dummy_labels].sum().sort_values(ascending= False) # get sorted value_count for clean labels
print(f'train data size：')
print(clean_labels) # view tabular results

In [None]:
print(f'test size：')
clean_labels = test_data[dummy_labels].sum().sort_values(ascending= False) # get sorted value_count for clean labels
print(clean_labels) # view tabular results

In [None]:
dataset_list = []
for label_ in dummy_labels:
    dataset_list.append(clean_labels[label_])
dataset_list

## 训练开始

In [None]:
base_lr = 
batch_size = 
max_epoch = 
momentum = 
T = 

# 将模型放到多 GPU 上
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    for i in range(len(models)):
        models[i] = nn.DataParallel(models[i])
    
for i in range(len(models)):
    models[i] = models[i].cuda()
    
fn_loss  = FocalLoss(device = 'cuda:0', gamma = 2.).to('cuda:0')
kl_loss = torch.nn.KLDivLoss(reduction='batchmean')

cross_loss = nn.CrossEntropyLoss()

optimizers = []
for i in range(len(models)):
    optimizers.append(torch.optim.SGD(models[i].parameters(), lr=base_lr))

agent_optimizer = torch.optim.SGD(agent.parameters(), lr=base_lr)
    
train_acdc_data = ACDC(data=train_data, phase = 'train', img_size=(224, 224))
train_data_loader = DataLoader(train_acdc_data, batch_size=batch_size, shuffle=True, num_workers=5)
test_acdc_data = ACDC(data=test_data, phase = 'test', img_size=(224, 224))
test_data_loader = DataLoader(test_acdc_data, batch_size=batch_size, shuffle=True, num_workers=5)

In [None]:
if 0:
    for batch_idx, (batch_data, batch_finding, batch_label) in enumerate(test_data_loader):
        # print(batch_data.shape)
        for i in range(batch_data.shape[0]):
            print(f'--------- {i} -----------')
            for j in range(batch_data.shape[2]):
                img_ = batch_data[i, 0, j, :, :].detach().numpy()
                plt.imshow(img_, cmap='gray')
                plt.show()

In [None]:
if phase == TRAIN:
    from torch.utils.tensorboard import SummaryWriter
    # 初始化 TensorBoard
    writer = SummaryWriter(log_dir='./runs/VSTMTRL-ACDC')  # 指定日志保存路径

In [None]:
if phase == TRAIN:
    import time
    pre_mutil_teacher_acc = []
    now_mutil_teacher_acc = []
    for i in range(1, len(models)):
        pre_mutil_teacher_acc.append(0)
        now_mutil_teacher_acc.append(0)
    logit_actions_list = []
    agent_rewars_list  = []
    student_infos_list = []    
    mask_list_dict= {}
    bast_acc = 0.0
    for epoch_num in range(0, max_epoch):
        alpha = get_alpha(epoch_num, max_epoch)
        # alpha = 0.7
        print(f"--------> epoch_num: {epoch_num}")
        train_loader_nums = len(train_data_loader.dataset)
        probs = np.zeros((train_loader_nums, len(num_class[0])), dtype = np.float32)
        gt    = np.zeros((train_loader_nums, len(num_class[0])), dtype = np.float32)
        k_index = 0
        start_time = time.time()
        total_train_loss = 0.0
        correct = 0.0
        for i in range(len(models)): models[i].train()
        mutil_teacher_correct = []
        mutil_teacher_num = []
        for i in range(1, len(models)):
            mutil_teacher_correct.append(0)
            mutil_teacher_num.append(0)
        for batch_idx, (batch_data, batch_finding, batch_label) in enumerate(train_data_loader):
            weighted_list = []
            mutil_teacher_label = torch.zeros_like(batch_label)
            student_output = torch.zeros_like(batch_label)
            pre_lables = 0
            s_train_loss = torch.tensor(0.0, device='cuda')
            for i, k in enumerate(num_class.keys()):
                if i == 0: 
                    continue
                else:
                    ## train teacher.
                    teacher_train_data_index = pd.Series(batch_finding).isin(num_class[k])
                    teacher_train_data_index = teacher_train_data_index.to_numpy()
                    weighted_list.append(np.sum(teacher_train_data_index > 0))
                    mutil_teacher_num[i-1] += weighted_list[i-1]
                    t_train_data = batch_data[teacher_train_data_index]
                    t_train_label = batch_label[teacher_train_data_index][:, pre_lables:pre_lables+len(num_class[k])]
                    if np.sum(teacher_train_data_index > 0) == 0: 
                        pre_lables += len(num_class[i])
                        continue
                    t_output, _ = models[i](t_train_data.cuda())
                    
                    mutil_teacher_label[teacher_train_data_index, pre_lables:pre_lables+len(num_class[k])] = \
                                            row_softmax(t_output.cpu().detach())
                    t_output = t_output.reshape(t_output.shape[0], -1)
                    t_train_label = t_train_label.reshape(t_train_label.shape[0], -1).cuda()
                    t_train_loss = fn_loss(t_output, t_train_label)
                    optimizers[i].zero_grad()
                    t_train_loss.backward()
                    optimizers[i].step()
                    predicted_ = torch.argmax(t_output, 1)
                    labels_ = torch.argmax(t_train_label.cuda(), 1)
                    correct_ = (predicted_ == labels_).sum().item() 
                    mutil_teacher_correct[i-1] += correct_
                    
                    ## train student
                    s_train_data, s_train_label = t_train_data = batch_data[teacher_train_data_index], batch_label[teacher_train_data_index]
                    s_output, _ = models[0](s_train_data.cuda())
                    s_output = s_output.reshape(s_output.shape[0], -1)
                    student_output[teacher_train_data_index] = s_output.cpu()  
                    
                    log_s_output = torch.nn.LogSoftmax(dim=1)(student_output)
                    s_train_loss += kd_loss(s_output.cuda(), mutil_teacher_label[teacher_train_data_index].cuda(), batch_label[teacher_train_data_index].cuda(), T, alpha)
                    pre_lables += len(num_class[k])
                    
            optimizers[0].zero_grad()
            s_train_loss.backward()
            optimizers[0].step()
            total_train_loss += s_train_loss
            
            with torch.no_grad():
                outputs, student_info = models[0](batch_data.cuda())
                predicted = torch.argmax(outputs, 1)
                labels = torch.argmax(batch_label.cuda(), 1)
                correct += (predicted == labels).sum().item()  

            ## RL
            if epoch_num != 0 and (batch_idx+1) % (len(train_data_loader.dataset)//batch_size//3) == 0:
                student_infos_list.append(student_info.detach().cpu())
                with torch.no_grad():
                    logit_action = agent(student_info.detach().cpu())
                if epoch_num == 0:
                    logit_action = torch.ones_like(logit_action).cuda()
                logit_actions_list.append(logit_action.detach().cpu())                

                agent_reward = -(F.cross_entropy(outputs.cuda() ,batch_label.cuda(), reduction='none'))
                rewards_mean = agent_reward.mean() 
                rewards_std = agent_reward.std() 
                normalized_reward = (agent_reward - rewards_mean) / rewards_std 
                normalized_reward = normalized_reward.detach().cpu()
                normalized_reward = torch.clamp(normalized_reward, min=0, max=1)
                agent_rewars_list.append(normalized_reward)

                mean_logit_action = torch.mean(logit_action, dim=0, keepdim=True)  # shape=(1, 4)
                print(f'RMD weights : {mean_logit_action}')

                for module_index, modules in enumerate(zip(models[0].modules(), models[1].modules(), models[2].modules())):
                    mask_list = []
                    teacher_class_num_sum = 0
                    if module_index not in mask_matrix_dict.keys():  continue
                    if isinstance(modules[0], nn.Linear):
                        if modules[0].weight.data.numel() < 50000: continue
                        for mask_index, module in enumerate(modules):
                            if mask_index == 0: 
                                mask_list.append([])
                                continue  
                            if hasattr(modules[0], "qkv"):
                                weights = module.qkv.weight.data
                            else:
                                weights = module.weight.data
                            rl_weights = mean_logit_action[0][mask_index-1]
                            mask = pruning_mask(weights.cpu(), mask_matrix_dict[module_index], mask_index, k=rl_weights.cpu().numpy())  ## 1 表示 N0-PI
                            mask_list.append(mask.cuda())
                        all_weights_mask = torch.ones_like(modules[0].weight.data)
                        for i in range(1, len(models)):
                            unique_weights_mask_ = mask_list[i]
                            unique_weights_mask = (unique_weights_mask_ >= 1).int()
                            if hasattr(modules[0], "qkv"):
                                modules[0].qkv.weight.grad.data[unique_weights_mask].fill_(0)  ## 不让更新
                                unique_weights = unique_weights_mask * modules[i].qkv.weight.data 
                                modules[0].qkv.weight.data = (modules[0].qkv.weight.data)*(all_weights_mask-unique_weights_mask) + \
                                                         + momentum*(unique_weights_mask * modules[0].qkv.weight.data) + (1-momentum)*unique_weights    
                            else:
                                modules[0].weight.grad.data[unique_weights_mask].fill_(0)  ## 不让更新
                                unique_weights = unique_weights_mask * modules[i].weight.data 
                                modules[0].weight.data = (modules[0].weight.data)*(all_weights_mask-unique_weights_mask) + \
                                                         + momentum*(unique_weights_mask * modules[0].weight.data) + (1-momentum)*unique_weights                            
            if batch_idx != 0 and batch_idx % (len(train_data_loader)-1) == 0:  
                train_agent(epoch=epoch_num, agent=agent, student_infos=student_infos_list, agent_rewards=agent_rewars_list, logits_agent_actions=logit_actions_list, agent_optimizer=agent_optimizer)
                logit_actions_list.clear()
                agent_rewars_list.clear()
                student_infos_list.clear()
    
        logit_actions_list.clear()
        agent_rewars_list.clear()
        student_infos_list.clear()
                 
        for i in range(len(now_mutil_teacher_acc)):
            now_mutil_teacher_acc[i] = mutil_teacher_correct[i]/mutil_teacher_num[i]
      
        print(f'mutil-teacher acc is : {now_mutil_teacher_acc}')
        print(f"epoch_num {epoch_num} train loss {total_train_loss} ")  

        writer.add_scalars('Training Metrics', {
            'Loss': total_train_loss,
            'Accuracy': correct / train_loader_nums,
        }, epoch_num)
        
        lr_ = base_lr*(1-0.0009)
        T = T*(1-0.0009)
        for i in range(len(optimizers)):
            for param_group in optimizers[i].param_groups:
                param_group['lr'] = lr_
        
        mutil_teacher_compare_result = (np.array(now_mutil_teacher_acc) > np.array(pre_mutil_teacher_acc))
        for i in range(mutil_teacher_compare_result.shape[0]):
            if mutil_teacher_compare_result[i]: 
                pre_mutil_teacher_acc[i] = now_mutil_teacher_acc[i]

        for i in range(1, len(models)):
            if mutil_teacher_compare_result[i-1] == False:
                print(f'teacher model {i} dont surpressed pre batch. dont update agein.')
        
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"程序运行时间：{elapsed_time} 秒")
        test_interval = 1
        if (epoch_num + 1) % test_interval == 0:
            test_loader_nums = len(test_data_loader.dataset)
            test_probs = np.zeros((test_loader_nums, len(num_class[0])), dtype = np.float32)
            test_gt    = np.zeros((test_loader_nums, len(num_class[0])), dtype = np.float32)
            test_k  = 0
            models[0].eval()
            with torch.no_grad():
                for test_data_batch, _, test_label_batch in test_data_loader:
                    test_data_batch = test_data_batch.cuda()
                    test_label_batch = test_label_batch.cuda()
                    test_outputs, _ = models[0](test_data_batch)
                    test_outputs = test_outputs.reshape(test_outputs.shape[0], -1)           
                    test_label_batch = test_label_batch.reshape(test_outputs.shape[0], -1)
                    # storing model predictions for metric evaluation 
                    test_probs[test_k: test_k + test_outputs.shape[0], :] = test_outputs.cpu().detach().numpy()
                    test_gt[   test_k: test_k + test_outputs.shape[0], :] = test_label_batch.cpu().detach().numpy()
                    test_k += test_outputs.shape[0]
                test_label = np.argmax(test_gt, axis=1)
                test_pred = np.argmax(test_probs, axis=1)
                print(f"auc: {cal_auc(test_gt, test_probs)} | acc: {np.sum(test_label==test_pred)/test_k}")
                if (np.sum(test_label==test_pred)/test_k) >= bast_acc:
                    bast_acc = (np.sum(test_label==test_pred)/test_k)
                    for i in range(len(models)):
                        os.makedirs(f'./train_model/MTRL-MKD-SKD/', exist_ok=True)
                        os.makedirs(f'./train_model/MTRL-MKD-SKD/{save_models[i]}', exist_ok=True)
                        save_mode_path = os.path.join(f'./train_model/MTRL-MKD-SKD/{save_models[i]}', 'best_model.pth')
                        torch.save(models[i].state_dict(), save_mode_path)
                        print("save model to {}".format(save_mode_path))  

3 2

time 17.5   \
11301MiB

In [None]:
total_acc_list = []
total_auroc_list = []

total_weight_auroc_list = []
total_weight_acc_list = []
### eval.
for i in range(10):
    test_loader_nums = len(test_data_loader.dataset)
    test_probs = np.zeros((test_loader_nums, len(dummy_labels)), dtype = np.float32)
    test_gt    = np.zeros((test_loader_nums, len(dummy_labels)), dtype = np.float32)
    test_k  =0
    models[0].eval()
    with torch.no_grad():
        for test_data_batch, _, test_label_batch in test_data_loader:
            test_data_batch = test_data_batch.cuda()
            test_label_batch = test_label_batch.cuda()
            test_outputs, _ = models[0](test_data_batch.cuda())
            test_outputs = test_outputs.reshape(test_outputs.shape[0], -1)           
            test_label_batch = test_label_batch.reshape(test_outputs.shape[0], -1)
            test_probs[test_k: test_k + test_outputs.shape[0], :] = test_outputs.cpu().detach().numpy()
            test_gt[   test_k: test_k + test_outputs.shape[0], :] = test_label_batch.cpu().detach().numpy()
            test_k += test_outputs.shape[0]
        test_label = np.argmax(test_gt, axis=1)
        test_pred = np.argmax(test_probs, axis=1)
        weight_auc, auc_list = get_weighted_auc_f1(test_probs, test_pred, test_label)

        cm = confusion_matrix(test_label, test_pred)
        dataset_list = [10, 10, 10, 10]  # , 7
        acc_list = []
        weighted_acc = 0.0
        for i in range(len(dataset_list)):
            weight = dataset_list[i] / sum(dataset_list)
            correct = cm[i][i]
            acc = float(correct) / dataset_list[i]
            acc_list.append(acc)
            weighted_acc += weight*acc 
        
        total_auroc_list.append(auc_list)
        total_acc_list.append(acc_list)
        total_weight_auroc_list.append(weight_auc)
        total_weight_acc_list.append(weighted_acc)

In [None]:
print(total_weight_auroc_list)
print(total_weight_acc_list)

In [None]:
auc_arr = np.array(total_auroc_list)
print(auc_arr.shape)
for i in range(auc_arr.shape[-1]):
    auc_arr_cls = auc_arr[:, i]
    mean = np.mean(auc_arr_cls)
    std = np.std(auc_arr_cls)
    print(mean, std)

In [None]:
acc_arr = np.array(total_acc_list)
print(acc_arr.shape)
for i in range(auc_arr.shape[-1]):
    acc_arr_cls = acc_arr[:, i]
    mean = np.mean(acc_arr_cls)
    std = np.std(acc_arr_cls)
    print(mean, std)

In [None]:
np.mean(total_weight_acc_list), np.std(total_weight_acc_list)

In [None]:
np.mean(total_weight_auroc_list), np.std(total_weight_auroc_list)

END