In [11]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from copy import deepcopy
import os 
import random
from PIL import Image
import torchvision
from pathlib import Path
from tqdm import tqdm
from skimage import io
import skimage
import cv2
import timm

from torch import nn
from torchinfo import summary
# from skorch import NeuralNetClassifier
# from sklearn.model_selection import GridSearchCV
from torchmetrics import Accuracy
from torchmetrics.classification import BinarySpecificity, BinaryPrecision, BinaryRecall, BinaryAccuracy, BinaryF1Score
from torchvision.models import vgg19
from PIL import Image
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc, brier_score_loss
from sklearn.model_selection import train_test_split, StratifiedKFold
import albumentations
from albumentations.pytorch import ToTensorV2

from models import get_model
from dataset import *
from utils import *
from run_full_tbm import *

if __name__ == '__main__':
    torch.cuda.empty_cache() 
    all_model = [
        'vgg11',
        'resnet10',
        'resnet50',
        'resnet50_med3d',
        'densenet_201'
    ]

    checkpath = [
        'DTI/wdtifit_FA.nii',
        'DTI/wdtifit_MD.nii',
        'DTI/wfdt_paths.nii',
        'DTI/wnfdt_paths.nii',
        'DTI/wnodif.nii',
        'T2s/wR2S.nii',
        'T2s/wrealigne01.nii'
    ]
    filenames = [
        # 'DTI/wdtifit_FA.nii',
        # 'DTI/wdtifit_MD.nii',
        # 'DTI/wfdt_paths.nii',
        # 'DTI/wnfdt_paths.nii',
        # 'T2s/wR2S.nii',
        # 'T2s/wrealigne01.nii',
        'fmri_mirror/con_0001.nii',
        'fmri_mirror/con_0003.nii'
    ]
    should_normalize = [
        # False,
        # False,
        # True,
        # True,
        # True,
        # True,
        True,
        True

    ]
    
    checkpoints_dir = './checkpoints'
    tensorboard_dir = './runs'
    basepath = '/data1/TBM/TBM-AI_data/data_by_subject'
    excel_path = './subjects_info_TBM-AI.xlsx'
    checkpoints_suffix = "matchage"
    
    img_size = 64
    max_epochs = 100
    batch_size = 16
    num_workers = 8
    random_seed = 42
    lr = 1e-4
    torch.backends.cudnn.benchmark = True
    pin_memory = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device: ", device)

    df = pd.DataFrame(columns = ['model', 'data','version', 'bacc', 'acc', 'precision', 'recall', 'f1', 'rocauc', 'specificity'])

    for filename, norm in zip(filenames, should_normalize):
        print(f"Start {filename}")
        file_list, labels, nodata = get_list_files_and_labels(basepath, excel_path, filename)
        # mask_list, _, _ = get_list_files_and_labels(basepath, excel_path, mask_filename)

        train_tmp_indices = np.load('train_tmp_indices.npy', allow_pickle=True)
        train_indices = np.load('train_indices.npy', allow_pickle=True)
        val_indices = np.load('val_indices.npy', allow_pickle=True)
        test_indices = np.load('test_indices.npy', allow_pickle=True)
        file_list = np.array(file_list)
        labels = np.array(labels)

        ###########################################################################################################
        for model_name in all_model:
            for i, (imp_index, train_index, val_index, test_index) in enumerate(zip(train_tmp_indices, train_indices, val_indices, test_indices)):
                try:

                    version = i
                    imp_index = np.array(imp_index).astype(int)
                    train_index = np.array(train_index).astype(int)
                    val_index = np.array(val_index).astype(int)
                    test_index = np.array(test_index).astype(int)
                    
                    checkpoint_name = f"{filename.split('/')[-1]}-{model_name}-{version}{checkpoints_suffix}"
                    tensorboard_suffix = f"{model_name}-{version}"
                    print(f"RUN: {checkpoint_name}")

                    X_train, X_val, y_train, y_val = file_list[imp_index], file_list[val_index], labels[imp_index], labels[val_index]
                    
                    if norm:
                        mean_data, std_data = get_normalization_param_nomask(X_train)
                    else:
                        mean_data, std_data = 0,1
                    train_aug = get_transform(mean_data, std_data)
                    valid_aug = get_transform(mean_data, std_data)

                    X_train, X_test, y_train, y_test = X_train[train_index], X_train[test_index], y_train[train_index], y_train[test_index]
  
                    train_loader = get_loader(X_train, y_train, train_aug, mode = 'train', batch_size = batch_size, pin_memory = pin_memory, num_workers = num_workers, img_size = img_size)
                    
                    val_loader = get_loader(X_val, y_val, valid_aug, mode = 'val', batch_size = batch_size, pin_memory = pin_memory, num_workers= num_workers, img_size = img_size)
                    test_loader = get_loader(X_test, y_test, valid_aug, mode = 'val', batch_size = batch_size, pin_memory = pin_memory, num_workers=num_workers, img_size = img_size)

                    model = get_model(model_name)
                    model = model.to(device)



                    model.load_state_dict(load_checkpoint(checkpoints_dir, checkpoint_name, 'bacc'))
                    model.eval()
                    test_outputs_all = []
                    test_labels_all = []
                    with torch.no_grad():
                        for test_images, test_labels in tqdm(test_loader):
                            test_images, test_labels = test_images.to(device), test_labels.to(device)
                            test_outputs = model(test_images)

                            test_outputs = test_outputs.squeeze(-1)
                            test_outputs_all.append(test_outputs.sigmoid().cpu().numpy())
                            test_labels_all.append(test_labels.cpu().numpy())
                    
                    test_outputs_all = np.concatenate(test_outputs_all)
                    test_labels_all = np.concatenate(test_labels_all)

                    test_outputs_all[test_outputs_all>=0.5] = 1
                    test_outputs_all[test_outputs_all<0.5] = 0

                    bacc = metrics.balanced_accuracy_score(test_labels_all, test_outputs_all)
                    acc = metrics.accuracy_score(test_labels_all, test_outputs_all)

                    results = cal_metrics_binary(test_labels_all, test_outputs_all)
                    tmp_df = pd.DataFrame({
                        'model': [model_name],
                        'version': [version],
                        'data': filename.split('/')[-1],
                        'bacc': [bacc],
                        'acc': [results['acc']],
                        'precision': [results['precision'][1]],
                        'recall': [results['recall'][1]],
                        'f1': [results['f1'][1]],
                        'rocauc': [results['rocauc']],
                        'specificity': [results['specificity']]
                    })
                    df = pd.concat([df, tmp_df])

                except Exception as e:
                    print(e)
