In [None]:
import csv
from scipy import ndimage
from scipy import linalg
import numpy as np
import nibabel as nib
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
from operator import itemgetter
import pickle

In [None]:
import random

index = np.array(range(817))
random.seed(86)
random.shuffle(index)

train_index = index[0:572]
val_index = index[572:654]
test_index = index[654:817]

In [None]:
lab = np.zeros((817,1)) #first AD:188 then CN:288
lab[0:188, 0] = 1 #AD -- 1; CN -- 0; MCI -- 2
lab[416:817, 0] = 2

In [None]:
mci = 0.
ad = 0.
cn = 0.
total_num = 0.
for i in lab:
    if(i == 1):
        ad += 1
    elif(i == 2):
        mci += 1
    elif(i == 0):
        cn += 1
total_num = mci + ad + cn
print(mci, ad, cn, total_num)

In [None]:
from torch.utils.data import Dataset, DataLoader

def ToTensor(np_array):
    tensor = torch.from_numpy(np_array)
    tensor = tensor.float()
    return tensor

class ADNI_Dataset(Dataset):
    
    def __init__(self, ptype, transform=None):
        
        self.ptype = ptype
        
        self.img = np.load('GRAY_AD_CN_MCI_Jan27.npy') # 416 * ((121 * 121) * 3)
        self.lab = np.zeros((817,1)) #first AD:188 then CN:288
        self.lab[0:188, 0] = 1 #AD -- 1; CN -- 0; MCI -- 2
        self.lab[416:817, 0] = 2
        
        if(self.ptype == 'train'):
            self.img = self.img[train_index]
            self.lab = self.lab[train_index]
        if(self.ptype == 'val'):
            self.img = self.img[val_index]
            self.lab = self.lab[val_index]
        if(self.ptype == 'test'):
            self.img = self.img[test_index]
            self.lab = self.lab[test_index]        

    def __len__(self):#return int: number of mri
        return len(self.img)

    def __getitem__(self, idx):#input the dataset and desires i-th item
       
        sample = {'image': self.img[idx], 'labels': self.lab[idx,0]} 

        return sample

In [None]:
ADNI_Data_train = ADNI_Dataset('train')

ADNI_Data_val = ADNI_Dataset('val')

ADNI_Data_test = ADNI_Dataset('test')


dataloader_train = DataLoader(ADNI_Data_train, batch_size=400,
                        shuffle=True, num_workers=4)
dataloader_val = DataLoader(ADNI_Data_val, batch_size=400,
                        shuffle=True, num_workers=4)
dataloader_test = DataLoader(ADNI_Data_test, batch_size=400,
                        shuffle=True, num_workers=4)


In [None]:
import torch.nn as nn
import torch.nn.functional as F

#initial: 
#121 * 121 * 3(feature map number)
#60 * 60 * 16
#30 * 30 * 32
#15 * 15 * 64
#7 * 7 * 128
#3 * 3 * 256

#Linear: 3 * 3 * 256
#3 * 3 * 256(2304)--->1024 --->256 ---> 64 --->3 
class Net_2D(nn.Module):
    def __init__(self):
        super(Net_2D, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 5, padding=2)
        self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2)
        self.conv4 = nn.Conv2d(64, 128, 5, padding=2)
        self.conv5 = nn.Conv2d(128, 256, 5, padding=2)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(3 * 3 * 256, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64, 3)

    def forward(self, x): #batch_size * 256*256*160*1
        x = self.pool(F.relu(self.conv1(x)))#batch_size * 128 * 128 * 80 * 2
        x = self.pool(F.relu(self.conv2(x)))#batch_size * 64 * 64 * 40 * 4
        x = self.pool(F.relu(self.conv3(x)))#batch_size * 32 * 32 * 20 * 8
        x = self.pool(F.relu(self.conv4(x)))#batch_size * 16 * 16 * 10 * 16
        x = self.pool(F.relu(self.conv5(x)))
        x = x.view(-1, 3 * 3 * 256)#batch_size * 10240(8*8*5*32)
        x = F.relu(self.fc1(x))#batch_size * 1024
        x = F.relu(self.fc2(x))#batch_size * 256
        x = F.relu(self.fc3(x))#batch_size * 64
        x = self.fc4(x)#batch_size * 3
        return x

net_2d = Net_2D().cuda()

In [None]:
import torch.optim as optim
import torch 
Weight = torch.tensor([total_num / cn, total_num / ad, total_num / mci])
criterion = nn.CrossEntropyLoss(weight = Weight.cuda())
optimizer = optim.SGD(net_2d.parameters(), lr=0.001, momentum=0.9)

In [None]:
import numpy as np
import torch

def ToTensor(np_array):
    tensor = torch.from_numpy(np_array)
    tensor = tensor.float()
    return tensor

Validation_loss = []
Training_loss = []
epoch_actual = 0

