# Alzheimer's Disease Classification

### This code will perform stratified five fold cross validation at the subject level using a VGG16 binary classifier. We will also employ a nested-four-fold cross validaion to optimize the epochs and learning rate

In [None]:
# Packages to start with 

import os 
import imageio 
import matplotlib.pyplot as plt 
import random
import numpy as np 
from PIL import Image


# 1. Data Splitting and Resizing



## 1.1 Initialize a seed for the random data split

In [None]:
seed_number = 1000

## 1.2  Designate "1" for AD, and "0" for NC. We will run the data split for each label.

Split data according to their id and distribution of one or two eyes for evenness.
- 6180 refers to 61AD, 80NC subjects
- The data split here is hard-coded specifically with respect to 100 / 5 = 20 images, hence 20 images per fold. 

In [None]:
img_dir = '/DATA/charlie/fundus/AutoMorph_6180/AD/binary_vessel/resize/'
outer_dir = '/DATA/charlie/fundus/AutoMorph_6180/Experiment/'
flag = '1' # 1 here for AD

name_of_experiment = 'AutoMorph_6180_' + str(seed_number)
images = os.listdir(img_dir)


# the subject_ids here are based on the naming of the image
# e.g., 1234567_21015_0_0 where 1234567 is the subject ID, 21015_0_0 means left eye at first visit.
# the indices (e.g., 0:7) need to be the indicies for the subject id
subject_ids = [] 
for img in images:
    eid = img[0:7]
    subject_ids.append(np.int(eid))
    
single_img_id = []
double_img_id = [] 

unique, counts = np.unique(subject_ids, return_counts = True)

result = np.column_stack((unique, counts))


for i in result:
    if i[1] == 2:
        double_img_id.append(i[0]) # The i[0] is because of the double brackets
    if i[1] == 1:
        single_img_id.append(i[0])
        
random.seed(seed_number)
random.shuffle(double_img_id)
random.shuffle(single_img_id)

if not os.path.exists(outer_dir + name_of_experiment):

    os.makedirs(outer_dir + name_of_experiment)
    os.makedirs(outer_dir + name_of_experiment + '/f1/')
    os.makedirs(outer_dir + name_of_experiment + '/f1/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f1/0/')
    os.makedirs(outer_dir + name_of_experiment + '/f2/')
    os.makedirs(outer_dir + name_of_experiment + '/f2/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f2/0/')
    os.makedirs(outer_dir + name_of_experiment + '/f3')
    os.makedirs(outer_dir + name_of_experiment + '/f3/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f3/0/')
    os.makedirs(outer_dir + name_of_experiment + '/f4/')
    os.makedirs(outer_dir + name_of_experiment + '/f4/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f4/0/')
    os.makedirs(outer_dir + name_of_experiment + '/f5/')
    os.makedirs(outer_dir + name_of_experiment + '/f5/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f5/0/')
else:
    print('Folder already exists')
    
# The flag means whether you are saving CN images or IAD images


dir_1 = outer_dir + name_of_experiment + '/f1/' + flag + '/'
dir_2 = outer_dir + name_of_experiment + '/f2/' + flag + '/'
dir_3 = outer_dir + name_of_experiment + '/f3/' + flag + '/'
dir_4 = outer_dir + name_of_experiment + '/f4/' + flag + '/'
dir_5 = outer_dir + name_of_experiment + '/f5/' + flag + '/'


counter = 0 
for eid in double_img_id:    
    for img in images:
        if str(eid) in img:
            X = Image.open(img_dir + img)
            X = X.resize((224,224))
            if counter % 5 == 0:
                X.save(dir_1 + img)
            #    print('Image has been saved:' + dir_1 + img)
            if counter % 5 == 1:
                X.save(dir_2 + img)
            #    print('Image has been saved:' +  dir_2 + img)
            if counter % 5 == 2:
                X.save(dir_3 + img)
            #    print('Image has been saved:' +  dir_3 + img)
            if counter % 5 == 3:
                X.save(dir_4 + img)  
            #    print('Image has been saved:' + dir_4 + img)
            if counter % 5 == 4:
                X.save(dir_5 + img)  
             #   print('Image has been saved:' + dir_4 + img)
    #print('---------------------------')
    counter += 1


