In [30]:
import sys

import monai
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim

import pandas as pd
import numpy as np

from tqdm import tqdm

from MINIT.minit import MINiT

import os

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [11]:
adni_merge_1 = pd.read_csv('our_subj_adnimerge.csv').sort_values(['PTID', 'EXAMDATE']).reset_index(drop=True)
adni_merge_2 = pd.read_csv('ADNIMERGE.csv').sort_values(['PTID', 'EXAMDATE']).reset_index(drop=True)
adni_merge_1 = adni_merge_1[['PTID', 'VISCODE', 'DX_bl']]
adni_merge_2 = adni_merge_2[['PTID', 'VISCODE', 'DX_bl']]

  adni_merge_2 = pd.read_csv('ADNIMERGE.csv').sort_values(['PTID', 'EXAMDATE']).reset_index(drop=True)


In [13]:
# какое количество классов во всем ADNI
from collections import Counter
Counter(pd.read_csv('ADNIMERGE.csv')['DX_bl'].values)

  Counter(pd.read_csv('ADNIMERGE.csv')['DX_bl'].values)


Counter({'CN': 4512,
         'AD': 1667,
         'LMCI': 5037,
         'SMC': 1037,
         'EMCI': 2740,
         nan: 10})

In [14]:
# какое количество классов в нашем сэмпле ADNI
Counter(pd.read_csv('our_subj_adnimerge.csv')['DX_bl'].values)

Counter({'CN': 775, 'AD': 409, 'LMCI': 1430})

# Делаем dataset

In [26]:
BASE_PATH = '/home/druzhininapo/nfs/caps/subjects/'

In [32]:
def make_dataset(adni_merge, resize=64):
    resizer_64 = monai.transforms.Resize((resize, resize, resize))
    label2num = {
        'AD': 1,
        'CN': 0,
    }
    dataset = {}
    for subj in tqdm(np.unique(adni_merge_1[['PTID']].values)):
        subj_path = subj.replace('_', '')
        for ses in os.listdir(os.path.join(BASE_PATH, f"sub-ADNI{subj_path}")):
            subj_samples = adni_merge[adni_merge['PTID'] == subj]
            if subj_samples[subj_samples['VISCODE'] == ses.split('-')[1].lower()].shape[0] == 0:
                continue
            label_text = subj_samples[subj_samples['VISCODE'] == ses.split('-')[1].lower()]['DX_bl'].values[0]
            if label_text not in label2num:
                continue
            
            path_to_tensor = os.path.join(BASE_PATH, f"sub-ADNI{subj_path}", ses, 'deeplearning_prepare_data/image_based/t1_linear/')
            image_tensor = torch.load(os.path.join(path_to_tensor, os.listdir(path_to_tensor)[0]))
            
            if resize:
                image_tensor = resizer_64(image_tensor).get_array()

            if subj in dataset:
                dataset[subj].append((image_tensor, label2num[label_text]))
            else:
                dataset[subj] = [(image_tensor, label2num[label_text])]
    
    return dataset

In [33]:
dataset_dict = make_dataset(adni_merge_1)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 771/771 [03:12<00:00,  4.00it/s]


# Train

In [35]:
class MRIDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        'Initialization'
        self.labels = y
        self.X = X # Maps index to id

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.X)

    def __getitem__(self, index):
        'Generates one sample of data'
        X = torch.tensor(self.X[index])

        return {'tensor': X, 'label': self.labels[index]}

In [54]:
from sklearn.model_selection import KFold
from sklearn.metrics import precision_score, recall_score, roc_auc_score

