In [15]:
!nvidia-smi

Sat Mar 18 20:52:31 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.67       Driver Version: 460.67       CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 3090    Off  | 00000000:65:00.0 Off |                  N/A |
| 64%   56C    P8    21W / 370W |   3404MiB / 24265MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 3090    Off  | 00000000:B3:00.0 Off |                  N/A |
| 92%   78C    P2   332W / 370W |  23982MiB / 24268MiB |     90%      Defaul

In [16]:
import warnings
warnings.filterwarnings("ignore")

import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import random
import torch
import os
import pickle
import gc
from sklearn.metrics import confusion_matrix, f1_score, roc_curve, auc
from torch.optim import Adam, lr_scheduler
from scipy.ndimage import zoom
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook as tqdm
from torchsummary import summary
from transformation_3d import *
from resnet3d import *
from Custom_losses_3d import *
from evaluation import *
from load_data import *
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

print(torch.__version__)

1.10.0+cu113


In [17]:
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
print(torch.cuda.device_count())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

2
cuda:0


In [4]:
img_shape = (96, 96, 96)

class ADNI_3D_Dataset(Dataset):
    
    def __init__(self, csv, path, transform, task):

        self.csv = csv
        self.path = path
        self.transform = transform
        self.task = task
        
    def __len__(self):

        return len(self.csv)

        
    def __getitem__(self, idx):
        
        file_path = self.path + self.csv['Filename'].values[idx]
        image = self.load_img(file_path)
                
        if (self.transform == 1):
            rand = np.random.randint(0, 4, 1)
                
            image = self.transformation(image, rand)

        image = zoom(image, (96/image.shape[0], 96/image.shape[1], 96/image.shape[2]))
        image = self.normalize(image)
        
        if (self.task == 'disease'):

                    
            if (self.csv['Group'].values[idx] == 'CN'):
                label = 0
            elif (self.csv['Group'].values[idx] == 'AD'):
                label = 1
            else:
                label = 1
        elif (self.task == 'race'):
            
            if (self.csv['Race'].values[idx] == 0):
                label = 0
            elif (self.csv['Race'].values[idx] != 0):
                label = 1
            else:
                label = 1
                
        elif (self.task == 'age'):
            
            if (self.csv['Age'].values[idx] <= 75):
                label = 0
            elif (self.csv['Age'].values[idx] > 75):
                label = 1
            else:
                label = 1
                
        elif (self.task == 'sex'):
            
            if (self.csv['Sex'].values[idx] == 'F'):
                label = 0
            elif (self.csv['Sex'].values[idx] == 'M'):
                label = 1
            else:
                label = 1
                
        else:
            print('Wrong task!')
            
            return 
        
        demo = torch.tensor([0 if self.csv['Sex'].values[idx] == 'F' else 1,
                             0 if self.csv['Age'].values[idx] <= 75 else 1])

        image = torch.unsqueeze(torch.from_numpy(image), 0)
        label = torch.tensor(label)
                       
        return image, label, demo
    
    def load_img(self, file_path):
        data = nib.load(file_path)
        data = np.array(data.dataobj)
        return data
    
    def normalize(self, arr):
        arr_min = np.min(arr)
        arr_max = np.max(arr)
        return (arr - arr_min) / (arr_max - arr_min)
    
    def transformation(self, img, rand):
    
        img_np = np.copy(img)

        if (rand == 0):
            img_np = shear(img_np, 6)
        elif (rand == 1):
            img_np = scale(img_np, 0.8)
        elif (rand == 2):
            img_np = fish(img_np, 0.4)
        else:
            img_np = rotation(img_np, 10)
            
            
        return img_np


In [5]:
def display_result(path, loss, auc, name):
    f = open(path, 'a')
    print('[%s] test_loss: %.3f // test_auc: %.3f ' % (name, loss, auc), file=f)                                                                          
    f.close()

