In [1]:
import random
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import glob
import pickle
from tqdm import tqdm
import nibabel as nib

%matplotlib inline

In [146]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision import transforms

from torch.utils.data import Dataset, TensorDataset, random_split, SubsetRandomSampler, ConcatDataset
from sklearn.model_selection import KFold

from skimage.transform import resize 

from sklearn.metrics import mean_absolute_error as mae
from torch.autograd import Variable

In [3]:
# Use the GPU if you have one
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [230]:
#single scale
class fc_layer(nn.Module):
    def __init__(self, in_layer, out_layer):
        super(fc_layer, self).__init__()
        self.fc = nn.Linear(in_layer, out_layer)
        #if in_features=5 and out_features=10 and the input tensor x 
        #has dimensions 2-3-5, then the output tensor will have dimensions 2-3-10???
        #

    def forward(self, x):
        return self.fc(x)

class avg_pool(nn.Module):
    def __init__(self):
        super(avg_pool, self).__init__()
        self.avgp = nn.AvgPool3d(2)
        
    def forward(self, x):
        return self.avgp(x)

class softmax_layer(nn.Module):
    def __init__(self):
        super(softmax_layer, self).__init__()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        return self.softmax(x)

In [185]:
def customToTensor(img):
    if isinstance(img, np.ndarray):
        img1 = torch.from_numpy(img)
        img1 = resize_image(img, (150, 150, 200))
        # backward compatibility
        return img1.astype(np.float32)

def resize_image(img_array, trg_size):
    res = resize(img_array, trg_size, mode='reflect', preserve_range=True, anti_aliasing=False)
    # type check
    if type(res) != np.ndarray:
        raise "type error!"
    return res

In [173]:
NUM_EPOCH = 50
BATCH_SIZE = 20
LR = 0.001
SAVE_PATH_AP = r'C:\Users\pbhav\Desktop\NYU\ivp\project\model\AP' #path to save age prediction model
SAVE_PATH_BC = r'C:\Users\pbhav\Desktop\NYU\ivp\project\model\BC' #path to save disease classification madel

In [231]:
#data prep
class ADNI_Dataset_classification(Dataset):
    def __init__(self, root_dir, data_file):
        """
        Args:
            root_dir (string): Directory of all the images.
            data_file (string): File name of the train/test split file.
        """
        self.root_dir = root_dir
        self.data_file = data_file
    
    def __len__(self):
        return sum(1 for line in open(self.data_file))
    
    def __getitem__(self, idx):
        df = open(self.data_file)
        lines = df.readlines()
        lst = lines[idx].split(',')
        img_name = lst[0].strip('\"')
        img_label = lst[2].strip('\"')
        image_path = os.path.join(self.root_dir, img_name) + '.nii'
        image = nib.load(image_path)
        a = (image.get_fdata()) #convert to np array
        a = customToTensor(a)
        
        if img_label == 'CN': #Cognitive Normal
            label = 0
        elif img_label == 'AD': #Alzheimer's 
            label = 1
        elif img_label == 'MCI': #Mild Cognitive Impairement
            label = 2

        sample = {'image': a, 'label': label}
        
        return sample

In [232]:
#loss function for classification
def a_value(y_true, y_pred_prob, zero_label=0, one_label=1):
    """
    Approximates the AUC by the method described in Hand and Till 2001,
    equation 3.
    
    NB: The class labels should be in the set [0,n-1] where n = # of classes.
    The class probability should be at the index of its label in the predicted
    probability list.
    
    Args:
        y_true: actual labels of test data 
        y_pred_prob: predicted class probability
        zero_label: label for positive class
        one_label: label for negative class
    Returns:
        The A-value as a floating point.
    """
    
    idx = np.isin(y_true, [zero_label, one_label])
    labels = y_true[idx]
    prob = y_pred_prob[idx, zero_label]
    sorted_ranks = labels[np.argsort(prob)]
    
    n0, n1, sum_ranks = 0, 0, 0
    n0 = np.count_nonzero(sorted_ranks==zero_label)
    n1 = np.count_nonzero(sorted_ranks==one_label)
    sum_ranks = np.sum(np.where(sorted_ranks==zero_label)) + n0
    
    return (sum_ranks - (n0*(n0+1)/2.0)) / float(n0 * n1)  # Eqn 3

class mAUC(nn.Module):
    def __init__(self):
        super(mAUC, self).__init__()