def get_dataloder(dataset_dict, batch_train_size=4, n_splits=5, verbose=True):

    def make_tensors(set_subj):
        X = []
        y = []
        for subj in set_subj:
            for tensor, label in dataset_dict[subj]:
                X.append(tensor)
                y.append(label)

        return X, y
    
    all_keys_subj = np.sort(list(dataset_dict.keys()))
    kfold = KFold(n_splits=N_SPLITS, random_state=29, shuffle=True)
    for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_subj)):
        
        if verbose:
            print(f"Number fold: {i+1}")
    
        X_train, y_train = make_tensors(all_keys_subj[train_idx])
        X_test, y_test = make_tensors(all_keys_subj[test_idx])

        mri_dataset_train = MRIDataset(X_train, y_train)
        mri_dataset_test = MRIDataset(X_test, y_test)

        mri_dataloader_train = torch.utils.data.DataLoader(mri_dataset_train, batch_size=batch_train_size, shuffle=True)
        mri_dataloader_test = torch.utils.data.DataLoader(mri_dataset_test, batch_size=1, shuffle=False)
    
        yield mri_dataloader_train, mri_dataloader_test

In [57]:
def train_vit(dataset_dict, epoch_num=3, thr=0.5, use_checkpoint=False):
    precision_folds = []
    recall_folds = []
    auc_folds = []
    
    for train_dataloader, test_dataloader in get_dataloder(dataset_dict):

        net = MINiT(
            block_size = 16,
            image_size = 64,
            patch_size = 8,
            num_classes = 1,
            channels = 1,
            dim = 256,
            depth = 6,
            heads = 4,
            mlp_dim = 309
        ).to(device)

        if use_checkpoint:
            checkpoint = torch.load("MINIT/minit.pt")
            state = checkpoint['model_state_dict']

        optimizer = optim.Adam(net.parameters(), lr=5e-4)
        loss = nn.CrossEntropyLoss()

        precision_folds = []
        recall_folds = []
        auc_folds = []

        for epoch in range(epoch_num):
            print(f"Train step")

            net.train()
            epoch_train_loss = 0.0
            for i, batch in enumerate(mri_dataloader_train):
                brain_img = batch['tensor'].to(device)
                labels = batch['label'].to(device)

                optimizer.zero_grad()
                outputs = net(brain_img)
                loss_calc = loss(torch.squeeze(outputs), labels.float())

                loss_calc.backward()
                optimizer.step()

                epoch_train_loss += loss_calc.item()
                if i % 24 == 0:
                    print(f"Epoch: {epoch+1}, step: {i+1} / {len(mri_dataloader_train)}, loss: {epoch_train_loss / (i+1)}")

            print(f"Test step")

            net.eval()
            labels_pred = []
            logits_pred = []
            for i, batch in enumerate(mri_dataloader_test):
                brain_img = batch['tensor'].to(device)
                labels = batch['label'].to(device)

                with torch.no_grad():
                    output_prob = torch.sigmoid(net(brain_img)).item()
                    logits_pred.append(output_prob)

                    if output_prob > thr:
                        labels_pred.append(1)
                    else:
                        labels_pred.append(0)

                        
            precision_score_fold = precision_score(y_test, labels_pred)
            recall_score_fold = recall_score(y_test, labels_pred)
            auc_fold = roc_auc_score(y_test, logits_pred)
            
            precision_folds.append(precision_score_fold)
            recall_folds.append(recall_score_fold)
            auc_folds.append(auc_fold)
            
            print(f"Precision: {precision_score_fold}")
            print(f"Recall: {recall_score_fold}")  
            print(f"AUC: {auc_fold}")  
            
    return precision_folds, recall_folds, auc_folds

In [None]:
precision_folds, recall_folds, auc_folds = train_vit(dataset_dict)

Number fold: 1
Train step
Epoch: 1, step: 1 / 232, loss: 2.783334732055664
Epoch: 1, step: 25 / 232, loss: 2.2007543087005614
Epoch: 1, step: 49 / 232, loss: 2.1319573217508743
Epoch: 1, step: 73 / 232, loss: 2.030388921091001
Epoch: 1, step: 97 / 232, loss: 1.9958675497585965
Epoch: 1, step: 121 / 232, loss: 1.959923521546293
Epoch: 1, step: 145 / 232, loss: 1.9922145991489806
Epoch: 1, step: 169 / 232, loss: 2.0103664786152584
Epoch: 1, step: 193 / 232, loss: 1.9885836883684513
Epoch: 1, step: 217 / 232, loss: 2.0080937120359614
Test step
Precision: 0.2265625
Recall: 0.9666666666666667
AUC: 0.5659090909090908
Train step
Epoch: 2, step: 1 / 232, loss: 1.5438249111175537
Epoch: 2, step: 25 / 232, loss: 2.3925910663604735
Epoch: 2, step: 49 / 232, loss: 1.9885262992917274
Epoch: 2, step: 73 / 232, loss: 1.9302974800540977
Epoch: 2, step: 97 / 232, loss: 1.8899679380593841
Epoch: 2, step: 121 / 232, loss: 1.8538431790003107
Epoch: 2, step: 145 / 232, loss: 1.8560467157898277
Epoch: 2, st

  _warn_prf(average, modifier, msg_start, len(result))