In [6]:
def train_step(model, df, data_path, model_path, epochs, transform, task, algo):
    
    num_worker = 16
    Val_Data_set = ADNI_3D_Dataset(csv = df.loc[df['Split']=='val'], path = data_path, transform = transform, task = task)
    Train_Data_set = ADNI_3D_Dataset(csv = df.loc[df['Split']=='train'], path = data_path, transform = transform, task = task)

    val_dataloader = DataLoader(Val_Data_set, batch_size = batch_size, shuffle = True, num_workers = num_worker)
    train_dataloader = DataLoader(Train_Data_set, batch_size = batch_size, shuffle = True, num_workers = num_worker)
    
    train_losses = []
    val_losses = []

    count = 0

    best_val_loss = float("Inf")
    
    for epoch in range(epochs):
        
        if (algo == 'FairALM'):
            lag_mult_r = torch.zeros(len(group_type['race']) * len(surrogate_fns))
            lag_mult_g = torch.zeros(len(group_type['gender']) * len(surrogate_fns))
            lag_mult_a = torch.zeros(len(group_type['age']) * len(surrogate_fns))
        else:
            pass

        model.train()
        training_loss = 0.0 
        train_accuracy = 0.0
        loop = tqdm(enumerate(train_dataloader), total =len(train_dataloader))
        for step, data in loop:     
            inputs, labels, demo = data
            inputs, labels, demo = inputs.to(device), labels.to(device), demo.to(device)

            optimizer.zero_grad()
            
            if (algo != 'Adv'):
                outputs = torch.reshape(model(inputs), (len(labels),))
            else:
                all_outputs = model(inputs)
                outputs_gender = torch.reshape(all_outputs[0], (len(labels), ))
                outputs_age = torch.reshape(all_outputs[1], (len(labels), ))
                outputs_disease = torch.reshape(all_outputs[2], (len(labels), ))
                
            if (algo == 'ERM'):
                loss_value = ERM_Loss(labels.float(), outputs)
            elif (algo == 'Adv'):
                loss_value_gender = reciprocal_BCE_loss(demo[:, 0].float(), outputs_gender)
                loss_value_age = reciprocal_BCE_loss(demo[:, 1].float(), outputs_age)
                loss_value_disease = ERM_Loss(labels.float(), outputs_disease)
                loss_value = loss_value_gender + loss_value_age + loss_value_disease
            elif (algo == 'DistMatch'):
                loss_value, penalty_g = DistMatch_Loss(labels.float(), outputs, demo[:, 0].float(), mode, 'gender')
                loss_value, penalty_a = DistMatch_Loss(labels.float(), outputs, demo[:, 1].float(), mode, 'age')
                loss_value = (loss_value + penalty_g + penalty_a)
            elif (algo == 'FairALM'):
                loss_value, penalty_g, lag_mult_g = fairALM_loss(labels.float(), outputs, demo[:, 0].float(), lag_mult_g, 'gender', False)
                loss_value, penalty_a, lag_mult_a = fairALM_loss(labels.float(), outputs, demo[:, 1].float(), lag_mult_a, 'age', False)
                loss_value += (loss_value + penalty_g + penalty_a)
            else:
                print('Wrong algo!')
                break
                
            weight = torch.ones_like(loss_value)
            weight[labels==1.] = weights[1]
            loss = (loss_value * weight).mean()

            loss.backward()
            optimizer.step()

            training_loss += loss.item()
    
            loop.set_description(f'Training epoch [{epoch}/80]')
            loop.set_postfix(loss = training_loss/(step+1))

        scheduler.step()
        train_losses.append(training_loss/len(train_dataloader))

        model.eval()
        valid_loss = 0.0
        valid_accuracy = 0.0
        with torch.no_grad():
            loop = tqdm(enumerate(val_dataloader), total =len(val_dataloader))
            for step, data in loop:
                inputs, labels, demo = data
                inputs, labels, demo = inputs.to(device), labels.to(device), demo.to(device)

                optimizer.zero_grad()
                
                if (algo != 'Adv'):
                    outputs = torch.reshape(model(inputs), (len(labels),))
                else:
                    all_outputs = model(inputs)
                    outputs_gender = torch.reshape(all_outputs[0], (len(labels), ))
                    outputs_age = torch.reshape(all_outputs[1], (len(labels), ))
                    outputs_disease = torch.reshape(all_outputs[2], (len(labels), ))

                if (algo == 'ERM'):
                    loss_value = ERM_Loss(labels.float(), outputs)
                elif (algo == 'Adv'):
                    loss_value_gender = reciprocal_BCE_loss(demo[:, 0].float(), outputs_gender)
                    loss_value_age = reciprocal_BCE_loss(demo[:, 1].float(), outputs_age)
                    loss_value_disease = ERM_Loss(labels.float(), outputs_disease)
                    loss_value = loss_value_gender + loss_value_age + loss_value_disease
                elif (algo == 'DistMatch'):
                    loss_value, penalty_g = DistMatch_Loss(labels.float(), outputs, demo[:, 0].float(), mode, 'gender')
                    loss_value, penalty_a = DistMatch_Loss(labels.float(), outputs, demo[:, 1].float(), mode, 'age')
                    loss_value = (loss_value + penalty_g + penalty_a)
                elif (algo == 'FairALM'):
                    loss_value, penalty_g, lag_mult_g = fairALM_loss(labels.float(), outputs, demo[:, 0].float(), lag_mult_g, 'gender', False)
                    loss_value, penalty_a, lag_mult_a = fairALM_loss(labels.float(), outputs, demo[:, 1].float(), lag_mult_a, 'age', False)
                    loss_value += (loss_value + penalty_g + penalty_a)
                else:
                    print('Wrong algo!')
                    break

                weight = torch.ones_like(loss_value)
                weight[labels==1.] = weights[1]
                loss = (loss_value * weight).mean()

                valid_loss += loss.item()

                loop.set_description(f'Validataion epoch [{epoch}/80]')
                loop.set_postfix(loss = valid_loss/(step+1))

        val_losses.append(valid_loss / len(val_dataloader))
        if(np.around(valid_loss/len(val_dataloader), 3) < np.around(best_val_loss, 3)):
            count = 0
            best_val_loss = valid_loss/len(val_dataloader)
            torch.save(model.state_dict(), model_path)
            print("=========save model=========")

        count += 1

        if (count > 10):
            print("=========Early stopping=========")
            break


