In [1]:
import os
import random

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import SimpleITK as sitk
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import monai.transforms as monai_transforms

from dataset import MRIDataset, get_loader
from models import C3D, generate_model, ResNet
from train import epoch_iter, add_metrics, save_checkpoint, get_metrics

# Preprocessing

In [2]:
"/data1/TBM/data_for_AI/data/spatial_normalized_pt_1/"
basepath = "/data1/TBM/data_for_AI/data/realigned_pt/"
csvpath = '/data1/TBM/data_for_AI/subjects_info/final_TBM_subjects_info.csv'

form = 'process'

# use for raw file
modality = "T2s"
use_file = "R2S.nii"

# use for process file
prefix = 'r'
suffix = '_fdt_paths.nii'

random_state = 65489132
torch.manual_seed(random_state)
random.seed(random_state)
np.random.seed(random_state)

In [3]:
df_data = pd.read_csv(csvpath)

filenames = []
labels_data = []

mean_data = []
var_data = []

def get_path_and_label(name, label, form = 'raw', prefix = 'wr', suffix = '_fdt_paths.nii'):
    
    if form == 'raw':
        if label.lower() == 'mci':
            category = 1
            fdt_paths_path = os.path.join(basepath, 'MCI',name, modality, use_file)
        elif label.lower() == 'normal':
            category = 0
            fdt_paths_path = os.path.join(basepath, 'Normal',name, modality, use_file)
        elif label.lower() == 'mmd':
            category = 2
            fdt_paths_path = os.path.join(basepath, 'AD',name, modality, use_file)
        else:
            raise ValueError(f"No label name {label}")

    else:
        if label.lower() == 'mci':
            category = 1
        elif label.lower() == 'normal':
            category = 0

        elif label.lower() == 'mmd':
            category = 2
        else:
            raise ValueError(f"No label name {label}")
        fdt_paths_path = os.path.join(basepath, prefix + name + suffix)
    return fdt_paths_path, category


first_shape_flag = False
first_shape = (0,0,0)

for name, label in zip(df_data.label_id, df_data['label']):

    fdt_paths_path, category = get_path_and_label(name, label, form = form, prefix = prefix, suffix = suffix)
        

    try:
        img = sitk.ReadImage(fdt_paths_path)
        img_array = sitk.GetArrayFromImage( img)

        if not first_shape_flag:
            first_shape = img_array.shape
            first_shape_flag = True

        elif img_array.shape[0] != first_shape[0] or img_array.shape[1] != first_shape[1] or img_array.shape[2] != first_shape[2]:
            print(img_array.shape)
            continue
    except:
        print(name)
        continue

    labels_data.append(category)
    filenames.append(fdt_paths_path)
    
    mean_img_array = img_array[img_array>0].mean()
    var_img_array = img_array[img_array>0].var()
    
    mean_data.append(mean_img_array)
    var_data.append(var_img_array)
    
assert len(labels_data) == len(filenames)

mean_data = np.mean(mean_data)
std_data = np.sqrt(np.mean(var_data))

batch_size = 8
img_size = 64
test_size = 0.1
val_size = 0.2
val_size = val_size/(1-test_size)


x_train, x_test, y_train,  y_test = train_test_split(filenames, labels_data,  test_size= test_size, random_state=random_state)
x_train, x_val, y_train,  y_val = train_test_split(x_train, y_train,  test_size= val_size, random_state=random_state)

def get_transform(mean, std, mode= 'train'):
    transforms_list = []
    
    def standard(img):
        return (img - mean)/std
    if mode == 'train':
        # transforms_list.append(monai_transforms.RandRotate90(prob = 0.3))
        transforms_list.append(
            monai_transforms.RandGaussianSmooth(
            sigma_x=(0.1, 0.5), 
            sigma_y=(0.1, 0.5), 
            sigma_z=(0.1, 0.5), prob=0.3))
        transforms_list.append(monai_transforms.RandAffine(prob = 0.3, translate_range  =[(-2,2), (-2,2), (-2,2)], padding_mode  = 'zeros'))


    transforms_list.append(torch.Tensor)
    transforms_list.append(standard)
    return monai_transforms.Compose(transforms_list)

train_transform = get_transform(mean_data, std_data, mode = 'train')
val_transform = get_transform(mean_data, std_data, mode = 'val')