Precision: 0.0
Recall: 0.0
AUC: 0.6626262626262626
Train step
Epoch: 3, step: 1 / 232, loss: 1.3412561416625977
Epoch: 3, step: 25 / 232, loss: 2.049325551986694
Epoch: 3, step: 49 / 232, loss: 1.9530858415730146
Epoch: 3, step: 73 / 232, loss: 1.9508703334690773
Epoch: 3, step: 97 / 232, loss: 1.9271789998123325
Epoch: 3, step: 121 / 232, loss: 1.890053813615121
Epoch: 3, step: 145 / 232, loss: 1.8553521908562758
Epoch: 3, step: 169 / 232, loss: 1.7850371360911068
Epoch: 3, step: 193 / 232, loss: 1.7678617848749296
Epoch: 3, step: 217 / 232, loss: 1.775182299951117
Test step
Precision: 1.0
Recall: 0.03333333333333333
AUC: 0.6312289562289563
Number fold: 2
Train step
Epoch: 1, step: 1 / 232, loss: 4.139474868774414
Epoch: 1, step: 25 / 232, loss: 1.527468945980072
Epoch: 1, step: 49 / 232, loss: 1.9085381286484855
Epoch: 1, step: 73 / 232, loss: 1.8276267043531758
Epoch: 1, step: 97 / 232, loss: 1.8923038047613556
Epoch: 1, step: 121 / 232, loss: 1.8316061043542278
Epoch: 1, step: 145 

  _warn_prf(average, modifier, msg_start, len(result))


Precision: 0.0
Recall: 0.0
AUC: 0.5017676767676768
Train step
Epoch: 2, step: 1 / 232, loss: 1.3590525388717651
Epoch: 2, step: 25 / 232, loss: 1.3724038290977478
Epoch: 2, step: 49 / 232, loss: 1.6351589426702382
Epoch: 2, step: 73 / 232, loss: 1.8877643232476222
Epoch: 2, step: 97 / 232, loss: 1.8773271994492442
Epoch: 2, step: 121 / 232, loss: 1.8045832630523966
Epoch: 2, step: 145 / 232, loss: 1.8390445423537287
Epoch: 2, step: 169 / 232, loss: 1.9015236011976322
Epoch: 2, step: 193 / 232, loss: 1.930279369953383
Epoch: 2, step: 217 / 232, loss: 1.9043542378783775
Test step
Precision: 0.3364485981308411
Recall: 0.6
AUC: 0.6622053872053872
Train step
Epoch: 3, step: 1 / 232, loss: 2.7793192863464355
Epoch: 3, step: 25 / 232, loss: 1.9847800016403199
Epoch: 3, step: 49 / 232, loss: 1.6903032088766292
Epoch: 3, step: 73 / 232, loss: 1.9176592173641676
Epoch: 3, step: 97 / 232, loss: 1.9401632793170889
Epoch: 3, step: 121 / 232, loss: 1.9340099480526507
Epoch: 3, step: 145 / 232, loss:

  _warn_prf(average, modifier, msg_start, len(result))


