使用训练好的模型评估在数据集上的效果

In [1]:
import torch
import os
from DataSet import BrainDataSet
from DataSet import get_data
from torchvision import transforms
from torch.utils.data import DataLoader
import pandas as pd
from sklearn.model_selection import GroupKFold
from OurModel import OurModel
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
path = "E:/DataSet/Alzheimer/mcad_crop/"
# get data information(image name, image label and image center)
mcad_info = pd.read_excel("E:/DataSet/Alzheimer/mcad_info_809.xlsx")
mcad_info = mcad_info[['Subj_t1', 'Group', 'center']]
print(mcad_info)

     Subj_t1  Group  center
0     1_S021      1       3
1     1_S030      1       3
2     1_S035      1       3
3     1_S052      1       3
4     1_S053      1       3
..       ...    ...     ...
804  3_AD039      3       7
805  3_AD040      3       7
806  3_AD041      3       7
807  3_AD042      3       7
808  3_AD043      3       7

[809 rows x 3 columns]


In [3]:
pathes, labels, sites = get_data(path=path, mcad_info=mcad_info)
# print(sites)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
     transforms.Normalize([0.5], [0.5])
])

for site in range(7):
    # for each trained model, evaluate its performence on all sites
    group_kf = GroupKFold(n_splits=7)

    net = OurModel(
            block_num_list = [2,2,2,2],
            channel_list=[1,64,128,256,512],
            num_class=2,
            conv_pooling=True,
            domain_knowledge=False,
            domain_knowledge_len=11,
            train_sites_num=[0 for i in range(6)],
            mlp_ratio=1
        )

    net = torch.load('model/chan=1-512 domain=F testsite={} best_network.pth'.format(site))
    net = net.cuda()
    print("model trained without site{0}".format(site))

    for train_index, test_index in group_kf.split(pathes, labels, sites):
        # use test_site to evaluate the model's performence
        
        # print(test_site)

        test_pathes = pathes[test_index]
        test_labels = labels[test_index]
        test_sites = sites[test_index]

        dataset = BrainDataSet(img_pathes=test_pathes, 
                               labels=test_labels, 
                               sites = test_sites)
        
        dataloader = DataLoader(dataset=dataset,
                                batch_size=1,
                                shuffle=False,
                                drop_last=False)
        acc = 0.
        for img, label, site in dataloader:
            img = img.unsqueeze(dim = 1).cuda()
            img = (img - torch.min(img)) / (torch.max(img) - torch.min(img))
            # print(img.shape)
            label = label.cuda()
            site = site.cuda()
            pre_label = net(img, train = False, site = site, gender = None, age = None, MMSE = None)
            pre_label = torch.argmax(F.softmax(pre_label), dim = 1)
            acc += (pre_label == label).sum()

        print("site{0} acc:{1}".format(test_sites[0], acc / len(dataloader)))
   
    print("---------------------------------------------------------------------------------\n")

model trained without site0


  pre_label = torch.argmax(F.softmax(pre_label), dim = 1)


site4 acc:1.0
site3 acc:0.8545454144477844
site0 acc:0.8705882430076599
site6 acc:0.9638553857803345
site2 acc:0.9672130346298218
site5 acc:0.9818181395530701
site1 acc:1.0
---------------------------------------------------------------------------------

model trained without site1


  pre_label = torch.argmax(F.softmax(pre_label), dim = 1)


site4 acc:1.0
site3 acc:0.9545454382896423
site0 acc:1.0
site6 acc:0.9999999403953552
site2 acc:0.9999999403953552
site5 acc:0.9999999403953552
site1 acc:0.9111111164093018
---------------------------------------------------------------------------------

model trained without site2


  pre_label = torch.argmax(F.softmax(pre_label), dim = 1)


site4 acc:0.991150438785553
site3 acc:0.7999999523162842
site0 acc:1.0
site6 acc:0.9879517555236816
site2 acc:0.7704917788505554
site5 acc:0.9818181395530701
site1 acc:1.0
---------------------------------------------------------------------------------

model trained without site3


  pre_label = torch.argmax(F.softmax(pre_label), dim = 1)


site4 acc:1.0
site3 acc:0.7454545497894287
site0 acc:1.0
site6 acc:0.9759035706520081
site2 acc:0.9999999403953552
site5 acc:0.9999999403953552
site1 acc:1.0
---------------------------------------------------------------------------------

model trained without site4


  pre_label = torch.argmax(F.softmax(pre_label), dim = 1)


site4 acc:0.9115044474601746
site3 acc:0.8363636136054993
site0 acc:1.0
site6 acc:0.9879517555236816
site2 acc:0.9180327653884888
site5 acc:0.9454545378684998
site1 acc:1.0
---------------------------------------------------------------------------------

model trained without site5


  pre_label = torch.argmax(F.softmax(pre_label), dim = 1)


site4 acc:0.991150438785553
site3 acc:0.7363635897636414
site0 acc:0.9529411792755127
site6 acc:0.9638553857803345
site2 acc:0.9180327653884888
site5 acc:0.9090908765792847
site1 acc:1.0
---------------------------------------------------------------------------------

model trained without site6


  pre_label = torch.argmax(F.softmax(pre_label), dim = 1)


site4 acc:0.991150438785553
site3 acc:0.7545454502105713
site0 acc:1.0
site6 acc:0.8915662169456482
site2 acc:0.9180327653884888
site5 acc:0.9636363387107849
site1 acc:1.0
---------------------------------------------------------------------------------