#         self.data = data
#         self.num_classes = num_classes
    
    def forward(self, y_true, y_pred_prob, num_classes):
        import itertools
        """
        Calculates the MAUC over a set of multi-class probabilities and
        their labels. This is equation 7 in Hand and Till's 2001 paper.
        NB: The class labels should be in the set [0,n-1] where n = # of classes.
        The class probability should be at the index of its label in the
        probability list.
        I.e. With 3 classes the labels should be 0, 1, 2. The class probability
        for class '1' will be found in index 1 in the class probability list
        wrapped inside the zipped list with the labels.
        Args:
            data (list): A zipped list (NOT A GENERATOR) of the labels and the
                class probabilities in the form (m = # data instances):
                 [(label1, [p(x1c1), p(x1c2), ... p(x1cn)]),
                  (label2, [p(x2c1), p(x2c2), ... p(x2cn)])
                                 ...
                  (labelm, [p(xmc1), p(xmc2), ... (pxmcn)])
                 ]
            num_classes (int): The number of classes in the dataset.
        Returns:
            The MAUC as a floating point value.
        """
#         def MAUC(y_true, y_pred_prob, num_classes):
        """
        Calculates the MAUC over a set of multi-class probabilities and
        their labels. This is equation 7 in Hand and Till's 2001 paper.

        NB: The class labels should be in the set [0,n-1] where n = # of classes.
        The class probability should be at the index of its label in the
        probability list.

        Args:
            y_true: actual labels of test data 
            y_pred_prob: predicted class probability
            zero_label: label for positive class
            one_label: label for negative class
            num_classes (int): The number of classes in the dataset.

        Returns:
            The MAUC as a floating point value.
        """
        # Find all pairwise comparisons of labels
        class_pairs = [x for x in itertools.combinations(range(num_classes), 2)]

        # Have to take average of A value with both classes acting as label 0 as this
        # gives different outputs for more than 2 classes
        sum_avals = 0
        for pairing in class_pairs:
            sum_avals += (a_value(y_true, y_pred_prob, zero_label=pairing[0], one_label=pairing[1]) +
                          a_value(y_true, y_pred_prob, zero_label=pairing[1], one_label=pairing[0])) / 2.0
        sum_avals = np.array([sum_avals])
        sum_avals = torch.from_numpy(sum_avals)
        return sum_avals * (2 / float(num_classes * (num_classes-1)))  # Eqn 7

#CITE THE PAPER AND GITHUB https://github.com/pritomsaha/Multiclass_AUC/blob/master/multiclass_auc.ipynb

In [233]:
# num_classes = 3
# y_true = np.array([0,1,1,0,2,2])
# y_pred_prob = np.array([[0.5, 0.1, 0.4], [0.5, 0.3, 0.2], [0.1, 0.8, 0.1], [0, 0.4, 0.6], [0.3, 0.2, 0.5], [0.5, 0.1, 0.4]])
# criterion = mAUC()
# mauc = criterion(y_true, y_pred_prob, num_classes)
# print(mauc)

In [234]:
class BD_classify(nn.Module):
    def __init__(self, in_layer, out_layer):
        super(BD_classify, self).__init__()
        self.in_layer = in_layer
        self.out_layer = out_layer
        self.avg = avg_pool()
        self.layer1 = fc_layer(562500, out_layer)
        self.softmax = softmax_layer()
    
    def forward(self, x):
        x0 = self.avg(x)
        x0 = torch.flatten(x0, start_dim = 1, end_dim = -1)
        x1 = self.layer1(x0)
        x2 = self.softmax(x1)
        return x2

In [241]:
def train_epoch_classify(net, data_loader, optimizer, criterion, epoch):
    net.train()
    loss_stat = []
    for i, img_label in enumerate(data_loader):
        img = img_label.get('image')
        label = img_label.get('label')
        img = img.to(device=device)
#         label = torch.as_tensor(label)
#         label_amount = len(label)
#         label.resize_(label_amount)
#         print(label)
#         print(label.shape)
        pred = net(img)
#         print(pred.shape)
        pred = pred.detach().numpy()
#         print(pred)
#         print(label)
        loss = criterion(label, pred, 3)
        loss = Variable(loss, requires_grad = True)
#         print(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#         print(loss.item())
        loss_stat += [loss.item()]
    