#print('Now starting to print out the single images')
for eid in single_img_id:
    for img in images:
        if str(eid) in img:
            X = Image.open(img_dir + img)
            X = X.resize((224,224))
            if len(os.listdir(dir_1)) < 20:
                X.save(dir_1 + img)
           #     print('Image has been saved:' + dir_1 + img)
            elif len(os.listdir(dir_2)) < 20:
                X.save(dir_2 + img)
            #    print('Image has been saved:' + dir_2 + img)
            elif len(os.listdir(dir_3)) < 20:
                X.save(dir_3 + img)
            #    print('Image has been saved:' + dir_3 + img)
            elif len(os.listdir(dir_4)) < 20:
                X.save(dir_4 + img)
            #    print('Image has been saved:' + dir_4 + img)
            elif len(os.listdir(dir_5)) < 20:
                X.save(dir_5 + img)
            #    print('Image has been saved:' + dir_5 + img)
print('Data has been saved')

In [None]:
img_dir = '/DATA/charlie/fundus/AutoMorph_6180/CN/binary_vessel/resize/'
outer_dir = '/DATA/charlie/fundus/AutoMorph_6180/Experiment/'
flag = '0' # Your options here at '1' and '0'

name_of_experiment = 'AutoMorph_6180_' + str(seed_number)


images = os.listdir(img_dir)

subject_ids = [] 
for img in images:
    #eid = img[3:10]
    eid = img[3:10]
    # if AD then 0:7
    # if CN then 3:10
    subject_ids.append(np.int(eid))
    
single_img_id = []
double_img_id = [] 

unique, counts = np.unique(subject_ids, return_counts = True)

result = np.column_stack((unique, counts))


for i in result:
    if i[1] == 2:
        double_img_id.append(i[0]) # The i[0] is because of the double brackets
    if i[1] == 1:
        single_img_id.append(i[0])
        
random.seed(seed_number)
random.shuffle(double_img_id)
random.shuffle(single_img_id)

if not os.path.exists(outer_dir + name_of_experiment):

    os.makedirs(outer_dir + name_of_experiment)
    os.makedirs(outer_dir + name_of_experiment + '/f1/')
    os.makedirs(outer_dir + name_of_experiment + '/f1/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f1/0/')
    os.makedirs(outer_dir + name_of_experiment + '/f2/')
    os.makedirs(outer_dir + name_of_experiment + '/f2/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f2/0/')
    os.makedirs(outer_dir + name_of_experiment + '/f3')
    os.makedirs(outer_dir + name_of_experiment + '/f3/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f3/0/')
    os.makedirs(outer_dir + name_of_experiment + '/f4/')
    os.makedirs(outer_dir + name_of_experiment + '/f4/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f4/0/')
    os.makedirs(outer_dir + name_of_experiment + '/f5/')
    os.makedirs(outer_dir + name_of_experiment + '/f5/1/')
    os.makedirs(outer_dir + name_of_experiment + '/f5/0/')
else:
    print('Folder already exists')
    
# The flag means whether you are saving CN images or IAD images


dir_1 = outer_dir + name_of_experiment + '/f1/' + flag + '/'
dir_2 = outer_dir + name_of_experiment + '/f2/' + flag + '/'
dir_3 = outer_dir + name_of_experiment + '/f3/' + flag + '/'
dir_4 = outer_dir + name_of_experiment + '/f4/' + flag + '/'
dir_5 = outer_dir + name_of_experiment + '/f5/' + flag + '/'


counter = 0 
for eid in double_img_id:    
    for img in images:
        if str(eid) in img:
            X = Image.open(img_dir + img)
            X = X.resize((224,224))
            if counter % 5 == 0:
                X.save(dir_1 + img)
             #   print('Image has been saved:' + dir_1 + img)
            if counter % 5 == 1:
                X.save(dir_2 + img)
             #   print('Image has been saved:' +  dir_2 + img)
            if counter % 5 == 2:
                X.save(dir_3 + img)
             #   print('Image has been saved:' +  dir_3 + img)
            if counter % 5 == 3:
                X.save(dir_4 + img)  
             #   print('Image has been saved:' + dir_4 + img)
            if counter % 5 == 4:
                X.save(dir_5 + img)  
             #   print('Image has been saved:' + dir_5 + img)
    #print('---------------------------')
    counter += 1

#print('Data has been saved')
#print('Now starting to print out the single images')
for eid in single_img_id:
    for img in images:
        if str(eid) in img:
            X = Image.open(img_dir + img)
            X = X.resize((224,224))
            if len(os.listdir(dir_1)) < 20:
                X.save(dir_1 + img)
             #   print('Image has been saved:' + dir_1 + img)
            elif len(os.listdir(dir_2)) < 20:
                X.save(dir_2 + img)
             #   print('Image has been saved:' + dir_2 + img)
            elif len(os.listdir(dir_3)) < 20:
                X.save(dir_3 + img)
             #   print('Image has been saved:' + dir_3 + img)
            elif len(os.listdir(dir_4)) < 20:
                X.save(dir_4 + img)
             #   print('Image has been saved:' + dir_4 + img)
            elif len(os.listdir(dir_5)) < 20:
                X.save(dir_5 + img)
             #   print('Image has been saved:' + dir_5 + img)