loss_val_min = 1000000
threshold = 100

epoch = 50000
min_total_loss = 1000000


for i in range(epoch):
    epoch_actual += 1
    for i_batch, sample_batched in enumerate(dataloader_train):
        
        labels = sample_batched['labels'].cuda()
        running_loss = 0.0

        images = sample_batched['image'].float().cuda()
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        #before actual training
        net_2d.train()

        # forward + backward + optimize
        outputs = net_2d(images)

        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()

        # print statistics
        print('[%d] loss: %.3f' %
                (i_batch + 1, loss.item()))
        Training_loss.append(loss.item())
    
    print('$$Finished Training This Epoch')
    #validation for the epoch
    net_2d.eval()
    total_val_loss = 0
    print('Start a new val_epoch')
    for i_batch, sample_batched in enumerate(dataloader_val):

        running_loss = 0.0
        
        labels = sample_batched['labels'].cuda()
        images = sample_batched['image'].float().cuda()
        
        outputs = net_2d(images)
        loss = criterion(outputs, labels.long())
        print('Validation')
        print('[%d] loss: %.3f' %
                (i_batch + 1, loss.item()))
        Validation_loss.append(loss.item())
        print('\n')
        
        total_val_loss = total_val_loss + loss.cpu().data.numpy()
    if(total_val_loss >= loss_val_min):
        if(total_val_loss < 0.5):
            count += 1
    else:
        loss_val_min = total_val_loss
        count = 0
        torch.save(net_2d, 'best_model3_Feb10_NOISE_weighted_3type_plot_val_loss.pkl')

    if(count > threshold):
        break
    
print("we're done training")

In [None]:
import matplotlib.pyplot as plt 

x1 = range(epoch_actual)
Average_Training_loss = []
for i in range(epoch_actual):
    Average_Training_loss.append((Training_loss[2 * i] + Training_loss[2 * i + 1]) / 2.0) 
y1 = Average_Training_loss

x2 = range(epoch_actual)
y2 = Validation_loss

plt.plot(x1, y1, label = "Average Training loss") 
plt.plot(x2, y2, label = "Average Validation loss") 
plt.title('classifer of AD, MCI, and CN')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.legend() 
plt.show() 

In [None]:
import torch
#test after finish all epochs

#load the best_model back 
final_net = torch.load('best_model3_Feb10_NOISE_weighted_3type_plot_val_loss.pkl')
final_net.eval()
correct_sum = 0.
total_sum = 0.
for i_batch, sample_batched in enumerate(dataloader_test):

    running_loss = 0.0
    labels = sample_batched['labels'].cuda()
    
    images = sample_batched['image'].float().cuda()

    outputs = final_net(images)#n * 3
    results_type = np.argmax(outputs.cpu().data.numpy(), axis=1)
    results_type = ToTensor(results_type)

    #calculate accuracy
    labels = labels.cpu()
    print("predicted:\n", results_type, "\nactual:\n", labels)
    
######
AD_num = 0
MCI_num = 0
CN_num = 0
correct_sum_AD_CN = 0
correct_sum_AD_MCI = 0
correct_sum_CN_AD = 0
correct_sum_CN_MCI = 0
correct_sum_MCI_CN = 0
correct_sum_MCI_AD = 0
for results in results_type:
    if (results.float() == 1):
        AD_num += 1
    if (results.float() == 0):
        CN_num += 1
    if (results.float() == 2):
        MCI_num += 1
print("labels", AD_num, CN_num, MCI_num)

for i in range(len(results_type)):
    if(labels[i].float() == results_type[i].float()):
        correct_sum += 1

for i in range(len(results_type)):
    if(labels[i].float() == 1):
        if(results_type[i].float() == 1):
            correct_sum_AD += 1
        if(results_type[i].float() == 0):
            correct_sum_AD_CN += 1
        if(results_type[i].float() == 2):
            correct_sum_AD_MCI += 1
    if(labels[i].float() == 0):
        if(results_type[i].float() == 0):
            correct_sum_CN += 1
        if(results_type[i].float() == 1):
            correct_sum_CN_AD += 1
        if(results_type[i].float() == 2):
            correct_sum_CN_MCI += 1
    if(labels[i].float() == 2):
        if(results_type[i].float() == 2):
            correct_sum_MCI += 1
        if(results_type[i].float() == 1):
            correct_sum_MCI_AD += 1
        if(results_type[i].float() == 0):
            correct_sum_MCI_CN += 1
    total_sum += 1
    
print("correct_sum", correct_sum)
print("correct_sum_AD", correct_sum_AD, correct_sum_AD_CN, correct_sum_AD_MCI)
print("correct_sum_CN", correct_sum_CN_AD, correct_sum_CN, correct_sum_CN_MCI)
print("correct_sum_MCI", correct_sum_MCI_AD, correct_sum_MCI_CN, correct_sum_MCI)
print(total_sum)