In [1]:
import os
import tempfile
import warnings
import glob
import torch
import time
import utils
import SimpleITK as sitk
import numpy as np
import pystrum.pynd.ndutils as nd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.metrics import mean_squared_error
from my_unet import myUNet,teacher_student_model,UNet_distill
from utils import pkload
#from dataset import Seg_CT_MRI_Dataset
# MONAI imports
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from monai.config import print_config
from monai.transforms.post.array import AsDiscrete
from monai.metrics import DiceMetric,HausdorffDistanceMetric,SurfaceDistanceMetric,PSNRMetric,MSEMetric
from monai.metrics.regression import SSIMMetric

In [2]:
class Seg_CT_MRI_Dataset(Dataset):
    def __init__(self, source_root_dir,transform2img=None,transform2both=None,is_Training=True,is_val=False):
        
        self.source_root_dir = source_root_dir
        #self.target_root_dir = target_root_dir
        self.transform2img = transform2img
        self.transform2both = transform2both
        #self.is_Training = is_Training
        # 获取所有子文件夹的名称
        if is_Training:
            self.source_subfolders = [subfolder for subfolder in os.listdir(source_root_dir)][:70]
        elif is_val:
            self.source_subfolders = [subfolder for subfolder in os.listdir(source_root_dir)][70:88]
        else:
            #self.source_subfolders = [subfolder for subfolder in os.listdir(source_root_dir)][88:]
            #self.source_subfolders = [subfolder for subfolder in os.listdir(source_root_dir)][70:]
            self.source_subfolders=[]
            test_set=[71,72,74,75,77,79,80,85,87]
            delete_list = [91,92,95,96,98,101,102,103,105,106,107,109,111,113,115,116]

            for i in range(89,119):
                if i not in delete_list:
                    test_set.append(i)
            self.source_subfolders=[]
            for i in range(len(test_set)):
                idx = test_set[i]
                self.source_subfolders.append(f"BCH_CT{idx:03}")    
                #self.target_subfolders = [subfolder for subfolder in os.listdir(target_root_dir)]
    
    def __len__(self):
        return len(self.source_subfolders) 
        
    def __getitem__(self, idx):
        
        source_subfolder = self.source_subfolders[idx]
        #print(source_subfolder)
        #print(self.source_subfolders)
        
        # 构建源图像和目标图像的文件路径
        source_path = os.path.join(self.source_root_dir, source_subfolder, "brain_small_norm.nii.gz")
        source_label_path = os.path.join(self.source_root_dir, source_subfolder, "label_small.nii.gz")
        #print('real ct image path: ',source_path)
        #source_path = os.path.join(self.source_root_dir, source_subfolder, "pred_image_rigid.nii.gz")
        #source_label_path = os.path.join(self.source_root_dir, source_subfolder, "pred_label_rigid.nii.gz")
        # 检查文件是否存在
        if not os.path.exists(source_path) :
            raise FileNotFoundError(f"Source file not found for subfolder {source_path}")

        source_image_sitk = sitk.ReadImage(source_path)
        source_label_sitk = sitk.ReadImage(source_label_path)

        source_image_np = sitk.GetArrayFromImage(source_image_sitk) #Converting sitk_metadata to image Array
        source_label_np = sitk.GetArrayFromImage(source_label_sitk)
       
        label_mapping = {
            0.0: 0,
            1.0: 1,
            2.0: 2,
            3.0: 1,
            4.0: 2,
            5.0: 1,
            6.0: 2,
            7.0: 1,
            8.0: 2,
            9.0: 4,
            10.0: 3,
            11.0: 4,
            12.0: 3,
            13.0: 0,
            14.0: 1,
            15.0: 2,
            16.0: 0,
            17.0: 0,
            18.0: 1,
            19.0: 2,
            20.0: 6,
            21.0: 5,
            24.0: 6,
            25.0: 5,
            28.0: 0,
            29.0: 1,
            30.0: 2
        }
        '''
        # 创建分类标签值到连续整数的映射
        label_mapping = {
            0.0: 0,
            1.0: 1,
            2.0: 2,
            3.0: 3,
            4.0: 4,
            5.0: 5,
            6.0: 6,
            7.0: 7,
            8.0: 8,
            9.0: 9,
            10.0: 10,
            11.0: 11,
            12.0: 12,
            13.0: 13,
            14.0: 14,
            15.0: 15,
            16.0: 16,
            17.0: 17,
            18.0: 18,
            19.0: 19,
            20.0: 20,
            21.0: 21,
            24.0: 22,
            25.0: 23,
            28.0: 24,
            29.0: 25,
            30.0: 26
        }
        '''
        # 使用映射替换标签数组中的值
        for old_label, new_label in label_mapping.items():
            source_label_np[source_label_np == old_label] = new_label
        
        # 现在，label1_array 包含了连续的整数值
        source_label_np = source_label_np.astype(int)


        source_image = torch.Tensor(source_image_np).unsqueeze(dim = 0)

        source_label = torch.Tensor(source_label_np).unsqueeze(dim = 0)

        data_dict = {'image': source_image, "label": source_label}

        # Apply transformation
        if self.transform2img:
            source_image = apply_transform(self.transform2img, source_image)

        if self.transform2both:
            trans_data = apply_transform(self.transform2both, data_dict)
            source_image = trans_data['image']
            source_label = trans_data['label']
            
            
        return source_image,source_label,source_subfolder,source_path