print('Data has been saved')  


# 2. Model Cross Validation

# 2.1 Packages

The traditional PyTorch (compatibility tested at 1.7.1) and sklearn packages are required.

In [None]:
import torch
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset
import torchvision.datasets as datasets
import os
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import torchvision.models as models
import torch.nn as nn
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch.nn.functional as F
from sklearn.metrics import classification_report

# 2.2 Utilities

In [None]:
def computeAUC(dataGT, dataPRED):
    outAUROC = []
    datanpGT = dataGT.cpu().numpy()
    datanpPRED = dataPRED.cpu().numpy()
            
    return roc_auc_score(datanpGT, datanpPRED)

def computeACC(dataGT, dataCLASS):
    
    datanpGT = dataGT.cpu().numpy()
    datanpCLASS = dataCLASS.cpu().numpy()
    
    return accuracy_score(datanpGT, datanpCLASS)

def classreport(dataGT, predCLASS):
    datanpGT = dataGT.cpu().numpy()
    datanppredCLASS = predCLASS.cpu().numpy()
    
    print(classification_report(datanpGT, datanppredCLASS, digits = 3))
    
def train(model, dataloader, optimizer, criterion):
    model.train() 
    tr_loss = []
    for i, data in enumerate(dataloader):
        optimizer.zero_grad()
        images, labels = data
        images = images.float().cuda()
        labels = labels.long().cuda()
    #    print(images.shape)
        output = model(images)
        softmax_preds = F.softmax(output, dim = 1)
        loss = criterion(softmax_preds, labels)
      #  print('Loss is', loss.item())
        tr_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    return np.average(tr_loss, axis = 0)

def validation(model, dataloader):
    model.eval()
    model.cuda()
    outGT = torch.FloatTensor().cuda()
    outPRED = torch.FloatTensor().cuda()
    outClass = torch.FloatTensor().cuda()
    val_loss = [] 
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            images, labels, = data
            images = images.float().cuda()
            labels = labels.long().cuda()
            output = model(images)
            softmax_preds = F.softmax(output, dim = 1)
            loss = criterion(softmax_preds, labels)
            val_loss.append(loss.item())
            outGT = torch.cat((outGT, labels), 0)
            outPRED = torch.cat((outPRED, softmax_preds), 0)
            class_predictions = torch.argmax(softmax_preds, dim = 1)
            outClass = torch.cat((outClass, class_predictions), 0)
    acc_test = computeACC(outGT, outClass)
    return np.average(val_loss, axis= 0 ), acc_test


def test(model, dataloader):
    model.eval()
    model.cuda()
    outGT = torch.FloatTensor().cuda()
    outPRED = torch.FloatTensor().cuda()
    outClass = torch.FloatTensor().cuda()
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            images, labels, = data
            images = images.float().cuda()
            labels = labels.float().cuda()
            output = model(images)
            outGT = torch.cat((outGT, labels), 0)
            softmax_preds = F.softmax(output, dim = 1)
            outPRED = torch.cat((outPRED, softmax_preds), 0)
            class_predictions = torch.argmax(softmax_preds, dim = 1)
            outClass = torch.cat((outClass, class_predictions), 0)
        
    classreport(outGT, outClass)
    acc_test = computeACC(outGT, outClass)
    return acc_test