df.to_csv('results_final_con.csv', index= False)
# run 886

Device:  cuda
Start fmri_mirror/con_0001.nii
RUN: con_0001.nii-vgg11-0matchage


100%|██████████| 7/7 [00:01<00:00,  5.19it/s]


RUN: con_0001.nii-vgg11-1matchage


100%|██████████| 7/7 [00:01<00:00,  5.56it/s]


RUN: con_0001.nii-vgg11-2matchage


100%|██████████| 7/7 [00:01<00:00,  5.57it/s]


RUN: con_0001.nii-vgg11-3matchage


100%|██████████| 7/7 [00:01<00:00,  5.15it/s]


RUN: con_0001.nii-vgg11-4matchage


100%|██████████| 7/7 [00:01<00:00,  4.78it/s]


RUN: con_0001.nii-resnet10-0matchage


100%|██████████| 7/7 [00:01<00:00,  5.05it/s]


RUN: con_0001.nii-resnet10-1matchage


100%|██████████| 7/7 [00:01<00:00,  4.86it/s]


RUN: con_0001.nii-resnet10-2matchage


100%|██████████| 7/7 [00:01<00:00,  4.87it/s]


RUN: con_0001.nii-resnet10-3matchage


100%|██████████| 7/7 [00:01<00:00,  4.93it/s]


RUN: con_0001.nii-resnet10-4matchage


100%|██████████| 7/7 [00:01<00:00,  5.20it/s]


RUN: con_0001.nii-resnet50-0matchage


100%|██████████| 7/7 [00:01<00:00,  4.74it/s]


RUN: con_0001.nii-resnet50-1matchage


100%|██████████| 7/7 [00:01<00:00,  4.61it/s]


RUN: con_0001.nii-resnet50-2matchage


100%|██████████| 7/7 [00:01<00:00,  4.65it/s]


RUN: con_0001.nii-resnet50-3matchage


100%|██████████| 7/7 [00:01<00:00,  4.64it/s]


RUN: con_0001.nii-resnet50-4matchage


100%|██████████| 7/7 [00:01<00:00,  4.70it/s]


RUN: con_0001.nii-resnet50_med3d-0matchage


100%|██████████| 7/7 [00:01<00:00,  4.89it/s]


RUN: con_0001.nii-resnet50_med3d-1matchage


100%|██████████| 7/7 [00:01<00:00,  5.10it/s]


RUN: con_0001.nii-resnet50_med3d-2matchage


100%|██████████| 7/7 [00:01<00:00,  4.74it/s]