In [7]:
def plot_roc_curve(fper, tper):
    plt.plot(fper, tper, color='red', label='ROC')
    plt.plot([0, 1], [0, 1], color='green', linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic Curve')
    plt.legend()
    plt.show()
    
def eval_step(model, df, data_path, model_path, val, transform, task):
    
    num_worker = 1
    if (val):
        split = 'val'
        Data_set = ADNI_3D_Dataset(csv = df.loc[df['Split']==split], path = data_path, transform = transform, task = task)
        dataloader = DataLoader(Data_set, batch_size = batch_size, shuffle = False, num_workers = num_worker)

    else:
        split = 'test'
        Data_set = ADNI_3D_Dataset(csv = df.loc[df['Split']==split], path = data_path, transform = 0, task = task)
        dataloader = DataLoader(Data_set, batch_size = batch_size, shuffle = False, num_workers = num_worker)

    model.eval()

    test_loss = 0.0
    target=[]
    prob=[]
    
    if (algo == 'FairALM'):
        lag_mult_r = torch.zeros(len(group_type['race']) * len(surrogate_fns))
        lag_mult_g = torch.zeros(len(group_type['gender']) * len(surrogate_fns))
        lag_mult_a = torch.zeros(len(group_type['age']) * len(surrogate_fns))
    else:
        pass
    
    with torch.no_grad():
        loop = tqdm(enumerate(dataloader), total =len(dataloader))
        for step, data in loop:
            inputs, labels, demo = data
            inputs, labels, demo = inputs.to(device), labels.to(device), demo.to(device)

            optimizer.zero_grad()
            if (algo != 'Adv'):
                outputs = torch.reshape(model(inputs), (len(labels),))
            else:
                all_outputs = model(inputs)
                outputs_gender = torch.reshape(all_outputs[0], (len(labels), ))
                outputs_age = torch.reshape(all_outputs[1], (len(labels), ))
                outputs_disease = torch.reshape(all_outputs[2], (len(labels), ))

            if (algo == 'ERM'):
                loss_value = ERM_Loss(labels.float(), outputs)
            elif (algo == 'Adv'):
                loss_value_gender = reciprocal_BCE_loss(demo[:, 0].float(), outputs_gender)
                loss_value_age = reciprocal_BCE_loss(demo[:, 1].float(), outputs_age)
                loss_value_disease = ERM_Loss(labels.float(), outputs_disease)
                loss_value = loss_value_gender + loss_value_age + loss_value_disease
            elif (algo == 'DistMatch'):
                loss_value, penalty_g = DistMatch_Loss(labels.float(), outputs, demo[:, 0].float(), mode, 'gender')
                loss_value, penalty_a = DistMatch_Loss(labels.float(), outputs, demo[:, 1].float(), mode, 'age')
                loss_value = (loss_value + penalty_g + penalty_a)
            elif (algo == 'FairALM'):
                loss_value, penalty_g, lag_mult_g = fairALM_loss(labels.float(), outputs, demo[:, 0].float(), lag_mult_g, 'gender', False)
                loss_value, penalty_a, lag_mult_a = fairALM_loss(labels.float(), outputs, demo[:, 1].float(), lag_mult_a, 'age', False)
                loss_value += (loss_value + penalty_g + penalty_a)
            else:
                print('Wrong algo!')
                break

            weight = torch.ones_like(loss_value)
            weight[labels==1.] = weights[1]
            loss = (loss_value * weight).mean()

            test_loss += loss.item()

            target.extend(np.array(labels.cpu()))
            
            if (algo != 'Adv'):
                prob.extend(outputs.detach().cpu().numpy())
            else:
                prob.extend(outputs_disease.detach().cpu().numpy())

            loop.set_description(f'Test epoch')
            loop.set_postfix(step=(step+1), loss = test_loss/(step+1))

    fpr, tpr, threshold = roc_curve(target, prob)
    auc_curve = auc(fpr, tpr)
    print("AUC: ", auc_curve)
    
    if (val):
        best_thresh = cal_best_thresh(np.array(target), np.array(prob))

        np.savetxt('thresh/{i}_thresh.txt'.format(i=model_path[12:-4]), [best_thresh])
    else:
        with open('predictions/'+model_path[12:-4]+'_on_original', "wb") as fp:
            pickle.dump(np.array(prob), fp)

In [8]:
def get_embeddings(model, df, data_path, split):
    
    X_embeddings = []
    
    Data_set = ADNI_3D_Dataset(csv = df.loc[df['Split']==split], path = data_path, transform = 0, task = 'disease')
    dataloader = DataLoader(Data_set, batch_size = 1, shuffle = False, num_workers = 16)
    
    with torch.no_grad():
        loop = tqdm(enumerate(dataloader), total =len(dataloader))
        for step, data in loop:
            inputs, labels, demo = data
            
            outputs = model(inputs)
            
            X_embeddings.append(outputs.numpy()[0])
            
    
    return X_embeddings

In [9]:
def cal_best_thresh(y_test, y_preds):
    best_thresh = 0
    best = 0
    tprs, fprs, threshes = get_threshes(y_test, y_preds)
        
    for i in range(len(threshes)):
        score = f1_score(y_test, np.where(y_preds >= threshes[i], 1, 0), average='binary')
        if (score > best):
            best = score
            best_thresh = threshes[i]
                
    return best_thresh

def get_threshes(y_test, preds):

    fpr, tpr, threshold = roc_curve(y_test, preds, drop_intermediate=False)
        
    return tpr, fpr, threshold

def get_tpr(y_test, preds, thresh):
    tn, fp, fn, tp = confusion_matrix(y_test, np.where(preds >= thresh, 1, 0)).ravel()
    
    return tp/(tp+fn)

In [13]:
df = pd.read_csv('data_new.csv')
data_path = '../../../mnt/usb/kuopc/ADNI_B1/MPR__GradWarp__B1_Correction_crop/'

df = df.loc[df['Group'] != 'MCI']


# class_0 = len(np.where(df['Age'] <= 75)[0])
# class_1 = len(np.where(df['Age'] > 75)[0])

# class_0 = len(np.where(df['Sex'] == 'F')[0])
# class_1 = len(np.where(df['Sex'] == 'M')[0])

class_0 = len(np.where(df['Group'] == 'CN')[0])
class_1 = len(np.where(df['Group'] == 'AD')[0])

total = class_0 + class_1

weight_for_0 = (1 / class_0) * (total / 2.0)
weight_for_1 = (1 / class_1) * (total / 2.0)

weights = torch.tensor([weight_for_0, weight_for_1], dtype=torch.float32)

print(weights)
print(type(weights))

tensor([0.7956, 1.3457])
<class 'torch.Tensor'>


In [14]:
def load_img(file_path):
    data = nib.load(file_path)
    data = np.array(data.dataobj)
    return data

def normalize(arr):
    arr_min = np.min(arr)
    arr_max = np.max(arr)
    return (arr - arr_min) / (arr_max - arr_min)

with torch.no_grad():
    def compute_aug_predictions(model, model_path, df, data_path):
        seed = 2021
        np.random.seed(seed)

        y_preds = []

        loop = tqdm(enumerate(df['Filename']), total =len(df['Filename']))
        for i, file in loop:

            file_path = data_path + file
            image = load_img(file_path)

            aug_inputs = []
            for _ in range(3):

                try:
                    aug_img = shear(image, 6)
                    aug_inputs.append(normalize(zoom(aug_img, (96/aug_img.shape[0], 96/aug_img.shape[1], 96/aug_img.shape[2]))))
                except:
                    aug_img = image
                    aug_inputs.append(normalize(zoom(aug_img, (96/aug_img.shape[0], 96/aug_img.shape[1], 96/aug_img.shape[2]))))

                try:
                    aug_img = rotation(image, 10)
                    aug_inputs.append(normalize(zoom(aug_img, (96/aug_img.shape[0], 96/aug_img.shape[1], 96/aug_img.shape[2]))))
                except:
                    aug_img = image
                    aug_inputs.append(normalize(zoom(aug_img, (96/aug_img.shape[0], 96/aug_img.shape[1], 96/aug_img.shape[2]))))

                try:
                    aug_img = fish(image, 0.4)
                    aug_inputs.append(normalize(zoom(aug_img, (96/aug_img.shape[0], 96/aug_img.shape[1], 96/aug_img.shape[2]))))
                except:
                    aug_img = image
                    aug_inputs.append(normalize(zoom(aug_img, (96/aug_img.shape[0], 96/aug_img.shape[1], 96/aug_img.shape[2]))))

                try:
                    aug_img = scale(image, 0.8)
                    aug_inputs.append(normalize(zoom(aug_img, (96/aug_img.shape[0], 96/aug_img.shape[1], 96/aug_img.shape[2]))))
                except:
                    aug_img = image
                    aug_inputs.append(normalize(zoom(aug_img, (96/aug_img.shape[0], 96/aug_img.shape[1], 96/aug_img.shape[2]))))

                del aug_img
                gc.collect()

            aug_inputs = torch.unsqueeze(torch.from_numpy(np.array(aug_inputs)), 1).to(device)

            temp_predict = model(aug_inputs)

            y_preds.append(torch.mean(temp_predict))

            del aug_inputs, image, temp_predict
            gc.collect()
            torch.cuda.empty_cache()

            print(i)


        with open('predictions/'+model_path[12:-4]+'_on_aug', "wb") as fp:
            pickle.dump(np.array(y_preds), fp)
        fp.close()

In [12]:
img_shape = (1, 96, 96, 96)
batch_size = 16

np.random.seed(2021)
torch.manual_seed(2021)

epochs = 80

model_path = 'checkpoints/3D_CNN_AD_CN_proposed.pth'

model = resnet3d_model(18, adv=False)
model.load_state_dict(torch.load(model_path, map_location=device))

model.to(device)

for name, param in model.named_parameters():
    if not ('layer4' in name):
        param.requires_grad = False
    else:
        break
#         param.requires_grad = True
        


In [13]:
summary(model, (1, 96, 96, 96))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 64, 96, 48, 48]          21,952
       BatchNorm3d-2       [-1, 64, 96, 48, 48]             128
              ReLU-3       [-1, 64, 96, 48, 48]               0
         MaxPool3d-4       [-1, 64, 48, 24, 24]               0
            Conv3d-5       [-1, 64, 48, 24, 24]         110,592
       BatchNorm3d-6       [-1, 64, 48, 24, 24]             128
              ReLU-7       [-1, 64, 48, 24, 24]               0
            Conv3d-8       [-1, 64, 48, 24, 24]         110,592
       BatchNorm3d-9       [-1, 64, 48, 24, 24]             128
             ReLU-10       [-1, 64, 48, 24, 24]               0
       BasicBlock-11       [-1, 64, 48, 24, 24]               0
           Conv3d-12       [-1, 64, 48, 24, 24]         110,592
      BatchNorm3d-13       [-1, 64, 48, 24, 24]             128
             ReLU-14       [-1, 64, 48,

In [14]:
model_path = 'checkpoints/3D_CNN_AD_CN_proposed_task_transfer_age_2.pth'

optimizer = Adam(model.parameters(),lr=1e-5, weight_decay=1e-6) 
scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1) 

