In [1]:
import os
import tempfile
import warnings
import glob
import torch
import utils
import SimpleITK as sitk
import numpy as np
import pystrum.pynd.ndutils as nd
import matplotlib.pyplot as plt
import torch.nn as nn
from model.teacher_student_model import teacher_student_model

# Metrics

In [2]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.vals = []
        self.std = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.vals.append(val)
        self.std = np.std(self.vals)

def multi_class_dice_coefficient(label1, label2,num_classes, include_background=True):
    dice_coefficients = []

    for class_idx in range(0 if include_background else 1, num_classes):
        
        label1_binary = (label1 == class_idx).astype(int)
        label2_binary = (label2 == class_idx).astype(int)
        intersection = np.sum(label1_binary * label2_binary)
        union = np.sum(label1_binary) + np.sum(label2_binary)

        dice_coefficient = (2.0 * intersection) / (union + 1e-8) 
        dice_coefficients.append(dice_coefficient)
    
    # 过滤掉值为0的项
    filtered_dice_coefficients = [x for x in dice_coefficients if x != 0]
    # 计算均值
    avg_dice = round(np.mean(filtered_dice_coefficients), 3)   
    #avg_dice = round(np.mean(dice_coefficients), 3)           
    return dice_coefficients,avg_dice

# Param

In [3]:
dir_save = "your model path"
batch_size = 1 
num_classes = 36 #36 for OASIS,27 for CT, 31 for IXI
folder_names = os.path.join(dir_save,"test_output")
load_pretrained_model_weights = True
pth_filename =  "OS-MedSeg_pre-trained_model_on_OASIS*" # pre-trained model filename

source_image_path = "D:/datasets/OASIS/OASIS_OAS1_0001_MR1/aligned_norm.nii.gz" # your image path
source_label_path = "D:/datasets/OASIS/OASIS_OAS1_0001_MR1/aligned_seg35.nii.gz"
source_image_sitk = sitk.ReadImage(source_image_path)
source_label_sitk = sitk.ReadImage(source_label_path)

source_image_np = sitk.GetArrayFromImage(source_image_sitk) #Converting sitk_metadata to image Array
source_label_np = sitk.GetArrayFromImage(source_label_sitk)

source_image = torch.Tensor(source_image_np).unsqueeze(dim = 0).unsqueeze(dim = 0)
source_label = torch.Tensor(source_label_np).unsqueeze(dim = 0).unsqueeze(dim = 0)
print(source_image.shape)

torch.Size([1, 1, 224, 192, 160])


# Visualize registration performance of trained network
## Load pretrained model and perform forward pass

In [6]:
# ==============================================
# Test
# ==============================================

if load_pretrained_model_weights:
    dir_load = dir_save  # folder where network weights are stored
    # instantiate model
    model = teacher_student_model(
    in_channel=1,
    num_class=36,
    channel_list=(16, 32, 64, 128, 256),
    residual=True,
    vae=False,
    device='cuda'
    )
    
    # load model weights
    #filename_best_model = glob.glob(os.path.join(dir_load, "distill_model_kpt_best_dice*"))[0]
    filename_best_model = glob.glob(os.path.join(dir_load, pth_filename))[0] # your pth filename
    model.load_state_dict(torch.load(filename_best_model))
    print('Best model: {}'.format(filename_best_model))
    #print(model)
    
model.eval()
eval_dsc = AverageMeter()
with torch.no_grad():

    x = source_image.cuda()
    x_seg = source_label.cuda()
    
    pred_label,_ = model(x,x,is_Training=False)
    pred_label,student_features = model(x,x,is_Training=False)
    pred_label = torch.argmax(pred_label, dim=1)
    #print("x_seg shape: ",x_seg.shape)
    #print("pred label shape: ",pred_label.squeeze().cpu().numpy().shape)

    dice_coefficients,avg_dice = multi_class_dice_coefficient(pred_label.squeeze().cpu().numpy(), \
                                                              x_seg.squeeze().cpu().numpy(),num_classes, include_background=False)#num_classes=36 for OASIS,27 for CT, 31 for IXI 
    eval_dsc.update(avg_dice)

    # Save Images
    if not os.path.exists(str('./{}'.format(folder_names))):
        os.makedirs(str('./{}'.format(folder_names)))

    fixed_image_sitk = source_image_sitk

    pred_label = sitk.GetImageFromArray(pred_label.data.cpu().numpy().squeeze().astype(np.int16))
    pred_label.SetSpacing(fixed_image_sitk.GetSpacing())
    pred_label.SetDirection(fixed_image_sitk.GetDirection())
    pred_label.SetOrigin(fixed_image_sitk.GetOrigin())       

    sitk.WriteImage(pred_label, str('./{}/pred_label.nii.gz'.format(folder_names)))

# metrics

print(f'Dice | {eval_dsc.avg:.3f}')
id_list = [i for i in range(1,num_classes)]

for idx, avg_dice in enumerate(dice_coefficients, 1):
    print(f"Dice Coefficient for class {id_list[idx-1]}: {avg_dice:.3f}")


    

Best model: .\experiments\OASIS\syn_distill_model_hint_1_seg_1_recon_1_lr=1e-3\distill_model_kpt_best_dice_69_0.8459.pth
Dice | 0.851
Dice Coefficient for class 1: 0.939
Dice Coefficient for class 2: 0.896
Dice Coefficient for class 3: 0.944
Dice Coefficient for class 4: 0.536
Dice Coefficient for class 5: 0.924
Dice Coefficient for class 6: 0.938
Dice Coefficient for class 7: 0.917
Dice Coefficient for class 8: 0.893
Dice Coefficient for class 9: 0.926
Dice Coefficient for class 10: 0.911
Dice Coefficient for class 11: 0.895
Dice Coefficient for class 12: 0.900
Dice Coefficient for class 13: 0.954
Dice Coefficient for class 14: 0.863
Dice Coefficient for class 15: 0.904
Dice Coefficient for class 16: 0.749
Dice Coefficient for class 17: 0.878
Dice Coefficient for class 18: 0.826
Dice Coefficient for class 19: 0.336
Dice Coefficient for class 20: 0.951
Dice Coefficient for class 21: 0.905
Dice Coefficient for class 22: 0.950
Dice Coefficient for class 23: 0.755
Dice Coefficient for cla