In [1]:
import random
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import glob
import pickle
from tqdm import tqdm
import nibabel as nib

%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision import transforms

from torch.utils.data import Dataset, TensorDataset, random_split, SubsetRandomSampler, ConcatDataset
from sklearn.model_selection import KFold

from skimage.transform import resize 

from sklearn.metrics import mean_absolute_error as mae

In [3]:
# Use the GPU if you have one
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
#single scale
class fc_layer(nn.Module):
    def __init__(self, in_layer, out_layer):
        super(fc_layer, self).__init__()
        self.fc = nn.Linear(in_layer, out_layer)
        #if in_features=5 and out_features=10 and the input tensor x 
        #has dimensions 2-3-5, then the output tensor will have dimensions 2-3-10???
        #

    def forward(self, x):
        return self.fc(x)

class avg_pool(nn.Module):
    def __init__(self):
        super(avg_pool, self).__init__()
        self.avgp = nn.AvgPool3d(2)
        
    def forward(self, x):
        return self.avgp(x)

class relu_act(nn.Module):
    def __init__(self):
        super(relu_act, self).__init__()
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(x)

class softmax_layer(nn.Module):
    def __init__(self):
        super(softmax_layer, self).__init__()
        self.softmax = nn.Softmax()
    
    def forward(self, x):
        return self.softmax(x)

# Age prediction

In [5]:
class age_pred(nn.Module):
    def __init__(self, in_layer, out_layer):
        super(age_pred, self).__init__()
        self.in_layer = in_layer
        self.out_layer = out_layer
        self.layer1 = fc_layer(562500, out_layer)
        self.avg = avg_pool()
        self.relu = relu_act()
    
    def forward(self, x):
        x0 = self.avg(x)
        x0 = torch.flatten(x0, start_dim = 1, end_dim = -1)
        x1 = self.layer1(x0)
        x2 = self.relu(x1)
        return x1

In [6]:
#loss for regression
class get_mae(nn.Module):
    def __init__(self):
        super(get_mae, self).__init__()
    
    def forward(self, pred, y_test):
        mae_out =  mae(y_pred=pred, y_true=y_test)
        return mae_out

In [5]:
def customToTensor(img):
    if isinstance(img, np.ndarray):
        img1 = torch.from_numpy(img)
        img1 = resize_image(img, (150, 150, 200))
        # backward compatibility
        return img1.astype(np.float32)

def resize_image(img_array, trg_size):
    res = resize(img_array, trg_size, mode='reflect', preserve_range=True, anti_aliasing=False)
    # type check
    if type(res) != np.ndarray:
        raise "type error!"
    return res

In [8]:
class ADNI_Dataset_regression(Dataset):
    def __init__(self, root_dir, data_file):
        """
        Args:
            root_dir (string): Directory of all the images.
            data_file (string): File name of the train/test split file.
        """
        self.root_dir = root_dir
        self.data_file = data_file
    
    def __len__(self):
        return sum(1 for line in open(self.data_file))
    
    def __getitem__(self, idx):
        df = open(self.data_file)
        lines = df.readlines()
        lst = lines[idx].split(',')
        img_name = lst[0].strip('\"')
        img_label = float(lst[4].strip('\"'))
        image_path = os.path.join(self.root_dir, img_name) + '.nii'
        image = nib.load(image_path)
        a = (image.get_fdata()) #convert to np array
        a = customToTensor(a)
        
        sample = {'image': a, 'label': img_label}
        
        return sample

In [9]:
# image = nib.load(r"C:\Users\pbhav\Desktop\NYU\ivp\project\ADNI\I30968.nii")
# a = (image.get_fdata()) #convert to np array
# a = customToTensor(a)
# plt.imshow(a[:, :, 100])
# # print(type(a))
# df = open("C:/Users/pbhav/Downloads/ADNI1_Annual_2_Yr_3T_4_23_2022.csv")
# lines = df.readlines()
# lst = lines[1].split(',')
# img_name = lst[0].strip('\"')
# img_label = float(lst[4].strip('\"'))
# print(type(img_label))
# print(img_label)

In [6]:
NUM_EPOCH = 50
BATCH_SIZE = 20
LR = 0.0001
SAVE_PATH_AP = r'C:\Users\pbhav\Desktop\NYU\ivp\project\model\AP' #path to save age prediction model
SAVE_PATH_BC = r'C:\Users\pbhav\Desktop\NYU\ivp\project\model\BC' #path to save disease classification madel

In [11]:
def train_epoch(net, data_loader, optimizer, criterion, epoch):
    net.train()
    loss_stat = []
    for i, img_label in enumerate(data_loader):
        img = img_label.get('image')
        label = img_label.get('label')
        img = img.to(device=device)
        label = torch.as_tensor(label)
        label_amount = len(label)
        label.resize_(label_amount, 1)
#         print(label.shape)
        pred = net(img)
#         print(pred.shape)
#         pred = pred.detach().numpy()
#         print(pred)
        loss = criterion(label, pred)
#         print(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#         print(loss.item())
        loss_stat += [loss.item()]
    
    print ("Epoch {}: [{}/{}] Loss: {:.3f}".format(epoch, len(data_loader), len(data_loader), np.mean(loss_stat))) 
    
    return np.mean(loss_stat)