Precision: 0.0
Recall: 0.0
AUC: 0.6099326599326599
Number fold: 4
Train step
Epoch: 1, step: 1 / 232, loss: 4.156723499298096
Epoch: 1, step: 25 / 232, loss: 2.415364990234375
Epoch: 1, step: 49 / 232, loss: 2.4104506361241245
Epoch: 1, step: 73 / 232, loss: 2.373163540069371
Epoch: 1, step: 97 / 232, loss: 2.192198198480704
Epoch: 1, step: 121 / 232, loss: 2.1660333797951377
Epoch: 1, step: 145 / 232, loss: 2.1821791430999493
Epoch: 1, step: 169 / 232, loss: 2.155454390500424
Epoch: 1, step: 193 / 232, loss: 2.1221232695283048
Epoch: 1, step: 217 / 232, loss: 2.0579833998108787
Test step
Precision: 0.3088235294117647
Recall: 0.35
AUC: 0.6324915824915825
Train step
Epoch: 2, step: 1 / 232, loss: 1.6811665296554565
Epoch: 2, step: 25 / 232, loss: 2.390007758140564
Epoch: 2, step: 49 / 232, loss: 2.0251454625810896
Epoch: 2, step: 73 / 232, loss: 1.9436180922266555
Epoch: 2, step: 97 / 232, loss: 1.9692222141113478
Epoch: 2, step: 121 / 232, loss: 1.8530534350674999
Epoch: 2, step: 145 /

# OLD

In [243]:
# non-pretrain + finetune
from sklearn.metrics import precision_score, recall_score