In [3]:
class Seg_OASIS_Dataset(Dataset):
    def __init__(self, source_root_dir,transform2img=None,transform2both=None,is_Training=True,is_val=False):
        
        self.source_root_dir = source_root_dir
        #self.target_root_dir = target_root_dir
        self.transform2img = transform2img
        self.transform2both = transform2both
        #self.is_Training = is_Training
        # 获取所有子文件夹的名称
        
        if is_Training:
            self.source_subfolders = [subfolder for subfolder in os.listdir(source_root_dir)][:335]

        elif is_val:
            self.source_subfolders = [subfolder for subfolder in os.listdir(source_root_dir)][335:375]
        else:
            self.source_subfolders = [subfolder for subfolder in os.listdir(source_root_dir)][375:]
        #self.target_subfolders = [subfolder for subfolder in os.listdir(target_root_dir)]
        
        #self.source_subfolders = [subfolder for subfolder in os.listdir(source_root_dir)]
    def __len__(self):
        return len(self.source_subfolders) 
        
    def __getitem__(self, idx):
        
        source_subfolder = self.source_subfolders[idx]
        #print(source_subfolder)
        #print(self.source_subfolders)
        
        # 构建源图像和目标图像的文件路径
        source_path = os.path.join(self.source_root_dir, source_subfolder, "aligned_norm.nii.gz")
        source_label_path = os.path.join(self.source_root_dir, source_subfolder, "aligned_seg35.nii.gz")
        #print('source_path: ',source_path)

        # 检查文件是否存在
        if not os.path.exists(source_path) :
            raise FileNotFoundError(f"Source file not found for subfolder {source_subfolder}")

        source_image_sitk = sitk.ReadImage(source_path)
        source_label_sitk = sitk.ReadImage(source_label_path)

        source_image_np = sitk.GetArrayFromImage(source_image_sitk) #Converting sitk_metadata to image Array
        source_label_np = sitk.GetArrayFromImage(source_label_sitk)

        # 现在，label1_array 包含了连续的整数值
        #source_label_np = source_label_np.astype(int)
        source_image = torch.Tensor(source_image_np).unsqueeze(dim = 0)

        source_label = torch.Tensor(source_label_np).unsqueeze(dim = 0)

        data_dict = {'image': source_image, "label": source_label}

        # Apply transformation
        if self.transform2img:
            source_image = apply_transform(self.transform2img, source_image)

        if self.transform2both:
            trans_data = apply_transform(self.transform2both, data_dict)
            source_image = trans_data['image']
            source_label = trans_data['label']
            
            
        return source_image,source_label,source_subfolder,source_path