def nested_data_order(data, tr_transform, val_transform, N):
    TEST_DATA = datasets.ImageFolder(root = DATA_SETS[4])
    if N == 1:
        TRAIN_DATA_F1 = datasets.ImageFolder(root = DATA_SETS[0], transform = tr_transform)
        TRAIN_DATA_F2 = datasets.ImageFolder(root = DATA_SETS[1], transform = tr_transform)
        TRAIN_DATA_F3 = datasets.ImageFolder(root = DATA_SETS[2], transform = tr_transform)
        TRAIN_DATA = ConcatDataset([TRAIN_DATA_F1, TRAIN_DATA_F2, TRAIN_DATA_F3])
        VAL_DATA = datasets.ImageFolder(root = DATA_SETS[3], transform = val_transform)
        

    if N == 2:
        TRAIN_DATA_F1 = datasets.ImageFolder(root = DATA_SETS[0], transform = tr_transform)
        TRAIN_DATA_F2 = datasets.ImageFolder(root = DATA_SETS[1], transform = tr_transform)
        TRAIN_DATA_F3 = datasets.ImageFolder(root = DATA_SETS[3], transform = tr_transform)
        TRAIN_DATA = ConcatDataset([TRAIN_DATA_F1, TRAIN_DATA_F2, TRAIN_DATA_F3])
        VAL_DATA = datasets.ImageFolder(root = DATA_SETS[2], transform = val_transform)
    
        
    if N == 3:
        TRAIN_DATA_F1 = datasets.ImageFolder(root = DATA_SETS[0], transform = tr_transform)
        TRAIN_DATA_F2 = datasets.ImageFolder(root = DATA_SETS[2], transform = tr_transform)
        TRAIN_DATA_F3 = datasets.ImageFolder(root = DATA_SETS[3], transform = tr_transform)
        TRAIN_DATA = ConcatDataset([TRAIN_DATA_F1, TRAIN_DATA_F2, TRAIN_DATA_F3])
        VAL_DATA = datasets.ImageFolder(root = DATA_SETS[1], transform = val_transform)
        
    if N == 4:
        TRAIN_DATA_F1 = datasets.ImageFolder(root = DATA_SETS[1], transform = tr_transform)
        TRAIN_DATA_F2 = datasets.ImageFolder(root = DATA_SETS[2], transform = tr_transform)
        TRAIN_DATA_F3 = datasets.ImageFolder(root = DATA_SETS[3], transform = tr_transform)
        TRAIN_DATA = ConcatDataset([TRAIN_DATA_F1, TRAIN_DATA_F2, TRAIN_DATA_F3])
        VAL_DATA = datasets.ImageFolder(root = DATA_SETS[0], transform = val_transform)


    return TRAIN_DATA, VAL_DATA, TEST_DATA

def generate_train_test_set(data, train_transform, test_transform):
    
    
    TRAIN_DATA_F1 = datasets.ImageFolder(root = DATA_SETS[0], transform = train_transform)
    TRAIN_DATA_F2 = datasets.ImageFolder(root = DATA_SETS[1], transform = train_transform)
    TRAIN_DATA_F3 = datasets.ImageFolder(root = DATA_SETS[2], transform = train_transform)
    TRAIN_DATA_F4 = datasets.ImageFolder(root = DATA_SETS[3], transform = train_transform)
    TRAIN_DATA = ConcatDataset([TRAIN_DATA_F1, TRAIN_DATA_F2, TRAIN_DATA_F3, TRAIN_DATA_F4])
    TEST_DATA = datasets.ImageFolder(root = DATA_SETS[4], transform = test_transform)
    
    return TRAIN_DATA, TEST_DATA

# 2.3: 5 Fold Cross Validation

We will do a five fold cross validation. There will be a four fold inner cross validation to optimize the learning rate and epochs. 

In [None]:
import torch
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset
import torchvision.datasets as datasets
import os
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import torchvision.models as models
import torch.nn as nn
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch.nn.functional as F
from sklearn.metrics import classification_report


#primary_directory = '/blue/ruogu.fang/charlietran/UKB/code/AutoMorph/Experiment/AutoMorph_6180_' + str(seed_number)
primary_directory = outer_dir + name_of_experiment
fold1_img_dir = primary_directory + '/f1/'
fold2_img_dir = primary_directory + '/f2/'
fold3_img_dir = primary_directory + '/f3/'
fold4_img_dir = primary_directory + '/f4/'
fold5_img_dir = primary_directory + '/f5/'


    
DATA = [[fold1_img_dir, fold2_img_dir, fold3_img_dir, fold4_img_dir, fold5_img_dir],
       [fold1_img_dir, fold2_img_dir, fold3_img_dir, fold5_img_dir, fold4_img_dir],
       [fold1_img_dir, fold2_img_dir, fold4_img_dir, fold5_img_dir, fold3_img_dir],
       [fold1_img_dir, fold3_img_dir, fold4_img_dir, fold5_img_dir, fold2_img_dir],
       [fold2_img_dir, fold3_img_dir, fold4_img_dir, fold5_img_dir, fold1_img_dir]]