In [12]:
def valid_epoch(net, data_loader, criterion, epoch):
    net.eval()

    val_loss_stat = []
    for i, img_label in enumerate(data_loader):
        img = img_label.get('image')
        label = img_label.get('label')
        img = img.to(device=device, dtype=torch.float32)
        label = torch.as_tensor(label)
        label_amount = len(label)
        label.resize_(label_amount, 1)
        with torch.no_grad():
            pred = net(img)
            val_loss = criterion(label, pred)
      
        val_loss_stat += [val_loss.item()]
        
    print ("Val Loss: {:.3f} ".format(np.mean(val_loss_stat)))
    
    return np.mean(val_loss_stat)

In [13]:
regression_data = ADNI_Dataset_regression(r"C:/Users/pbhav/Desktop/NYU/ivp/project/ADNI/", "C:/Users/pbhav/Downloads/ADNI1_Annual_2_Yr_3T_4_23_2022.csv")

In [14]:
in_size = 200
#size after the transformer (B, T, 3, C)
#T is shared size in xyz, number of slices in each direction, C is number of hiddens for each slice. Concat along C (feature map)

net = age_pred(in_size, 1)
# net.to(device)  # run net.to(device) if using GPU
print(net)

n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('Number of parameters in network: ', n_params)

age_pred(
  (layer1): fc_layer(
    (fc): Linear(in_features=562500, out_features=1, bias=True)
  )
  (avg): avg_pool(
    (avgp): AvgPool3d(kernel_size=2, stride=2, padding=0)
  )
  (relu): relu_act(
    (relu): ReLU()
  )
)
Number of parameters in network:  562501


In [36]:
#kfold validation for k = 5
k=5
splits=KFold(n_splits=k,shuffle=True,random_state=42)

criterion = nn.L1Loss()

In [37]:
# def k_fold(model, dataset, NUM_EPOCH, LR, BATCH_SIZE, SAVE_PATH, criterion):
foldperf={}
for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(regression_data)))):
    train_idx = np.delete(train_idx, 0)
    val_idx = np.delete(val_idx, 0)
    print('Fold {}'.format(fold + 1))
    
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(val_idx)
    train_loader = torch.utils.data.DataLoader(regression_data, batch_size=BATCH_SIZE, sampler=train_sampler)
    test_loader = torch.utils.data.DataLoader(regression_data, batch_size=BATCH_SIZE, sampler=test_sampler)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=LR)

    history = {'train_loss': [], 'test_loss': []}

    for epoch in range(NUM_EPOCH):
        train_loss =train_epoch(net, train_loader, optimizer, criterion, epoch)
        test_loss =valid_epoch(net, test_loader, criterion, epoch)
        train_loss = train_loss / len(train_loader.sampler)
        test_loss = test_loss / len(test_loader.sampler)

        print("Epoch:{}/{} AVG Training Loss:{:.3f} AVG Test Loss:{:.3f}".format(epoch + 1, NUM_EPOCH, train_loss, test_loss))
        history['train_loss'].append(train_loss)
        history['test_loss'].append(test_loss)

#         Save the model after each epoch
        if os.path.isdir(SAVE_PATH_AP):
            torch.save(net.state_dict(),SAVE_PATH_AP + '\epoch{}.pth'.format(epoch + 1))
        else:
            os.makedirs(model_save_path, exist_ok=True)
            torch.save(net.state_dict(),SAVE_PATH_AP + '\epoch{}.pth'.format(epoch + 1))
        print('Checkpoint {} saved to {}'.format(epoch + 1, SAVE_PATH_AP + '\epoch{}.pth'.format(epoch + 1)))   
    foldperf['fold{}'.format(fold+1)] = history  

torch.save(net,'k_cross.pt')
# return foldperf

Fold 1
Epoch 0: [12/12] Loss: 337.898
Val Loss: 191.526 
Epoch:1/100 AVG Training Loss:1.426 AVG Test Loss:3.246
Checkpoint 1 saved to C:\Users\pbhav\Desktop\NYU\ivp\project\model\AP\epoch1.pth
Epoch 1: [12/12] Loss: 211.199
Val Loss: 168.085 
Epoch:2/100 AVG Training Loss:0.891 AVG Test Loss:2.849
Checkpoint 2 saved to C:\Users\pbhav\Desktop\NYU\ivp\project\model\AP\epoch2.pth
Epoch 2: [12/12] Loss: 155.215
Val Loss: 140.738 
Epoch:3/100 AVG Training Loss:0.655 AVG Test Loss:2.385
Checkpoint 3 saved to C:\Users\pbhav\Desktop\NYU\ivp\project\model\AP\epoch3.pth
Epoch 3: [12/12] Loss: 135.657
Val Loss: 128.029 
Epoch:4/100 AVG Training Loss:0.572 AVG Test Loss:2.170
Checkpoint 4 saved to C:\Users\pbhav\Desktop\NYU\ivp\project\model\AP\epoch4.pth
Epoch 4: [12/12] Loss: 118.174
Val Loss: 94.280 
Epoch:5/100 AVG Training Loss:0.499 AVG Test Loss:1.598
Checkpoint 5 saved to C:\Users\pbhav\Desktop\NYU\ivp\project\model\AP\epoch5.pth
Epoch 5: [12/12] Loss: 78.851
Val Loss: 59.720 
Epoch:6/100

KeyboardInterrupt: 

In [None]:
testl_f, tl_f = [], []

for f in range(1,k+1):
    tl_f.append(np.mean(foldperf['fold{}'.format(f)]['train_loss']))
    testl_f.append(np.mean(foldperf['fold{}'.format(f)]['test_loss']))

print('Performance of {} fold cross validation'.format(k))
print("Average Training Loss: {:.3f} \t Average Test Loss: {:.3f}".format(np.mean(tl_f), np.mean(testl_f)))     
