In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import os

In [None]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print(device)
torch.backends.cudnn.benchmark = True

In [None]:
runs_list = next(os.walk('../Data/Images'))[1]

In [None]:
path = "../Data/Labels.csv"
header = ['Run','AD']
Labelsdf = pd.read_csv(path, names=header, usecols=[1,2], skiprows=1, index_col=False)

In [None]:
labels_dict = dict(zip(Labelsdf.Run, Labelsdf.AD))

In [None]:
import re

# BASELINE RGB
img_dir = "../Data/Images"
low_bound = 100
high_bound = 200
img_scale = 1
depth_scale = 33
#list of tuples (image data, AD)
train_data, valid_data, test_data = [], [], []
for run in runs_list:
    print(run)
    run_dir = os.path.join(img_dir, run)
    run_imgs = []
    for filename in os.listdir(run_dir):
        img_num = int(filename[-7:-4])
        if (100 <= img_num < 199):
            img_slice = cv2.imread(os.path.join(run_dir, filename), cv2.IMREAD_GRAYSCALE)
            #print(img_slice)
            #img_slice = cv2.resize(img_slice, (0,0), fx=1/img_scale, fy=1/img_scale, interpolation=cv2.INTER_AREA)
            #img_slice = img_slice[40:216,:]
            img_slice = img_slice/256
            run_imgs.append(img_slice)

    temp_arr = np.array(run_imgs)
    
    if temp_arr.size == 0 or temp_arr.shape[1] != 256 or temp_arr.shape[2] != 176:
#         print(run)
#         print(temp_arr.shape)
        continue
    
    # add black bars on the left/right
#     i = 0
    temp_imgs = []
    for img_slice in temp_arr:
#         print(i)
#         i += 1
        final_slice = np.zeros([256,256])
        final_slice[:, 40:216] = img_slice
#         plt.imshow(final_slice)
#         plt.show()
        temp_imgs.append(final_slice)
    
    temp_arr = np.array(temp_imgs)
    
    final_slices = []
    #print(temp_arr.shape[2])
    
    for y in range(temp_arr.shape[2]):
        xz_pane = temp_arr[:, :, y]
        scaled_xz = cv2.resize(xz_pane, (0, 0), fy=1/depth_scale, fx=1, interpolation=cv2.INTER_AREA)
        final_slices.append(scaled_xz)
    
    if not len(final_slices[0]) == 3:
        continue
    
    final_array = torch.from_numpy(np.dstack(final_slices)).float()
    #print(final_array.shape)
    run_tuple = (final_array, labels_dict[run])
    
    subject_regex = re.compile("OAS(?P<order>[0-9]+)")
    subject = subject_regex.search(run).group(1)
    
    if int(subject[-2]) < 6:
        train_data.append(run_tuple)
    elif 6 <= int(subject[-2]) <=7:
        valid_data.append(run_tuple)
    elif 8 <= int(subject[-2]) <= 9:
        test_data.append(run_tuple)
    
    
print("Number of data points in train dataset: {}".format(len(train_data)))
print("Number of data points in valid dataset: {}".format(len(valid_data)))
print("Number of data points in test dataset: {}".format(len(test_data)))

    

In [None]:
plt.imshow(valid_data[0][0][1])

In [None]:
from collections import Counter

counts = Counter(x[1] for x in train_data)
print("Number of non-AD scans in train dataset: {}".format(counts[0]))
print("Number of AD scan in train datasets: {}".format(counts[1]))
print("\n")

counts = Counter(x[1] for x in valid_data)
print("Number of non-AD scans in valid dataset: {}".format(counts[0]))
print("Number of AD scan in valid datasets: {}".format(counts[1]))
print("\n")

counts = Counter(x[1] for x in test_data)
print("Number of non-AD scans in test dataset: {}".format(counts[0]))
print("Number of AD scan in test datasets: {}".format(counts[1]))
print("\n")

In [None]:
old_train_data = train_data
train_AD = []
for item in train_data:
    if item[1] == 1:
        train_AD.append(item)
train_data = old_train_data + train_AD * 4

print("Number of data points in new train dataset: {}".format(len(train_data)))
counts = Counter(x[1] for x in train_data)
print("Number of non-AD scans in new train dataset: {}".format(counts[0]))
print("Number of AD scan in new train dataset: {}".format(counts[1]))

In [None]:
from torch.utils.data import Dataset, DataLoader

class T1Dataset(Dataset):
    def __init__(self, data, transform=None):
        # list of tuples (3d image arrays, AD label)
        self.data = data
        # labels.csv
        #self.target = torch.from_numpy(target).long()
        #self.transform = transforms.Compose([transforms.ToTensor()])
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        #scan = torch.from_numpy(self.data[index][0]).float()
        scan = self.data[index][0]
        y = self.data[index][1]
        return scan, y

In [None]:
train_dataset = T1Dataset(train_data, None)
valid_dataset = T1Dataset(valid_data, None)
test_dataset = T1Dataset(test_data, None)

In [None]:
import torchvision.models

alexnet = torchvision.models.alexnet(pretrained=True)
alexnet.cuda()

train_dataset_list = list(train_dataset)
imgs, train_labels = next(iter(torch.utils.data.DataLoader(train_dataset_list, batch_size=256, shuffle=True)))
imgs = imgs.float()
imgs = imgs.to(device)
train_features = alexnet.features(imgs)

valid_dataset_list = list(valid_dataset)
imgs, valid_labels = next(iter(torch.utils.data.DataLoader(valid_dataset_list, batch_size=256, shuffle=True)))
imgs = imgs.float()
imgs = imgs.to(device)
valid_features = alexnet.features(imgs)