#     print ("Epoch {}: [{}/{}] Loss: {:.3f}".format(epoch, len(data_loader), len(data_loader), np.mean(loss_stat))) 
    
    return np.mean(loss_stat)

In [242]:
def valid_epoch_classify(net, data_loader, criterion, epoch):
    net.eval()

    val_loss_stat = []
    for i, img_label in enumerate(data_loader):
        img = img_label.get('image')
        label = img_label.get('label')
        img = img.to(device=device, dtype=torch.float32)
#         label = torch.as_tensor(label)
#         label_amount = len(label)
#         label.resize_(label_amount, 1)
        with torch.no_grad():
            pred = net(img)
            pred = pred.detach().numpy()
            val_loss = criterion(label, pred, 3)
      
        val_loss_stat += [val_loss.item()]
        
#     print ("Val Loss: {:.3f} ".format(np.mean(val_loss_stat)))
    
    return np.mean(val_loss_stat)

In [243]:
classification_data = ADNI_Dataset_classification(r"C:/Users/pbhav/Desktop/NYU/ivp/project/ADNI/", "C:/Users/pbhav/Downloads/ADNI1_Annual_2_Yr_3T_4_23_2022.csv")

In [244]:
in_size = 200 #size after the transformer
net1 = BD_classify(in_size, 3)
net1.to(device)  # run net.to(device) if using GPU
print(net1)

n_params1 = sum(p.numel() for p in net1.parameters() if p.requires_grad)
print('Number of parameters in network: ', n_params1)

BD_classify(
  (avg): avg_pool(
    (avgp): AvgPool3d(kernel_size=2, stride=2, padding=0)
  )
  (layer1): fc_layer(
    (fc): Linear(in_features=562500, out_features=3, bias=True)
  )
  (softmax): softmax_layer(
    (softmax): Softmax(dim=1)
  )
)
Number of parameters in network:  1687503


In [245]:
#kfold validation for k = 5
k=5
splits=KFold(n_splits=k,shuffle=True,random_state=42)

criterion = mAUC()

In [None]:
foldperf1={}
for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(classification_data)))):
    train_idx = np.delete(train_idx, 0)
    val_idx = np.delete(val_idx, 0)
    print('Fold {}'.format(fold + 1))
    
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(val_idx)
    train_loader = torch.utils.data.DataLoader(classification_data, batch_size=BATCH_SIZE, sampler=train_sampler)
    test_loader = torch.utils.data.DataLoader(classification_data, batch_size=BATCH_SIZE, sampler=test_sampler)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    net1.to(device)
    optimizer = optim.Adam(net1.parameters(), lr=LR)

    history = {'train_loss': [], 'test_loss': []}

    for epoch in range(NUM_EPOCH):
        train_loss =train_epoch_classify(net1, train_loader, optimizer, criterion, epoch)
        test_loss =valid_epoch_classify(net1, test_loader, criterion, epoch)
#         train_loss = train_loss / len(train_loader.sampler)
#         test_loss = test_loss / len(test_loader.sampler)

        print("Epoch:{}/{} AVG Training Loss:{:.3f} AVG Test Loss:{:.3f}".format(epoch + 1, NUM_EPOCH, train_loss, test_loss))
        history['train_loss'].append(train_loss)
        history['test_loss'].append(test_loss)

#         Save the model after each epoch
        if os.path.isdir(SAVE_PATH_BC):
            torch.save(net1.state_dict(),SAVE_PATH_BC + '\epoch{}.pth'.format(epoch + 1))
        else:
            os.makedirs(model_save_path, exist_ok=True)
            torch.save(net1.state_dict(),SAVE_PATH_BC + '\epoch{}.pth'.format(epoch + 1))
        print('Checkpoint {} saved to {}'.format(epoch + 1, SAVE_PATH_BC + '\epoch{}.pth'.format(epoch + 1)))   
    foldperf1['fold{}'.format(fold+1)] = history  

torch.save(net1,'k_cross1.pt')

Fold 1


In [None]:
testl_f1, tl_f1 = [], []

for f in range(1,k+1):
    tl_f1.append(np.mean(foldperf1['fold{}'.format(f)]['train_loss']))
    testl_f1.append(np.mean(foldperf1['fold{}'.format(f)]['test_loss']))

print('Performance of {} fold cross validation'.format(k))
print("Average Training Loss: {:.3f} \t Average Test Loss: {:.3f}".format(np.mean(tl_f1), np.mean(testl_f1)))     


#  Segmentation