In [4]:
class Seg_IXI_Dataset(Dataset):
    def __init__(self, source_root_dir,transform2img=None,transform2both=None,is_Training=True,is_val=False):
        
        self.source_root_dir = source_root_dir
        self.transform2img = transform2img
        self.transform2both = transform2both
        #self.is_Training = is_Training
        # 获取所有子文件夹的名称
        
        if is_Training:
            self.source_files = [filename for filename in os.listdir(source_root_dir)]
        elif is_val:
            self.source_files = [filename for filename in os.listdir(source_root_dir)]
        else:
            self.source_files = [filename for filename in os.listdir(source_root_dir)]
        #self.target_subfolders = [subfolder for subfolder in os.listdir(target_root_dir)]
        
        #self.source_subfolders = [subfolder for subfolder in os.listdir(source_root_dir)]
    def __len__(self):
        return len(self.source_files) 
        
    def __getitem__(self, idx):
        
        source_filename = self.source_files[idx]
        #print(source_subfolder)
        #print(self.source_subfolders)
        
        # 构建源图像和目标图像的文件路径
        target_path = os.path.join(self.source_root_dir,source_filename)
        target_image_np,target_label_np = pkload(target_path)
    
        target_image_np = np.transpose(target_image_np, (2,1,0))
        target_label_np = np.transpose(target_label_np, (2,1,0))
        
        target_image_np = np.ascontiguousarray(target_image_np)
        target_label_np = np.ascontiguousarray(target_label_np)
        #[0, 2, 3, 4, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 24, 28, 31, 41, 42, 43, 46, 47, 49, 50, 51, 52, 53, 54, 60, 63]
        label_mapping = {
            0 : 0,
            5 : 0,
            26: 0,
            30: 0,
            44: 0,
            58: 0,
            62: 0,
            72: 0,
            77: 0,
            80: 0,
            85: 0,
            251:0,
            252:0,
            253:0,
            254:0,
            255:0,
            2: 1,
            3: 2,
            4: 3,
            7: 4,
            8: 5,
            10: 6,
            11: 7,
            12: 8,
            13: 9,
            14: 10,
            15: 11,
            16: 12,
            17: 13,
            18: 14,
            24: 15,
            28: 16,
            31: 17,
            41: 18,
            42: 19,
            43: 20,
            46: 21,
            47: 22,
            49: 23,
            50: 24,
            51: 25,
            52: 26,
            53: 27,
            54: 28,
            60: 29,
            63: 30
        }

        # 使用映射替换标签数组中的值
        for old_label, new_label in label_mapping.items():
            target_label_np[target_label_np == old_label] = new_label
        # 现在，label1_array 包含了连续的整数值
        target_label_np = target_label_np.astype(int)  

        # 现在，label1_array 包含了连续的整数值
        target_image = torch.Tensor(target_image_np).unsqueeze(dim = 0)
        target_label = torch.Tensor(target_label_np).unsqueeze(dim = 0)
        
        target_data_dict = {'image': target_image, "label": target_label}
            
        return target_image,target_label,source_filename,target_path

# Metrics

In [5]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.vals = []
        self.std = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.vals.append(val)
        self.std = np.std(self.vals)

def dice_val(y_pred, y_true, num_clus):
    y_pred = nn.functional.one_hot(y_pred, num_classes=num_clus)
    y_pred = torch.squeeze(y_pred, 1)
    y_pred = y_pred.permute(0, 4, 1, 2, 3).contiguous()
    y_true = nn.functional.one_hot(y_true, num_classes=num_clus)
    y_true = torch.squeeze(y_true, 1)
    y_true = y_true.permute(0, 4, 1, 2, 3).contiguous()
    # Exclude the background class
    y_pred = y_pred[:, 1:]  # Assuming that background class is at index 0
    y_true = y_true[:, 1:]
    intersection = y_pred * y_true
    intersection = intersection.sum(dim=[2, 3, 4])
    union = y_pred.sum(dim=[2, 3, 4]) + y_true.sum(dim=[2, 3, 4])
    dsc = (2.*intersection) / (union + 1e-5)
    return torch.mean(torch.mean(dsc, dim=1))   

def multi_class_dice_coefficient(label1, label2, delete_list,num_classes, include_background=True):
    dice_coefficients = []

    for class_idx in range(0 if include_background else 1, num_classes):
        delete_flag = class_idx
        if delete_flag in delete_list:
            continue
        
        label1_binary = (label1 == class_idx).astype(int)
        label2_binary = (label2 == class_idx).astype(int)
        intersection = np.sum(label1_binary * label2_binary)
        union = np.sum(label1_binary) + np.sum(label2_binary)

        dice_coefficient = (2.0 * intersection) / (union + 1e-8) 
        dice_coefficients.append(dice_coefficient)
    
    # 过滤掉值为0的项
    filtered_dice_coefficients = [x for x in dice_coefficients if x != 0]
    # 计算均值
    avg_dice = round(np.mean(filtered_dice_coefficients), 3)   
    #avg_dice = round(np.mean(dice_coefficients), 3)           
    return dice_coefficients,avg_dice