RUN: con_0001.nii-resnet50_med3d-3matchage


100%|██████████| 7/7 [00:01<00:00,  5.24it/s]


RUN: con_0001.nii-resnet50_med3d-4matchage


100%|██████████| 7/7 [00:01<00:00,  5.01it/s]


RUN: con_0001.nii-densenet_201-0matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.62it/s]


RUN: con_0001.nii-densenet_201-1matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.40it/s]


RUN: con_0001.nii-densenet_201-2matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.66it/s]


RUN: con_0001.nii-densenet_201-3matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.50it/s]


RUN: con_0001.nii-densenet_201-4matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.65it/s]


Start fmri_mirror/con_0003.nii
RUN: con_0003.nii-vgg11-0matchage


100%|██████████| 7/7 [00:01<00:00,  5.14it/s]


RUN: con_0003.nii-vgg11-1matchage


100%|██████████| 7/7 [00:01<00:00,  5.28it/s]


RUN: con_0003.nii-vgg11-2matchage


100%|██████████| 7/7 [00:01<00:00,  5.65it/s]


RUN: con_0003.nii-vgg11-3matchage


100%|██████████| 7/7 [00:01<00:00,  5.58it/s]


RUN: con_0003.nii-vgg11-4matchage


100%|██████████| 7/7 [00:01<00:00,  5.84it/s]


RUN: con_0003.nii-resnet10-0matchage


100%|██████████| 7/7 [00:01<00:00,  4.74it/s]


RUN: con_0003.nii-resnet10-1matchage


100%|██████████| 7/7 [00:01<00:00,  4.93it/s]


RUN: con_0003.nii-resnet10-2matchage


100%|██████████| 7/7 [00:01<00:00,  4.82it/s]


RUN: con_0003.nii-resnet10-3matchage


100%|██████████| 7/7 [00:01<00:00,  3.87it/s]


RUN: con_0003.nii-resnet10-4matchage


100%|██████████| 7/7 [00:01<00:00,  5.02it/s]


RUN: con_0003.nii-resnet50-0matchage


100%|██████████| 7/7 [00:01<00:00,  4.76it/s]


RUN: con_0003.nii-resnet50-1matchage


100%|██████████| 7/7 [00:01<00:00,  4.96it/s]


RUN: con_0003.nii-resnet50-2matchage


100%|██████████| 7/7 [00:01<00:00,  4.85it/s]


RUN: con_0003.nii-resnet50-3matchage


100%|██████████| 7/7 [00:01<00:00,  4.87it/s]


RUN: con_0003.nii-resnet50-4matchage


100%|██████████| 7/7 [00:01<00:00,  4.88it/s]


RUN: con_0003.nii-resnet50_med3d-0matchage


100%|██████████| 7/7 [00:01<00:00,  5.05it/s]


RUN: con_0003.nii-resnet50_med3d-1matchage


100%|██████████| 7/7 [00:01<00:00,  4.93it/s]


RUN: con_0003.nii-resnet50_med3d-2matchage


100%|██████████| 7/7 [00:01<00:00,  4.97it/s]


RUN: con_0003.nii-resnet50_med3d-3matchage


100%|██████████| 7/7 [00:01<00:00,  4.86it/s]


RUN: con_0003.nii-resnet50_med3d-4matchage


100%|██████████| 7/7 [00:01<00:00,  4.85it/s]


RUN: con_0003.nii-densenet_201-0matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.18it/s]


RUN: con_0003.nii-densenet_201-1matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.65it/s]


RUN: con_0003.nii-densenet_201-2matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.58it/s]


RUN: con_0003.nii-densenet_201-3matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.75it/s]


RUN: con_0003.nii-densenet_201-4matchage


nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
100%|██████████| 7/7 [00:01<00:00,  4.38it/s]


In [10]:
df.groupby(['model', 'data']).mean().round(4)

The default value of numeric_only in DataFrameGroupBy.mean is deprecated. In a future version, numeric_only will default to False. Either specify numeric_only or select only columns which should be valid for the function.