task = 'age'
algo = 'ERM'

train_step(model, df, data_path, model_path,  epochs, 0, task, algo)
# eval_step(model, df, data_path, model_path, True, 0, task)     calculate threshold, no needed for task transfer 


eval_step(model, df, data_path, model_path, False, 0, task)  # for original data

# for aug. data
# with torch.no_grad():
#     model.eval()
#     compute_aug_predictions(model, model_path, df.loc[df['Split'] == 'test'], data_path)



  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]



  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]



  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]



  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]



  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]



  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]



  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]



  0%|          | 0/16 [00:00<?, ?it/s]

AUC:  0.6318002984669651


In [16]:
# if just want to test

model_path = 'checkpoints/3D_CNN_AD_CN_task_transfer_age_2.pth'

batch_size = 16

task = 'age'
algo = 'ERM'

model = resnet3d_model(18, adv=False)

model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

optimizer = Adam(model.parameters(),lr=1e-5, weight_decay=1e-6) 
scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1) 


eval_step(model, df, data_path, model_path, False, 0, task)  # for original data

  0%|          | 0/16 [00:00<?, ?it/s]

AUC:  0.7331701346389227


In [16]:
# if just want to test

model_path = 'checkpoints/3D_CNN_AD_CN_task_transfer_gender.pth'