# HD/HD95
HD_metric_before = HausdorffDistanceMetric(percentile=None)
HD_metric_after = HausdorffDistanceMetric(percentile=None)
# Dice
Dice_metric_before = DiceMetric(include_background=False, reduction="mean", get_not_nans=False,ignore_empty = True,num_classes=35)
Dice_metric_after = DiceMetric(include_background=False, reduction="mean", get_not_nans=False,ignore_empty = True,num_classes=35)

# Param

In [6]:
#dir_save = ".\\experiments\\BCH\\unet_supervised\\"
#dir_save = "./experiments/IXI/unet_supervised_lr=1e-3//"
dir_save = ".\\experiments\\OASIS\\Ablation\\real_image+distill\\"
batch_size = 1 
folder_names = os.path.join(dir_save,"test_output")
load_pretrained_model_weights = True
#source_data_root = "D:/datasets/IXI_data/Test/" 
source_data_root = "D:/datasets/OASIS/" 
#source_data_root = "D:/datasets/20230423_pairs/CT/"
delete_list = []#[18,19,34,35]
test_set = Seg_OASIS_Dataset(source_data_root,is_Training=False)  # 使用你的数据根目录路径
#test_set = Seg_CT_MRI_Dataset(source_data_root,is_Training=False) 
#test_set = Seg_IXI_Dataset(source_data_root,is_Training=False) 
# 获取数据集的长度和样本索引
test_set_length = int(len(test_set))
print('Length of test set: ', test_set_length)
test_loader = DataLoader(test_set, batch_size = batch_size, shuffle = False, num_workers = 0, drop_last = False)  

Length of test set:  39


# Visualize registration performance of trained network
## Load pretrained model and perform forward pass

In [7]:
# ==============================================
# Test
# ==============================================

if load_pretrained_model_weights:
    dir_load = dir_save  # folder where network weights are stored
    # instantiate model
    '''
    model = myUNet(
        in_channel=1,
        num_class=7,
        channel_list=(16, 32, 64, 128, 256),
        residual=True,
        device='cuda'
    )
    '''
    model = teacher_student_model(
    in_channel=1,
    num_class=36,
    channel_list=(16, 32, 64, 128, 256),
    residual=True,
    vae=False,
    device='cuda'
    )
    
    # load model weights
    #filename_best_model = glob.glob(os.path.join(dir_load, "distill_model_kpt_total_loss_best*"))[0]
    filename_best_model = glob.glob(os.path.join(dir_load, "distill_model_kpt_best_dice*"))[0]
    #filename_best_model = glob.glob(os.path.join(dir_load, "unet_kpt_best_dice*"))[0]
    model.load_state_dict(torch.load(filename_best_model))
    print('Best model: {}'.format(filename_best_model))
    #print(model)
    