for epoch in range(3):
    net.train()
    epoch_train_loss = 0.0
    for i, batch in enumerate(mri_dataloader_train):
        brain_img = batch['tensor'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = net(brain_img)
        loss_calc = loss(torch.squeeze(outputs), labels.float())

        loss_calc.backward()
        optimizer.step()
        
        epoch_train_loss += loss_calc.item()
        if i % 8 == 0:
            print(f"Epoch: {epoch}, step: {i+1} / {len(mri_dataloader_train)}, loss: {epoch_train_loss / (i+1)}")

    net.eval()
    labels_pred = []
    for i, batch in enumerate(mri_dataloader_test):
        brain_img = batch['tensor'].to(device)
        labels = batch['label'].to(device)

        with torch.no_grad():
            output_prob = torch.sigmoid(net(brain_img)).item()
            if output_prob > 0.5:
                labels_pred.append(1)
            else:
                labels_pred.append(0)

    print(precision_score(y_test, labels_pred))
    print(recall_score(y_test, labels_pred))

Epoch: 0, step: 1 / 281, loss: 0.0
Epoch: 0, step: 9 / 281, loss: 2.394197834862603
Epoch: 0, step: 17 / 281, loss: 2.000979598830728
Epoch: 0, step: 25 / 281, loss: 2.072249717712402
Epoch: 0, step: 33 / 281, loss: 1.838512633786057
Epoch: 0, step: 41 / 281, loss: 1.8390611381065556
Epoch: 0, step: 49 / 281, loss: 1.676484385315253
Epoch: 0, step: 57 / 281, loss: 1.7525349311661302
Epoch: 0, step: 65 / 281, loss: 1.7216973625696623
Epoch: 0, step: 73 / 281, loss: 1.7372005336905179
Epoch: 0, step: 81 / 281, loss: 1.7519290660634452
Epoch: 0, step: 89 / 281, loss: 1.7796537909614907
Epoch: 0, step: 97 / 281, loss: 1.7823393240417402
Epoch: 0, step: 105 / 281, loss: 1.7946016629536947
Epoch: 0, step: 113 / 281, loss: 1.804586623622253
Epoch: 0, step: 121 / 281, loss: 1.7872062537295759
Epoch: 0, step: 129 / 281, loss: 1.7722801626190658
Epoch: 0, step: 137 / 281, loss: 1.7688390703967019
Epoch: 0, step: 145 / 281, loss: 1.7726228578337309
Epoch: 0, step: 153 / 281, loss: 1.7720017296816

  _warn_prf(average, modifier, msg_start, len(result))


0.0
0.0
Epoch: 2, step: 1 / 281, loss: 0.9903404712677002
Epoch: 2, step: 9 / 281, loss: 1.2091296911239624
Epoch: 2, step: 17 / 281, loss: 1.2393223608241362
Epoch: 2, step: 25 / 281, loss: 1.4782209968566895
Epoch: 2, step: 33 / 281, loss: 1.306702633247231
Epoch: 2, step: 41 / 281, loss: 1.8242759279361584
Epoch: 2, step: 49 / 281, loss: 1.6870434719080827
Epoch: 2, step: 57 / 281, loss: 1.7964918945442165
Epoch: 2, step: 65 / 281, loss: 1.912631487617126
Epoch: 2, step: 73 / 281, loss: 1.8508511142779702
Epoch: 2, step: 81 / 281, loss: 1.8840244005859634
Epoch: 2, step: 89 / 281, loss: 1.8810518161299523
Epoch: 2, step: 97 / 281, loss: 1.924387781736777
Epoch: 2, step: 105 / 281, loss: 1.8459877461194991
Epoch: 2, step: 113 / 281, loss: 1.8139593773974783
Epoch: 2, step: 121 / 281, loss: 1.7698975948016507
Epoch: 2, step: 129 / 281, loss: 1.7190615355737449
Epoch: 2, step: 137 / 281, loss: 1.7104708726932532
Epoch: 2, step: 145 / 281, loss: 1.6974194407976906
Epoch: 2, step: 153 / 

In [168]:
# non-pretrain + finetune
from sklearn.metrics import precision_score, recall_score

for epoch in range(3):
    net.train()
    epoch_train_loss = 0.0
    for i, batch in enumerate(mri_dataloader_train):
        brain_img = batch['tensor'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = net(brain_img)
        loss_calc = loss(torch.squeeze(outputs), labels.float())

        loss_calc.backward()
        optimizer.step()
        
        epoch_train_loss += loss_calc.item()
        if i % 8 == 0:
            print(f"Epoch: {epoch}, step: {i+1} / {len(mri_dataloader_train)}, loss: {epoch_train_loss / (i+1)}")

    net.eval()
    labels_pred = []
    for i, batch in enumerate(mri_dataloader_test):
        brain_img = batch['tensor'].to(device)
        labels = batch['label'].to(device)

        with torch.no_grad():
            output_prob = torch.sigmoid(net(brain_img)).item()
            if output_prob > 0.5:
                labels_pred.append(1)
            else:
                labels_pred.append(0)

    print(precision_score(y_test, labels_pred))
    print(recall_score(y_test, labels_pred))

Epoch: 0, step: 1 / 281, loss: 2.7592720985412598
Epoch: 0, step: 9 / 281, loss: 2.0186812612745495
Epoch: 0, step: 17 / 281, loss: 2.0448199720943676
Epoch: 0, step: 25 / 281, loss: 1.946508674621582
Epoch: 0, step: 33 / 281, loss: 1.8931593822710442
Epoch: 0, step: 41 / 281, loss: 1.893749219615285
Epoch: 0, step: 49 / 281, loss: 1.892821985848096
Epoch: 0, step: 57 / 281, loss: 1.9896142169048912
Epoch: 0, step: 65 / 281, loss: 1.8508722672095665
Epoch: 0, step: 73 / 281, loss: 1.9114093437586746
Epoch: 0, step: 81 / 281, loss: 1.9593559150342588
Epoch: 0, step: 89 / 281, loss: 1.9997915217045987
Epoch: 0, step: 97 / 281, loss: 1.9489044049351485
Epoch: 0, step: 105 / 281, loss: 1.9180229425430297
Epoch: 0, step: 113 / 281, loss: 1.9272019588841802
Epoch: 0, step: 121 / 281, loss: 1.9498930420757326
Epoch: 0, step: 129 / 281, loss: 1.9260694749595584
Epoch: 0, step: 137 / 281, loss: 1.9041858955021322
Epoch: 0, step: 145 / 281, loss: 1.9132682200135855
Epoch: 0, step: 153 / 281, los

  _warn_prf(average, modifier, msg_start, len(result))


0.0
0.0
Epoch: 1, step: 1 / 281, loss: 2.830392599105835
Epoch: 1, step: 9 / 281, loss: 1.83727731804053
Epoch: 1, step: 17 / 281, loss: 1.517961924566942
Epoch: 1, step: 25 / 281, loss: 1.7089318430423737
Epoch: 1, step: 33 / 281, loss: 1.7376287480195363
Epoch: 1, step: 41 / 281, loss: 1.6466180776677481
Epoch: 1, step: 49 / 281, loss: 1.5208125582763128
Epoch: 1, step: 57 / 281, loss: 1.5473935023734444
Epoch: 1, step: 65 / 281, loss: 1.5701666919084696
Epoch: 1, step: 73 / 281, loss: 1.5076581325433025
Epoch: 1, step: 81 / 281, loss: 1.4787089151364785
Epoch: 1, step: 89 / 281, loss: 1.4306895120090313
Epoch: 1, step: 97 / 281, loss: 1.4951415943730737
Epoch: 1, step: 105 / 281, loss: 1.4685016708714622
Epoch: 1, step: 113 / 281, loss: 1.4720258272327154
Epoch: 1, step: 121 / 281, loss: 1.4305070916975826
Epoch: 1, step: 129 / 281, loss: 1.4437457232974296
Epoch: 1, step: 137 / 281, loss: 1.465366628483264
Epoch: 1, step: 145 / 281, loss: 1.5115189178236599
Epoch: 1, step: 153 / 28

In [159]:
# pretrain + finetune
from sklearn.metrics import precision_score, recall_score

for epoch in range(2):
    net.train()
    epoch_train_loss = 0.0
    for i, batch in enumerate(mri_dataloader_train):
        brain_img = batch['tensor'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = net(brain_img)
        loss_calc = loss(torch.squeeze(outputs), labels.float())

        loss_calc.backward()
        optimizer.step()
        
        epoch_train_loss += loss_calc.item()
        if i % 8 == 0:
            print(f"Epoch: {epoch}, step: {i+1} / {len(mri_dataloader_train)}, loss: {epoch_train_loss / (i+1)}")

    net.eval()
    labels_pred = []
    for i, batch in enumerate(mri_dataloader_test):
        brain_img = batch['tensor'].to(device)
        labels = batch['label'].to(device)

        with torch.no_grad():
            output_prob = torch.sigmoid(net(brain_img)).item()
            if output_prob > 0.5:
                labels_pred.append(1)
            else:
                labels_pred.append(0)

    print(precision_score(y_test, labels_pred))
    print(recall_score(y_test, labels_pred))

Epoch: 0, step: 1 / 281, loss: 1.7101575136184692
Epoch: 0, step: 5 / 281, loss: 1.4812070608139039
Epoch: 0, step: 9 / 281, loss: 1.3079716364542644
Epoch: 0, step: 13 / 281, loss: 1.823077495281513
Epoch: 0, step: 17 / 281, loss: 1.6630467527052935
Epoch: 0, step: 21 / 281, loss: 1.8416676975431896
Epoch: 0, step: 25 / 281, loss: 1.9523288106918335
Epoch: 0, step: 29 / 281, loss: 2.014886194262011
Epoch: 0, step: 33 / 281, loss: 1.8870762586593628
Epoch: 0, step: 37 / 281, loss: 1.8546467729516931
Epoch: 0, step: 41 / 281, loss: 1.9063204148920572
Epoch: 0, step: 45 / 281, loss: 1.9453119436899822
Epoch: 0, step: 49 / 281, loss: 2.035807441691963
Epoch: 0, step: 53 / 281, loss: 1.9134045974263605
Epoch: 0, step: 57 / 281, loss: 1.9582718485280086
Epoch: 0, step: 61 / 281, loss: 1.9239012882357738
Epoch: 0, step: 65 / 281, loss: 1.876353223507221
Epoch: 0, step: 69 / 281, loss: 1.8896263388619907
Epoch: 0, step: 73 / 281, loss: 1.8255659619422808
Epoch: 0, step: 77 / 281, loss: 1.8008

In [94]:
torch.cuda.is_available()

True