train_loader = get_loader(x_train, y_train, train_transform, mode = 'train', batch_size =batch_size, img_size = img_size)
val_loader = get_loader(x_val, y_val,val_transform, mode = 'val', batch_size =batch_size, img_size = img_size)
test_loader = get_loader(x_test, y_test,val_transform, mode = 'test', batch_size =batch_size, img_size = img_size)


normal_0766
normal_0768
normal_0769
mci_0116
mci_0120


# Train

In [13]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = generate_model(10, [64, 128, 256, 512], n_input_channels = 1,conv1_t_size = 5, conv1_t_stride=1, n_classes = 2)
# model = C3D(num_classes = 2)
model = model.to(device)

loss_function = nn.CrossEntropyLoss()


optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
# optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)


In [15]:

writer = SummaryWriter()

current_epoch = 0
epochs = 30
best_f1 = 0
best_auc = 0

savepath = "resnet10_realigned"

for i in range(current_epoch , epochs+1):

    train_loss, preds_prob, labels = epoch_iter(train_loader, model, loss_function, optimizer, device)
    print(f"train_loss: {train_loss};")
    writer.add_scalar('train_loss', train_loss, i)
    _, _ = add_metrics(writer, preds_prob, labels, train_loss, i, mode = 'train')

    val_loss, preds_prob, labels = epoch_iter(val_loader, model, loss_function, optimizer, device, mode = 'val')
    print(f"val_loss: {val_loss};")
    writer.add_scalar('val_loss', val_loss, i)
    f1, rocauc = add_metrics(writer, preds_prob, labels, val_loss, i, mode = 'val')
    if f1>best_f1:
        best_f1 = f1
        save_checkpoint(savepath, model, optimizer, 'f1', i)
    if rocauc>best_auc:
        best_auc = rocauc
        save_checkpoint(savepath, model, optimizer, 'rocauc', i)
    

**************train*************


100%|██████████| 85/85 [00:18<00:00,  4.67it/s]


train_loss: 0.6157857109518612;
loss_train: 0.616;
acc_train: 0.638; bacc_train: 0.638; precision_train: 0.638; recall_train: 0.638; f1_train: 0.638; rocauc_train: 0.707;
**************val*************


100%|██████████| 25/25 [00:02<00:00,  8.79it/s]


val_loss: 0.3561427307128906;
loss_val: 0.356;
acc_val: 0.871; bacc_val: 0.552; precision_val: 0.739; recall_val: 0.552; f1_val: 0.562; rocauc_val: 0.739;
**************train*************


100%|██████████| 85/85 [00:18<00:00,  4.60it/s]


train_loss: 0.49788225959329047;
loss_train: 0.498;
acc_train: 0.752; bacc_train: 0.752; precision_train: 0.752; recall_train: 0.752; f1_train: 0.752; rocauc_train: 0.835;
**************val*************


100%|██████████| 25/25 [00:03<00:00,  7.57it/s]


val_loss: 1.2436859893798828;
loss_val: 1.244;
acc_val: 0.428; bacc_val: 0.637; precision_val: 0.574; recall_val: 0.637; f1_val: 0.409; rocauc_val: 0.75;
**************train*************


100%|██████████| 85/85 [00:20<00:00,  4.24it/s]


train_loss: 0.37941665649414064;
loss_train: 0.379;
acc_train: 0.843; bacc_train: 0.843; precision_train: 0.844; recall_train: 0.843; f1_train: 0.843; rocauc_train: 0.911;
**************val*************


100%|██████████| 25/25 [00:03<00:00,  7.45it/s]


val_loss: 1.63440185546875;
loss_val: 1.634;
acc_val: 0.448; bacc_val: 0.6; precision_val: 0.55; recall_val: 0.6; f1_val: 0.417; rocauc_val: 0.735;
**************train*************


100%|██████████| 85/85 [00:19<00:00,  4.44it/s]


train_loss: 0.3685286802404067;
loss_train: 0.369;
acc_train: 0.833; bacc_train: 0.832; precision_train: 0.833; recall_train: 0.832; f1_train: 0.832; rocauc_train: 0.914;
**************val*************


100%|██████████| 25/25 [00:02<00:00,  9.59it/s]


val_loss: 4.562681274414063;
loss_val: 4.563;
acc_val: 0.139; bacc_val: 0.503; precision_val: 0.567; recall_val: 0.503; f1_val: 0.125; rocauc_val: 0.716;
**************train*************


100%|██████████| 85/85 [00:15<00:00,  5.43it/s]


train_loss: 0.3752615535960478;
loss_train: 0.375;
acc_train: 0.839; bacc_train: 0.838; precision_train: 0.84; recall_train: 0.838; f1_train: 0.838; rocauc_train: 0.914;
**************val*************