for j in range(5):
    torch.cuda.empty_cache()
    print('STARTING PROCEDURE FOR TEST FOLD' + str(j + 1))
    ###############
    num_classes = 2
    learning_rate = [1e-4, 1e-5]
    DATA_SETS = DATA[j]
    #DATA_SETS = [fold1_img_dir, fold2_img_dir, fold3_img_dir, fold4_img_dir, fold5_img_dir]
    N = 4
    num_epochs = 50
    bsz = 64


    rotations = transforms.RandomRotation(15)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.RandomHorizontalFlip(p =0.5),
            transforms.RandomVerticalFlip(p = 0.5),
        transforms.RandomApply([rotations], p = 0.5),
        transforms.ToTensor(),
        normalize,
    ])

    test_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        normalize,
    ])




    optimal_accuracy_list = []
    optimal_epochs_list = [] 
    for learn_rate in learning_rate:
        inner_accuracy = []
        epoch_list = []

        inner_accuracy = []
        epoch_list = []

        for inner_n in range(1, N + 1):
            print(f'INNERFOLD {inner_n}', ' LEARNING RATE', learn_rate)
            print('--------------------------------')
            train_ncv_data, val_cv_data, test_data = nested_data_order(DATA_SETS, train_transform, test_transform, inner_n)
            train_nest_loader = DataLoader(train_ncv_data, batch_size = bsz, shuffle = True, drop_last = True)
            val_nest_loader = DataLoader(val_cv_data, batch_size = bsz, shuffle = True)



            model = models.vgg16(pretrained = True)
            model.classifier[6] = nn.Linear(in_features = 4096, out_features = num_classes, bias = True)
            model.cuda()

            criterion = nn.CrossEntropyLoss().cuda()
            optimizer = optim.Adam(model.parameters(), lr = learn_rate)
            #optimizer = optim.Adam(model.parameters(), lr = learn_rate, weight_decay = 1e-6)
            print('INITIALIZING THE VGG16 MODEL TO IMAGENET')

            epochs_list = [] 
            validation_loss_set = [] 
            training_loss_set = []

            nest_val_acc = [] 
            for epoch in range(num_epochs):
                training_loss = train(model, train_nest_loader, optimizer, criterion)
                validation_loss, val_accuracy = validation(model, val_nest_loader)
                #print("Epoch", epoch, 'Training Loss', "{:.4f}".format(training_loss), 
                #      ' Validation Loss', "{:.4f}".format(validation_loss), ' Validation Accuracy', val_accuracy)
                epochs_list.append(epoch)
                validation_loss_set.append(validation_loss)
                training_loss_set.append(training_loss)

                nest_val_acc.append(val_accuracy)

            inner_accuracy.append(nest_val_acc)
            # nest val acc is now a list of 50 elements

                #print('Validation Accuracy', val_accuracy)
                #inner_accuracy.append(val_accuracy)
            #plt.plot(epochs_list, training_loss_set)
            #plt.title('Training Loss Curve' + ' Test fold ' + 'Five' + ' Internal fold ' + str(inner_n) + ' learning rate' + str(learn_rate))
            #plt.plot(epochs_list, validation_loss_set)
            #plt.title('Validation Loss Curve' + 'Test fold' + 'Five' + ' Internal fold ' + str(inner_n) + ' learning rate' + str(learn_rate))
            #plt.legend(['Training', 'Validation'])
            #plt.show()



        optimal_accuracy_list.append(max(np.average(inner_accuracy, axis = 0)))
        optimal_epochs_list.append(np.argmax(np.average(inner_accuracy, axis = 0)))

    optimal_learning_rate = learning_rate[np.argmax(optimal_accuracy_list, axis = 0)]
    optimal_epochs = optimal_epochs_list[np.argmax(optimal_accuracy_list, axis = 0)]
    print('Optimal Accuracy List is', optimal_accuracy_list, 'Optimal Epochs list is', optimal_epochs_list)
    print('The optimal learning rate is', optimal_learning_rate, ' the optimal epochs is ', optimal_epochs)

    train_data, test_data = generate_train_test_set(DATA_SETS, train_transform, test_transform)
    train_loader = DataLoader(train_data, batch_size = bsz, shuffle = True, drop_last = True)
    test_loader = DataLoader(test_data, batch_size = bsz, shuffle = False) # Shuffle does not really  matter 
    model = models.vgg16(pretrained = True)
    model.classifier[6] = nn.Linear(in_features = 4096, out_features = num_classes, bias = True)
    model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.Adam(model.parameters(), lr = learn_rate)
    #optimizer = optim.Adam(model.parameters(), lr = learn_rate, weight_decay = 1e-6)



    epochs = []
    epoch_training_loss = []
    epoch_val_loss = []
    for epoch in range(optimal_epochs):
        training_loss = train(model, train_loader, optimizer, criterion)
        epoch_training_loss.append(training_loss)
        epochs.append(epoch)
        print("Epoch", epoch, 'Training Loss', "{:.4f}".format(training_loss))

    acc_test = test(model, test_loader)   


    output_save_model_name = '/home/charlietran/model_auto/VGG16_AutoMorph_TESTFOLD_' + str(j+1) + '_6180_' + str(seed_number) + '.pth'
    torch.save({'model_state_dict': model.state_dict(),
        }, output_save_model_name)