test_dataset_list = list(test_dataset)
imgs, test_labels = next(iter(torch.utils.data.DataLoader(test_dataset_list, batch_size=256, shuffle=True)))
imgs = imgs.float()
imgs = imgs.to(device)
test_features = alexnet.features(imgs)

print(imgs.shape)
print(len(test_dataset_list))

In [None]:
torch.max(imgs[0])

In [None]:
train_features_list = []
for i in range(train_features.shape[0]):
    train_features_list.append((train_features[i], train_labels[i]))
    
valid_features_list = []
for i in range(valid_features.shape[0]):
    valid_features_list.append((valid_features[i], valid_labels[i]))
    
test_features_list = []
for i in range(test_features.shape[0]):
    test_features_list.append((test_features[i], test_labels[i]))

In [None]:
def get_model_name(name, batch_size, learning_rate, epoch):
    path = "model_{0}_bs{1}_lr{2}_epoch{3}".format(name,
                                                   batch_size,
                                                   learning_rate,
                                                   epoch)
    return path

def get_loss(model, train=False):
    if train:
        data = train_features_list
    else:
        data = test_features_list
        
    loader = torch.utils.data.DataLoader(data, batch_size=len(data), shuffle=True)
    criterion = nn.BCEWithLogitsLoss()
    imgs, labels = next(iter(loader))
    imgs, labels = imgs.to(device), labels.to(device)
    out = model(imgs)             # forward pass
    loss = criterion(out, labels.float()) # compute the total loss
    return loss.item()

def get_accuracy(model, train=False):
    if train:
        data = train_features_list
    else:
        #data_acc_loader = torch.utils.data.DataLoader(data_val, batch_size=diff)
        data = valid_features_list

    correct = 0
    total = 0
    for imgs, labels in torch.utils.data.DataLoader(data, batch_size=32, shuffle=True):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs) # We don't need to run F.softmax
        corr = (outputs > 0.0).squeeze().long() == labels
        correct += int(corr.sum())
        #correct += pred.eq(labels.view_as(pred)).sum().item()
        total += imgs.shape[0]
    #print(total)
    return correct / total

In [None]:
from torch.autograd import Variable

def train(model, features_list, num_epochs=10, batch_size=32, learning_rate=1e-4):
    criterion = nn.BCEWithLogitsLoss()
    # use Adam for CNN
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    epochs, train_losses, valid_losses, train_acc, valid_acc = [], [], [], [], []
    
    train_loader = torch.utils.data.DataLoader(features_list, batch_size=batch_size, shuffle=True)
    
    loss = 0
    
    for epoch in range(num_epochs):
        for inputs, labels in iter(train_loader):
            optimizer.zero_grad()
            inputs, labels = inputs.to(device), labels.to(device)
            #print(torch.sum(inputs))
            outputs = model(inputs)
            loss = criterion(outputs, labels.float())
            loss.backward(retain_graph=True)
            optimizer.step()
            #for param in model.parameters():
            #  print(param.grad.data.sum())
    
        epochs.append(epoch)
        train_losses.append(get_loss(model, train=True))
        valid_losses.append(get_loss(model, train=False))
        train_acc.append(get_accuracy(model, train=True))
        valid_acc.append(get_accuracy(model, train=False))
        
        print("Epoch %d; Train Loss %f; Val Loss %f; Train Acc %f; Val Acc %f" % (
              epoch+1, loss, valid_losses[-1], train_acc[-1], valid_acc[-1]))
    
    plt.title("Training Curve")
    plt.plot(epochs, train_losses, label="Train")
    plt.plot(epochs, valid_losses, label="Validation")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.show()

    plt.title("Training Curve")
    plt.plot(epochs, train_acc, label="Train")
    plt.plot(epochs, valid_acc, label="Validation")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend(loc='best')
    plt.show()
    
    print("Final Training Accuracy: {}".format(train_acc[-1]))
    print("Final Validation Accuracy: {}".format(valid_acc[-1]))

In [None]:
class CNN2(nn.Module):
    def __init__(self):
        super(CNN2, self).__init__()
        self.name = "CNN2"
        # 256x6x6
        self.fc1 = nn.Linear(256*7*7, 512)
        nn.init.xavier_uniform(self.fc1.weight)
        #self.fc1 = nn.Linear(3*224*224, 512)
        self.fc2 = nn.Linear(512, 32)
        nn.init.xavier_uniform(self.fc2.weight)
        #self.fc3 = nn.Linear(512, 64)
        #nn.init.xavier_uniform(self.fc3.weight)
        self.fc3 = nn.Linear(32, 1)
        nn.init.xavier_uniform(self.fc3.weight)
        #self.softmax = nn.Softmax()

    def forward(self, x):
        x = x.view(-1, 256*7*7)
        #x = x.view(-1, 3*224*224)
        x = F.relu(self.fc1(x))
        #print(x.shape)
        x = F.relu(self.fc2(x))
        #print(x.shape)
        x = F.relu(self.fc3(x))
        #print(x.shape)
        #x = self.fc4(x)
        #x = self.softmax(x)
        x = x.squeeze(1) # Flatten to [batch_size]
        return x

In [None]:
len(train_features_list)

In [None]:
alexnetCNN2 = CNN2()
alexnetCNN2.cuda()
train(alexnetCNN2, train_features_list, num_epochs=50, batch_size=32, learning_rate=1e-4)

In [None]:
for param in alexnetCNN2.parameters():
    print(param.grad.data.sum())

In [None]:
get_accuracy(alexnetCNN2, train=False)