100%|██████████| 25/25 [00:02<00:00,  9.27it/s]


val_loss: 0.44058448791503907;
loss_val: 0.441;
acc_val: 0.825; bacc_val: 0.541; precision_val: 0.563; recall_val: 0.541; f1_val: 0.546; rocauc_val: 0.694;
**************train*************


100%|██████████| 85/85 [00:21<00:00,  3.95it/s]


train_loss: 0.27945074193617875;
loss_train: 0.279;
acc_train: 0.883; bacc_train: 0.882; precision_train: 0.883; recall_train: 0.882; f1_train: 0.882; rocauc_train: 0.952;
**************val*************


100%|██████████| 25/25 [00:03<00:00,  6.32it/s]
Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.


val_loss: 0.7614437866210938;
loss_val: 0.761;
acc_val: 0.866; bacc_val: 0.5; precision_val: 0.433; recall_val: 0.5; f1_val: 0.464; rocauc_val: 0.681;
**************train*************


100%|██████████| 85/85 [00:21<00:00,  3.95it/s]


train_loss: 0.2775578442741843;
loss_train: 0.278;
acc_train: 0.891; bacc_train: 0.888; precision_train: 0.89; recall_train: 0.888; f1_train: 0.889; rocauc_train: 0.952;
**************val*************


100%|██████████| 25/25 [00:03<00:00,  7.75it/s]


val_loss: 0.5219845581054687;
loss_val: 0.522;
acc_val: 0.866; bacc_val: 0.533; precision_val: 0.687; recall_val: 0.533; f1_val: 0.53; rocauc_val: 0.695;
**************train*************


100%|██████████| 85/85 [00:16<00:00,  5.29it/s]


train_loss: 0.2962338167078355;
loss_train: 0.296;
acc_train: 0.879; bacc_train: 0.879; precision_train: 0.879; recall_train: 0.879; f1_train: 0.879; rocauc_train: 0.946;
**************val*************


100%|██████████| 25/25 [00:03<00:00,  7.79it/s]


val_loss: 1.7348272705078125;
loss_val: 1.735;
acc_val: 0.433; bacc_val: 0.64; precision_val: 0.575; recall_val: 0.64; f1_val: 0.413; rocauc_val: 0.624;
**************train*************


100%|██████████| 85/85 [00:16<00:00,  5.09it/s]


train_loss: 0.2895456875071806;
loss_train: 0.29;
acc_train: 0.879; bacc_train: 0.879; precision_train: 0.879; recall_train: 0.879; f1_train: 0.879; rocauc_train: 0.949;
**************val*************


100%|██████████| 25/25 [00:03<00:00,  7.52it/s]


val_loss: 0.7965220642089844;
loss_val: 0.797;
acc_val: 0.686; bacc_val: 0.688; precision_val: 0.594; recall_val: 0.688; f1_val: 0.581; rocauc_val: 0.703;
**************train*************


100%|██████████| 85/85 [00:17<00:00,  4.85it/s]


train_loss: 0.27412080203785616;
loss_train: 0.274;
acc_train: 0.891; bacc_train: 0.891; precision_train: 0.891; recall_train: 0.891; f1_train: 0.891; rocauc_train: 0.953;
**************val*************


100%|██████████| 25/25 [00:03<00:00,  7.73it/s]


val_loss: 1.013781280517578;
loss_val: 1.014;
acc_val: 0.66; bacc_val: 0.69; precision_val: 0.592; recall_val: 0.69; f1_val: 0.566; rocauc_val: 0.725;
**************train*************


100%|██████████| 85/85 [00:16<00:00,  5.15it/s]


train_loss: 0.22108728745404413;
loss_train: 0.221;
acc_train: 0.914; bacc_train: 0.914; precision_train: 0.915; recall_train: 0.914; f1_train: 0.914; rocauc_train: 0.97;
**************val*************


100%|██████████| 25/25 [00:03<00:00,  7.80it/s]


val_loss: 0.5126157760620117;
loss_val: 0.513;
acc_val: 0.84; bacc_val: 0.583; precision_val: 0.62; recall_val: 0.583; f1_val: 0.595; rocauc_val: 0.74;
**************train*************


100%|██████████| 85/85 [00:17<00:00,  4.80it/s]