model.eval()
best_dsc = 0
best_dsc_filename = ''
best_dsc_pth = ''
eval_HD = AverageMeter()
eval_dsc = AverageMeter()
eval_monai_dsc_def = AverageMeter()
all_dice_coefficients = []
all_image_avg_dice = []
comb_all_dice_coefficients=[]
with torch.no_grad():
    for batch_data in test_loader:
        # Get data
        t0_test = time.time()
        data = [t for t in batch_data]
        x = data[0].cuda()
        x_seg = data[1].cuda()
        moving_path = data[2]
        
        
        if moving_path[0]=='BCH_CT094':
            continue
           
        pred_label,_ = model(x,x,is_Training=False)#重构图像就是x自己
        teacher_features,recon,pred_label,student_features = model(x,x,is_Training=True)
        #pred_label = model(x)
        #print(teacher_features[0].shape)
        pred_label = torch.argmax(pred_label, dim=1)
        #print("x_seg shape: ",x_seg.shape)
        #print("pred label shape: ",pred_label.squeeze().cpu().numpy().shape)
        '''
        # HD/HD95
        # Transform 
        
        transform_binary = AsDiscrete(threshold = 0.5)
        transform_onehot = AsDiscrete(to_onehot = 2)

        pred_label_binarized = transform_binary(pred_label.unsqueeze(dim=1))
        gt_label_binarized = transform_binary(x_seg)
        HD_metric_after(y_pred=pred_label_binarized, y=gt_label_binarized)
        eval_HD.update(HD_metric_after.aggregate().item(), x.size(0))
        
        # Dice
        dsc = utils.dice_val(pred_label.long(), x_seg.long(), 36)
        eval_dsc.update(dsc.item(), x.size(0))
        

        print(f"{moving_path[0]} Dice = {dsc.item()}, HD = {HD_metric_after.aggregate().item()}")
        
        Dice_metric_after(y_pred=pred_label.unsqueeze(dim=1), y=x_seg)
        eval_monai_dsc_def.update(Dice_metric_after.aggregate().item(), x.size(0))
        print('{} Monai DSC After: {:.3f}'.format(moving_path[0],Dice_metric_after.aggregate().item()))
                                                                     
        Dice_metric_after.reset()
        '''
        
        dice_coefficients,avg_dice = multi_class_dice_coefficient(pred_label.squeeze().cpu().numpy(), \
                                                                  x_seg.squeeze().cpu().numpy(),delete_list,36, include_background=False)
        #print('{} DSC : {:.3f}'.format(moving_path[0],avg_dice))
        # 将当前循环的 dice_coefficients 添加到列表中
        '''
        comb_dice_coefficients=[]
        for i in range(1, 11):
            j=i-1
            comb_avg_dice = (dice_coefficients[j] + dice_coefficients[j + 19]) / 2
            comb_dice_coefficients.append(comb_avg_dice)
        for i in range(11,14):
            comb_avg_dice = dice_coefficients[i-1]
            comb_dice_coefficients.append(comb_avg_dice)
        for i in range(14,20): 
            j=i-1
            comb_avg_dice = (dice_coefficients[j] + dice_coefficients[j + 16]) / 2    
            comb_dice_coefficients.append(comb_avg_dice)
        #print(len(comb_ants_dice_coefficients))
        comb_all_dice_coefficients.append(comb_dice_coefficients)
        '''
        all_dice_coefficients.append(dice_coefficients)
        all_image_avg_dice.append(avg_dice)
        eval_dsc.update(avg_dice)
        
        if avg_dice>best_dsc:
            best_dsc_filename = moving_path
            best_dsc_path = data[3][0]
            best_dsc = avg_dice
        
        #HD_metric_after.reset()
        '''
        # Save Images
        if not os.path.exists(str('./{}'.format(folder_names))):
            os.makedirs(str('./{}'.format(folder_names)))
        
        fixed_image_sitk = sitk.ReadImage("D:/datasets/20230423_pairs/CT/BCH_CT001/brain_small_norm.nii.gz")
        fixed_image_sitk =  sitk.ReadImage(data[3][0])   
        
        fixed_image, _ = pkload(data[3][0]) 
        fixed_image = np.transpose(fixed_image, (2,1,0))
        fixed_image_sitk = sitk.GetImageFromArray(fixed_image)
        
        teacher_feature_0 = sitk.GetImageFromArray(teacher_features[0].data.cpu().numpy().squeeze())
        teacher_feature_0.SetSpacing(fixed_image_sitk.GetSpacing())
        teacher_feature_0.SetDirection(fixed_image_sitk.GetDirection())
        teacher_feature_0.SetOrigin(fixed_image_sitk.GetOrigin())     
        sitk.WriteImage(teacher_feature_0,
                        str('./{}/teacher_feature_0_{}.nii.gz'.format(
                            folder_names, moving_path[0])))
        
        
        pred_label = sitk.GetImageFromArray(pred_label.data.cpu().numpy().squeeze().astype(np.int16))
        pred_label.SetSpacing(fixed_image_sitk.GetSpacing())
        pred_label.SetDirection(fixed_image_sitk.GetDirection())
        pred_label.SetOrigin(fixed_image_sitk.GetOrigin())       
        gt_label = sitk.GetImageFromArray(x_seg.data.cpu().numpy().squeeze().astype(np.int16))
        gt_label.SetSpacing(fixed_image_sitk.GetSpacing())
        gt_label.SetDirection(fixed_image_sitk.GetDirection())
        gt_label.SetOrigin(fixed_image_sitk.GetOrigin())      
        
        recon_img = sitk.GetImageFromArray(recon.data.cpu().numpy().squeeze())
        recon_img.SetSpacing(fixed_image_sitk.GetSpacing())
        recon_img.SetDirection(fixed_image_sitk.GetDirection())
        recon_img.SetOrigin(fixed_image_sitk.GetOrigin()) 
        
        
        if moving_path[0]=='BCH_CT099' or moving_path[0]=='BCH_CT100':
            sitk.WriteImage(pred_label,
                            str('./{}/pred_label_{}.nii.gz'.format(
                                folder_names, moving_path[0])))
            sitk.WriteImage(gt_label,
                            str('./{}/gt_label_{}.nii.gz'.format(
                                folder_names, moving_path[0])))
            #sitk.WriteImage(recon_img,str('./{}/recon_img_{}.nii.gz'.format(folder_names, moving_path[0])))
        '''
          