Unnamed: 0_level_0,Unnamed: 1_level_0,bacc,acc,precision,recall,f1,rocauc,specificity
model,data,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
densenet_201,wR2S.nii,0.5252,0.561,0.3855,0.4111,0.3879,0.5252,0.6393
densenet_201,wdtifit_FA.nii,0.6508,0.6355,0.477,0.7,0.5651,0.6508,0.6016
densenet_201,wdtifit_MD.nii,0.5798,0.6031,0.4388,0.5056,0.4597,0.5798,0.654
densenet_201,wfdt_paths.nii,0.5895,0.6298,0.4554,0.4611,0.456,0.5895,0.7179
densenet_201,wnfdt_paths.nii,0.6426,0.6353,0.4874,0.6667,0.551,0.6426,0.6186
densenet_201,wrealigne01.nii,0.5596,0.5764,0.4071,0.5056,0.4397,0.5596,0.6137
resnet10,wR2S.nii,0.5668,0.5667,0.4204,0.5667,0.4619,0.5668,0.567
resnet10,wdtifit_FA.nii,0.6219,0.6373,0.4758,0.5722,0.5176,0.6219,0.6715
resnet10,wdtifit_MD.nii,0.5865,0.6012,0.44,0.5389,0.4819,0.5865,0.634
resnet10,wfdt_paths.nii,0.5559,0.5612,0.4074,0.5389,0.4555,0.5559,0.5728


In [14]:
df

Unnamed: 0,model,data,version,bacc,acc,precision,recall,f1,rocauc,specificity
0,vgg11,con_0001.nii,0,0.501634,0.519231,0.347826,0.444444,0.390244,0.501634,0.558824
0,vgg11,con_0001.nii,1,0.529589,0.6,0.392857,0.305556,0.34375,0.529589,0.753623
0,vgg11,con_0001.nii,2,0.550725,0.619048,0.428571,0.333333,0.375,0.550725,0.768116
0,vgg11,con_0001.nii,3,0.512681,0.542857,0.357143,0.416667,0.384615,0.512681,0.608696
0,vgg11,con_0001.nii,4,0.562198,0.590476,0.414634,0.472222,0.441558,0.562198,0.652174
0,resnet10,con_0001.nii,0,0.485294,0.557692,0.321429,0.25,0.28125,0.485294,0.720588
0,resnet10,con_0001.nii,1,0.588164,0.580952,0.423077,0.611111,0.5,0.588164,0.565217
0,resnet10,con_0001.nii,2,0.490338,0.609524,0.307692,0.111111,0.163265,0.490338,0.869565
0,resnet10,con_0001.nii,3,0.625604,0.647619,0.487805,0.555556,0.519481,0.625604,0.695652
0,resnet10,con_0001.nii,4,0.530797,0.619048,0.409091,0.25,0.310345,0.530797,0.811594


In [9]:
df.to_csv('results_final_nocon.csv', index= False)

In [16]:
df = pd.read_csv('results_final_nocon.csv')

In [17]:
df

Unnamed: 0,model,data,version,bacc,acc,precision,recall,f1,rocauc,specificity
0,vgg11,wdtifit_FA.nii,0,0.568627,0.538462,0.400000,0.666667,0.500000,0.568627,0.470588
1,vgg11,wdtifit_FA.nii,1,0.528986,0.485714,0.363636,0.666667,0.470588,0.528986,0.391304
2,vgg11,wdtifit_FA.nii,2,0.606280,0.552381,0.417910,0.777778,0.543689,0.606280,0.434783
3,vgg11,wdtifit_FA.nii,3,0.576691,0.609524,0.435897,0.472222,0.453333,0.576691,0.681159
4,vgg11,wdtifit_FA.nii,4,0.673913,0.676190,0.521739,0.666667,0.585366,0.673913,0.681159
...,...,...,...,...,...,...,...,...,...,...
145,densenet_201,wrealigne01.nii,0,0.526961,0.586538,0.387097,0.333333,0.358209,0.526961,0.720588
146,densenet_201,wrealigne01.nii,1,0.567029,0.561905,0.403846,0.583333,0.477273,0.567029,0.550725
147,densenet_201,wrealigne01.nii,2,0.617150,0.619048,0.458333,0.611111,0.523810,0.617150,0.623188
148,densenet_201,wrealigne01.nii,3,0.550121,0.504762,0.378788,0.694444,0.490196,0.550121,0.405797


In [5]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from copy import deepcopy
import os 
import random
from PIL import Image
import torchvision
from pathlib import Path
from tqdm import tqdm
from skimage import io
import skimage
import cv2
import timm

from torch import nn
from torchinfo import summary
# from skorch import NeuralNetClassifier
# from sklearn.model_selection import GridSearchCV
from torchmetrics import Accuracy
from torchmetrics.classification import BinarySpecificity, BinaryPrecision, BinaryRecall, BinaryAccuracy, BinaryF1Score
from torchvision.models import vgg19
from PIL import Image
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc, brier_score_loss
from sklearn.model_selection import train_test_split, StratifiedKFold
import albumentations
from albumentations.pytorch import ToTensorV2