train_loss: 0.21440517201143153;
loss_train: 0.214;
acc_train: 0.919; bacc_train: 0.918; precision_train: 0.916; recall_train: 0.918; f1_train: 0.917; rocauc_train: 0.971;
**************val*************


100%|██████████| 25/25 [00:02<00:00,  9.29it/s]


val_loss: 0.6653351593017578;
loss_val: 0.665;
acc_val: 0.737; bacc_val: 0.621; precision_val: 0.574; recall_val: 0.621; f1_val: 0.579; rocauc_val: 0.701;
**************train*************


100%|██████████| 85/85 [00:17<00:00,  6.01it/s]

# Test

In [16]:
savepath = "resnet10_realigned"
checkpoints_dir = f"checkpoints/{savepath}"
checkpoint_type = 'f1'
model_path = os.path.join(checkpoints_dir, f'model-{checkpoint_type}.ckpt')

model_state_dict = torch.load(model_path, map_location=torch.device(device))
model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [17]:

mode = 'test'
val_loss, preds_prob, labels = epoch_iter(test_loader, model, loss_function, optimizer, device, mode = 'val')
print(f"val_loss: {val_loss};")
acc, bacc, precision, recall, f1,rocauc = get_metrics(preds_prob, labels)
print(
        f"acc_{mode}: {acc}; bacc_{mode}: {bacc};" +
        f" precision_{mode}: {precision}; recall_{mode}: {recall}; f1_{mode}: {f1};" +
        f" rocauc_{mode}: {rocauc};"
        )

mode = 'val'
val_loss, preds_prob, labels = epoch_iter(val_loader, model, loss_function, optimizer, device, mode = 'val')
print(f"val_loss: {val_loss};")
acc, bacc, precision, recall, f1,rocauc = get_metrics(preds_prob, labels)
print(
        f"acc_{mode}: {acc}; bacc_{mode}: {bacc};" +
        f" precision_{mode}: {precision}; recall_{mode}: {recall}; f1_{mode}: {f1};" +
        f" rocauc_{mode}: {rocauc};"
        )

**************val*************


100%|██████████| 13/13 [00:02<00:00,  6.22it/s]


val_loss: 0.8945473891038161;
acc_test: 0.845; bacc_test: 0.699; precision_test: 0.781; recall_test: 0.699; f1_test: 0.726; rocauc_test: 0.839;
**************val*************


100%|██████████| 25/25 [00:03<00:00,  7.54it/s]

val_loss: 0.692416000366211;
acc_val: 0.84; bacc_val: 0.615; precision_val: 0.638; recall_val: 0.615; f1_val: 0.625; rocauc_val: 0.713;





### Resnet10 top rocauc

val_loss: 0.8527595079862155;
acc_test: 0.619; bacc_test: 0.741; precision_test: 0.661; recall_test: 0.741; f1_test: 0.598; rocauc_test: 0.861;

val_loss: 1.2436859893798828;
acc_val: 0.428; bacc_val: 0.637; precision_val: 0.574; recall_val: 0.637; f1_val: 0.409; rocauc_val: 0.75;

### F1

val_loss: 0.8945473891038161;
acc_test: 0.845; bacc_test: 0.699; precision_test: 0.781; recall_test: 0.699; f1_test: 0.726; rocauc_test: 0.839;

val_loss: 0.692416000366211;
acc_val: 0.84; bacc_val: 0.615; precision_val: 0.638; recall_val: 0.615; f1_val: 0.625; rocauc_val: 0.713;

---

### simple CNN top rocauc

val_loss: 0.7132892608642578;
acc_test: 0.835; bacc_test: 0.693; precision_test: 0.755; recall_test: 0.693; f1_test: 0.715; rocauc_test: 0.824;

val_loss: 0.5662627029418945;
acc_val: 0.835; bacc_val: 0.612; precision_val: 0.629; recall_val: 0.612; f1_val: 0.62; rocauc_val: 0.754;

### F1

val_loss: 1.0085984743558443;
acc_test: 0.825; bacc_test: 0.612; precision_test: 0.774; recall_test: 0.612; f1_test: 0.634; rocauc_test: 0.821;

val_loss: 0.6464328002929688;
acc_val: 0.845; bacc_val: 0.618; precision_val: 0.648; recall_val: 0.618; f1_val: 0.63; rocauc_val: 0.71;

---


In [7]:
### simple CNN top rocauc

val_loss: 0.7132892608642578;
acc_test: 0.835; bacc_test: 0.693; precision_test: 0.755; recall_test: 0.693; f1_test: 0.715; rocauc_test: 0.824;