average_dice_coefficients = np.mean(all_dice_coefficients, axis=0)
variance_dice_coefficients = np.var(all_dice_coefficients, axis=0)        
# metrics
#print(f'Dice | {eval_monai_dsc_def.avg:.3f}, std: {eval_monai_dsc_def.std:.3f}')
#print(f'HD | {eval_HD.avg:.3f}, std: {eval_HD.std:.3f}')
print(f'Dice | {eval_dsc.avg:.3f}, std: {eval_dsc.std:.3f}')
print('Best dsc image: ',best_dsc_filename)
id_list = [i for i in range(1,36) if i not in []]#18, 19, 34, 35
'''
comb_average_dice_coefficients=[]
comb_variance_dice_coefficients=[]
for i in range(1, 11):
    j=i-1
    avg_dice = (average_dice_coefficients[j] + average_dice_coefficients[j + 19]) / 2
    var_dice = (variance_dice_coefficients[j] + variance_dice_coefficients[j + 19]) / 2
    comb_average_dice_coefficients.append(avg_dice)
    comb_variance_dice_coefficients.append(var_dice)
    print(f"Dice Coefficient for class {id_list[j]}: {avg_dice:.3f}±{var_dice:.3f}")
for i in range(11,14):
    avg_dice = average_dice_coefficients[i-1]
    var_dice = variance_dice_coefficients[i-1] 
    comb_average_dice_coefficients.append(avg_dice)
    comb_variance_dice_coefficients.append(var_dice)
    print(f"Dice Coefficient for class {id_list[i-1]}: {avg_dice:.3f}±{var_dice:.3f}")
for i in range(14,20): 
    j=i-1
    avg_dice = (average_dice_coefficients[j] + average_dice_coefficients[j + 16]) / 2
    var_dice = (variance_dice_coefficients[j] + variance_dice_coefficients[j + 16]) / 2
    comb_average_dice_coefficients.append(avg_dice)
    comb_variance_dice_coefficients.append(var_dice)
    print(f"Dice Coefficient for class {id_list[j]}: {avg_dice:.3f}±{var_dice:.3f}")

'''
'''
output_list = [12,23,6,5,22,1,18,4,21,8,25,16,29,9,26,7,24,3,20,13,27,10,11,14,28,2,19,15,17,30]



for i in range(0,30)
    avg_dice = (average_dice_coefficients[j] + average_dice_coefficients[j + 16]) / 2
    var_dice = (variance_dice_coefficients[j] + variance_dice_coefficients[j + 16]) / 2
    print(f"Dice Coefficient for class {id_list[idx-1]}: {avg_dice:.3f}±{var_dice:.3f}")

'''
for idx, (avg_dice, var_dice) in enumerate(zip(average_dice_coefficients, variance_dice_coefficients), 1):
    print(f"Dice Coefficient for class {id_list[idx-1]}: {avg_dice:.3f}±{var_dice:.3f}")


    

