In [1]:
import cv2
import numpy as np
import os
import pandas as ps
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch import optim
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
# This is for the progress bar.
from tqdm.auto import tqdm
from torch.nn import functional as F
from torchvision import transforms
from torchvision.utils import save_image
from model import UNet, MyDataset
from torch.utils.data import random_split
from my_metrics import calculate_metrics
import csv

In [None]:
#设置设备，选择cuda
device = 'cuda' if torch.backends.mps.is_available() else 'cpu'

weight_path='params/unet.pth'
model = UNet().to(device)

if os.path.exists(weight_path):
    model.load_state_dict(torch.load(weight_path))
    print('successful load weight！\n')
else:
    print('model is not exist!\n')
    exit()

for i in range(1,6):
    data_folder_path = f'../data/FAZ/Domain{i}/test/imgs'
    mask_folder_path = f'../data/FAZ/Domain{i}/test/mask'
    save_path = f"result/Domain{i}"
    if not os.path.isdir('result'):
        os.mkdir('result')
    if not os.path.isdir(save_path):
        os.mkdir(save_path)

    data_files = os.listdir(data_folder_path)
    mask_files= os.listdir(mask_folder_path)
    del_count = 0 
    for j in range(len(data_files)):
        if data_files[j-del_count].endswith(".png"):
            data_files[j-del_count] = data_files[j-del_count]
        else:
            del data_files[j-del_count]
            del_count = del_count+1
    del_count = 0      
    for j in range(len(mask_files)):
        if mask_files[j-del_count].endswith(".png"):
            mask_files[j-del_count] = mask_files[j-del_count]
        else:
            del mask_files[j-del_count]
            del_count = del_count+1

    dataset = MyDataset(data_file=data_files,mask_files=mask_files,data_folder_path=data_folder_path,mask_folder_path=mask_folder_path,test_flag=True)
    datalen = len(dataset)
    testLoader = DataLoader(dataset,batch_size=8,shuffle=True)
    [all_dc,all_jc,all_hd,all_assd,all_sp,all_recall,all_pre] = [0,0,0,0,0,0,0]
    model.eval()
    print(f"start testing domain{i}\n")
    cur = 1
    
    for idx,(img,seg_img,fname) in enumerate(testLoader):
        img = img.to(device)
        seg_img = seg_img.to(device)
        with torch.no_grad():
             output = model(img)
        
        
        for pred,target,fname in zip(output, seg_img, fname):
            dice,jaccard,hd95_score,assd_score,sp_score,recall_score,pre_score = calculate_metrics(pred,target)
            all_dc += dice/datalen
            all_jc += jaccard/datalen
            all_hd += hd95_score/datalen
            all_assd += assd_score/datalen
            all_sp += sp_score/datalen
            all_recall += recall_score/datalen
            all_pre += pre_score/datalen

            origin_img = torch.tensor(cv2.imread(os.path.join(data_folder_path,fname),cv2.IMREAD_GRAYSCALE)/255).reshape(1,256,256)
            
            total=torch.stack([origin_img,target,pred],dim=0)
            save_image(total,f'{save_path}/{fname}.png')
    print(f"the metrics of model on Domain{i} are as below")
    print(all_dc.data.cpu().numpy(),all_jc,all_hd,all_assd,all_sp,all_recall,all_pre)
    print()

print("finish testing\n") 


successful load weight！

start testing domain1

the metrics of model on Domain1 are as below
0.62106675 0.7910184230180797 3.3514830254297117 0.3955501218346783 0.9964969960099316 0.9150526561274266 0.8601987401347958

start testing domain2

the metrics of model on Domain2 are as below
0.62489855 0.736247502089821 2.321065072179495 0.42308995802728394 0.9990789453877397 0.7509039973981944 0.9763236000382434

start testing domain3

the metrics of model on Domain3 are as below
0.64836144 0.6781137177734373 19.039196270588334 3.139345870228153 0.9908096059632602 0.903381890738109 0.7509016785723492

start testing domain4