batch_size = 16

task = 'sex'
algo = 'ERM'

model = resnet3d_model(18, adv=False)
# model.fc = net()

model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

optimizer = Adam(model.parameters(),lr=1e-5, weight_decay=1e-6) 
scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1) 


eval_step(model, df, data_path, model_path, False, 0, task)  # for original data

  0%|          | 0/16 [00:00<?, ?it/s]

AUC:  0.7690738474092207


In [23]:
df = pd.read_csv('data_new.csv')
data_path = '../../../mnt/usb/kuopc/ADNI_B1/MPR__GradWarp__B1_Correction_crop/'

df = df.loc[df['Group'] != 'MCI']
df = df.loc[df['Split'] == 'test']

df['Group'] = df['Group'].replace(['CN', 'AD'], [0, 1])
df['Sex'] = df['Sex'].replace(['F', 'M'], [0, 1])
df['Age'] = np.where(df['Age'] <= 75, 0, 1)
df['Race'] = np.where(df['Race'] < 1, 0, 1)

In [None]:
file_name = 'predictions/3D_CNN_AD_CN_task_transfer_gender_on_original'

with open(file_name, "rb") as fp:
    y_preds = CPU_Unpickler(fp).load()
fp.close()

all_mean_score, all_lower, all_upper = task_transfer_test(y_preds, df['Sex'].values)