Best model: .\experiments\OASIS\Ablation\real_image+distill\distill_model_kpt_best_dice_55_0.8107.pth
Dice | 0.817, std: 0.018
Best dsc image:  ('OASIS_OAS1_0438_MR1',)
Dice Coefficient for class 1: 0.890±0.000
Dice Coefficient for class 2: 0.772±0.001
Dice Coefficient for class 3: 0.914±0.001
Dice Coefficient for class 4: 0.459±0.013
Dice Coefficient for class 5: 0.891±0.001
Dice Coefficient for class 6: 0.907±0.000
Dice Coefficient for class 7: 0.929±0.000
Dice Coefficient for class 8: 0.877±0.002
Dice Coefficient for class 9: 0.914±0.000
Dice Coefficient for class 10: 0.912±0.000
Dice Coefficient for class 11: 0.851±0.003
Dice Coefficient for class 12: 0.821±0.002
Dice Coefficient for class 13: 0.947±0.000
Dice Coefficient for class 14: 0.823±0.007
Dice Coefficient for class 15: 0.847±0.000
Dice Coefficient for class 16: 0.834±0.001
Dice Coefficient for class 17: 0.893±0.000
Dice Coefficient for class 18: 0.499±0.032
Dice Coefficient for class 19: 0.567±0.005
Dice Coefficient for cl

In [8]:
'''
import matplotlib.pyplot as plt
import numpy as np

# 假设 features 是 torch.Size([1, 32, 112, 96, 80]) 的张量
features = teacher_features[0]
print(features.shape)
# 选择要可视化的切片索引
slice_index = 56  # 选择切片的索引

# 选择在维度上的切片
slice_data = features[0, :,slice_index, :, :]

# 转换为NumPy数组
slice_data_np = slice_data.cpu().numpy()

# 可视化
plt.imshow(slice_data_np[0, :, :], cmap='viridis')  # 使用 viridis 颜色图
plt.title('Visualization of Feature Slice')
plt.colorbar()
plt.show()
'''

"\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# 假设 features 是 torch.Size([1, 32, 112, 96, 80]) 的张量\nfeatures = teacher_features[0]\nprint(features.shape)\n# 选择要可视化的切片索引\nslice_index = 56  # 选择切片的索引\n\n# 选择在维度上的切片\nslice_data = features[0, :,slice_index, :, :]\n\n# 转换为NumPy数组\nslice_data_np = slice_data.cpu().numpy()\n\n# 可视化\nplt.imshow(slice_data_np[0, :, :], cmap='viridis')  # 使用 viridis 颜色图\nplt.title('Visualization of Feature Slice')\nplt.colorbar()\nplt.show()\n"

In [9]:
print(len(all_dice_coefficients[0]))

35


In [10]:

import pickle
#dir_save = '.\\experiments\\OASIS_IJCAI_baslines\\Ants-Affine\\'
# 指定保存文件的路径和名称
file_path = os.path.join(dir_save,'all_images_all_struct_dice_coefficients.pkl')
file_path2 = os.path.join(dir_save,'all_images_dice_coefficients.pkl')
# 使用pickle将列表保存到文件中
with open(file_path, 'wb') as file:
    pickle.dump(all_dice_coefficients, file)
with open(file_path2, 'wb') as file2:
    pickle.dump(all_image_avg_dice, file2)
print(f'all_images_all_struct_dice_coefficients已保存到文件: {file_path}')
print(f'all_images_dice_coefficients已保存到文件: {file_path2}')


all_images_all_struct_dice_coefficients已保存到文件: .\experiments\OASIS\Ablation\real_image+distill\all_images_all_struct_dice_coefficients.pkl
all_images_dice_coefficients已保存到文件: .\experiments\OASIS\Ablation\real_image+distill\all_images_dice_coefficients.pkl


In [11]:
'''
import pickle
#dir_save = '.\\experiments\\ICML_baselines\\OASIS_ICML_baselines\\Ants-Affine\\'
# 指定保存文件的路径和名称
file_path = os.path.join(dir_save,'comb_all_images_all_struct_dice_coefficients.pkl')
# 使用pickle将列表保存到文件中
with open(file_path, 'wb') as file:
    pickle.dump(comb_all_dice_coefficients, file)
print(f'all_images_all_struct_dice_coefficients已保存到文件: {file_path}')
'''

"\nimport pickle\n#dir_save = '.\\experiments\\ICML_baselines\\OASIS_ICML_baselines\\Ants-Affine\\'\n# 指定保存文件的路径和名称\nfile_path = os.path.join(dir_save,'comb_all_images_all_struct_dice_coefficients.pkl')\n# 使用pickle将列表保存到文件中\nwith open(file_path, 'wb') as file:\n    pickle.dump(comb_all_dice_coefficients, file)\nprint(f'all_images_all_struct_dice_coefficients已保存到文件: {file_path}')\n"