from models import get_model
from dataset import *
from utils import *
from run import cal_metrics_binary


torch.cuda.empty_cache() 
all_model = [
    'vgg11',
]
basepath = '/data1/TBM/TBM-AI_data/data_by_subject'
excel_path = '/data1/TBM/TBM-AI_data/subj_info.csv'
filenames = [
        'DTI/wdtifit_FA.nii',
        'DTI/wdtifit_MD.nii',
        'DTI/wfdt_paths.nii',
        'DTI/wnfdt_paths.nii',
        'T2s/wR2S.nii',
        'T2s/wrealigne01.nii',
    ]
mask_filename = 'DTI/wnodif_brain_mask.nii'
checkpoints_dir = './checkpoints'
tensorboard_dir = './runs'
checkpoints_suffix = ""

img_size = 80
max_epochs = 70
batch_size = 16
num_workers = 0
random_seed = 42
lr = 4e-5
torch.backends.cudnn.benchmark = True
pin_memory = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

df = pd.DataFrame(columns = ['model', 'data','version', 'bacc', 'acc', 'precision', 'recall', 'f1', 'rocauc', 'specificity'])
for filename in filenames:
    print(f"Start {filename}")
    file_list, labels, nodata = get_list_files_and_labels(basepath, excel_path, filename)
    mask_list, _, _ = get_list_files_and_labels(basepath, excel_path, mask_filename)

    train_tmp_indices = np.load('train_tmp_indices.npy', allow_pickle=True)
    train_indices = np.load('train_indices.npy', allow_pickle=True)
    val_indices = np.load('val_indices.npy', allow_pickle=True)
    test_indices = np.load('test_indices.npy', allow_pickle=True)
    file_list = np.array(file_list)
    labels = np.array(labels)

    ###########################################################################################################
    for model_name in all_model:
        for i, (imp_index, train_index, val_index, test_index) in enumerate(zip(train_tmp_indices, train_indices, val_indices, test_indices)):


            version = i
            imp_index = np.array(imp_index).astype(int)
            train_index = np.array(train_index).astype(int)
            val_index = np.array(val_index).astype(int)
            test_index = np.array(test_index).astype(int)
            
            data_name = filename.split('/')[-1]
            checkpoint_name = f"{data_name}-{model_name}-{version}{checkpoints_suffix}"
            tensorboard_suffix = f"{model_name}-{version}"
            print(f"RUN: {checkpoint_name}")

            X_train, X_val, y_train, y_val = file_list[imp_index], file_list[val_index], labels[imp_index], labels[val_index]

            img = sitk.ReadImage(X_train[0])
            img = sitk.GetArrayFromImage(img)
        
            print(img.max(), img.min())
# run 12590

Device:  cuda
Start DTI/wdtifit_FA.nii
RUN: wdtifit_FA.nii-vgg11-0
1.2212887 0.0
RUN: wdtifit_FA.nii-vgg11-1
1.2212887 0.0
RUN: wdtifit_FA.nii-vgg11-2
1.2212887 0.0
RUN: wdtifit_FA.nii-vgg11-3
1.1591011 0.0
RUN: wdtifit_FA.nii-vgg11-4
1.2212887 0.0
Start DTI/wdtifit_MD.nii
RUN: wdtifit_MD.nii-vgg11-0
0.0047819577 -0.0033672685
RUN: wdtifit_MD.nii-vgg11-1
0.0047819577 -0.0033672685
RUN: wdtifit_MD.nii-vgg11-2
0.0047819577 -0.0033672685
RUN: wdtifit_MD.nii-vgg11-3
0.004727119 -0.0029651276
RUN: wdtifit_MD.nii-vgg11-4
0.0047819577 -0.0033672685
Start DTI/wfdt_paths.nii
RUN: wfdt_paths.nii-vgg11-0
1175528.0 0.0
RUN: wfdt_paths.nii-vgg11-1
1175528.0 0.0
RUN: wfdt_paths.nii-vgg11-2
1175528.0 0.0
RUN: wfdt_paths.nii-vgg11-3
1757722.4 0.0
RUN: wfdt_paths.nii-vgg11-4
1175528.0 0.0
Start DTI/wnfdt_paths.nii
RUN: wnfdt_paths.nii-vgg11-0
5.7161307 0.0
RUN: wnfdt_paths.nii-vgg11-1
5.7161307 0.0
RUN: wnfdt_paths.nii-vgg11-2
5.7161307 0.0
RUN: wnfdt_paths.nii-vgg11-3
7.664164 0.0
RUN: wnfdt_paths.nii