print(all_mean_score, all_lower, all_upper)


In [22]:
file_name = 'predictions/3D_CNN_AD_CN_proposed_task_transfer_gender_on_original'

with open(file_name, "rb") as fp:
    y_preds = CPU_Unpickler(fp).load()
fp.close()

all_mean_score, all_lower, all_upper = task_transfer_test(y_preds, df['Sex'].values)

print(all_mean_score, all_lower, all_upper)


0.5785567836623515 0.49534198934779367 0.6617715779769094


In [18]:
img_shape = (1, 96, 96, 96)
batch_size = 16

np.random.seed(2021)
torch.manual_seed(2021)

epochs = 80

model_path = 'checkpoints/3D_CNN_AD_CN_proposed.pth'

model = resnet3d_model(18, adv=False)
model.load_state_dict(torch.load(model_path, map_location=device))

# model.fc = net()

# model.to(device)

for name, param in model.named_parameters():
    if not ('fc' in name):
        param.requires_grad = False
    else:
        param.requires_grad = True
        


In [19]:
class Identity(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x
    
model.fc = Identity()

In [20]:
from sklearn import decomposition
X_embeddings = get_embeddings(model, df, data_path, 'train')
# pca = decomposition.PCA(n_components=100, whiten=False, svd_solver='full')
# X_embeds_pca = pca.fit_transform(X_embeddings)

directory = 'embeddings/'
with open(directory+'training_3dresnet_proposed_model', "wb") as fp:
    pickle.dump(X_embeddings, fp)

  0%|          | 0/765 [00:00<?, ?it/s]

In [13]:
# if just want to test

model_path = 'checkpoints/3D_CNN_AD_CN_proposed_task_transfer_gender.pth'

batch_size = 16


task = 'sex'
algo = 'ERM'

model = resnet3d_model(18, adv=False)
model.fc = net()

model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

optimizer = Adam(model.parameters(),lr=1e-5, weight_decay=1e-6) 
scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1) 


eval_step(model, df, data_path, model_path, False, 0, task)  # for original data

  0%|          | 0/16 [00:00<?, ?it/s]

AUC:  0.42356861145110836


In [16]:
# if just want to test

model_path = 'checkpoints/3D_CNN_AD_CN_proposed_task_transfer_gender.pth'

batch_size = 16


task = 'sex'
algo = 'ERM'

model = resnet3d_model(18, adv=False)
model.fc = net()

model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

optimizer = Adam(model.parameters(),lr=1e-5, weight_decay=1e-6) 
scheduler = lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1) 


eval_step(model, df, data_path, model_path, False, 0, task)  # for original data

  0%|          | 0/16 [00:00<?, ?it/s]

AUC:  0.5083639330885353
