In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import os

In [2]:
runs_list = next(os.walk('./Data/Images'))[1]

In [3]:
path = "./Data/Labels.csv"
header = ['Run','AD']
Labelsdf = pd.read_csv(path, names=header, usecols=[1,2], skiprows=1, index_col=False)

In [4]:
labels_dict = dict(zip(Labelsdf.Run, Labelsdf.AD))

In [5]:
labels_dict['sub-OAS30019_ses-d0376']

1

In [6]:
img_dir = "./Data/Images"
low_bound = 100
high_bound = 200
img_scale = 4
depth_scale = 4
#list of tuples (image data, AD)
all_data = []
for run in runs_list:
    run_dir = os.path.join(img_dir, run)
    run_imgs = []
    for filename in os.listdir(run_dir):
        img_num = int(filename[-7:-4])
        if (100 <= img_num < 200):
            img_slice = cv2.imread(os.path.join(run_dir, filename), cv2.IMREAD_GRAYSCALE)
            #print(img_slice)
            img_slice = cv2.resize(img_slice, (0,0), fx=1/img_scale, fy=1/img_scale, interpolation=cv2.INTER_AREA)
            run_imgs.append(img_slice)

    temp_arr = np.array(run_imgs)
    
    if temp_arr.size == 0:
        print(run)
        continue
    
    final_slices = []
    #print(temp_arr.shape[2])
    
    for y in range(temp_arr.shape[2]):
        xz_pane = temp_arr[:, :, y]
        scaled_xz = cv2.resize(xz_pane, (0, 0), fy=1/depth_scale, fx=1, interpolation=cv2.INTER_AREA)
        final_slices.append(scaled_xz)
    
    final_array = np.dstack(final_slices)
    #print(final_array.shape)
    run_tuple = (final_array, labels_dict[run])
    all_data.append(run_tuple)
    
print(len(all_data), len(all_data[0]))
    

KeyboardInterrupt: 

In [7]:
# BASELINE RGB
img_dir = "./Data/Images"
low_bound = 100
high_bound = 200
img_scale = 1
depth_scale = 33
#list of tuples (image data, AD)
all_data = []
for run in runs_list:
    run_dir = os.path.join(img_dir, run)
    run_imgs = []
    for filename in os.listdir(run_dir):
        img_num = int(filename[-7:-4])
        if (100 <= img_num < 199):
            img_slice = cv2.imread(os.path.join(run_dir, filename), cv2.IMREAD_GRAYSCALE)
            #print(img_slice)
            #img_slice = cv2.resize(img_slice, (0,0), fx=1/img_scale, fy=1/img_scale, interpolation=cv2.INTER_AREA)
            img_slice = img_slice[40:216,:]
            run_imgs.append(img_slice)

    temp_arr = np.array(run_imgs)
    
    if temp_arr.size == 0 or not temp_arr.shape[1] == 176 or not temp_arr.shape[2] == 176:
        print(run)
        print(temp_arr.shape)
        continue
    
    final_slices = []
    #print(temp_arr.shape[2])
    
    for y in range(temp_arr.shape[2]):
        xz_pane = temp_arr[:, :, y]
        scaled_xz = cv2.resize(xz_pane, (0, 0), fy=1/depth_scale, fx=1, interpolation=cv2.INTER_AREA)
        final_slices.append(scaled_xz)
    
    if not len(final_slices[0]) == 3:
        continue
    
    final_array = torch.from_numpy(np.dstack(final_slices)).float()
    #print(final_array.shape)
    run_tuple = (final_array, labels_dict[run])
    all_data.append(run_tuple)
    
print(len(all_data), len(all_data[0]))
    

sub-OAS30059_ses-d0230_run-01
(29, 176, 256)
sub-OAS30059_ses-d0230_run-02
(29, 176, 256)
sub-OAS30059_ses-d0230_run-03
(29, 176, 256)
sub-OAS30059_ses-d0230_run-04
(29, 176, 256)
sub-OAS30059_ses-d1188
(99, 176, 160)
sub-OAS30109_ses-d0270
(99, 176, 160)
sub-OAS30119_ses-d1209_run-01
(29, 176, 256)
sub-OAS30119_ses-d1209_run-02
(29, 176, 256)
sub-OAS30119_ses-d1209_run-03
(29, 176, 256)
sub-OAS30249_ses-d0091_run-01
(29, 176, 256)
sub-OAS30249_ses-d0091_run-02
(29, 176, 256)
sub-OAS30249_ses-d0091_run-03
(29, 176, 256)
sub-OAS30249_ses-d0749
(99, 176, 160)
sub-OAS30259_ses-d0000_run-01
(29, 176, 256)
sub-OAS30259_ses-d0000_run-02
(29, 176, 256)
sub-OAS30259_ses-d0000_run-03
(29, 176, 256)
sub-OAS30259_ses-d0679_run-01
(29, 176, 256)
sub-OAS30259_ses-d0679_run-02
(29, 176, 256)
sub-OAS30259_ses-d0679_run-03
(29, 176, 256)
sub-OAS30369_ses-d2819_run-01
(29, 176, 256)
sub-OAS30369_ses-d2819_run-02
(29, 176, 256)
sub-OAS30369_ses-d2819_run-03
(29, 176, 256)
sub-OAS30379_ses-d1169_run-01
(

In [113]:
temp_arr.shape[1]

176

In [8]:
from torch.utils.data import Dataset, DataLoader

class T1Dataset(Dataset):
    def __init__(self, data, transform=None):
        # list of tuples (3d image arrays, AD label)
        self.data = data
        # labels.csv
        #self.target = torch.from_numpy(target).long()
        #self.transform = transforms.Compose([transforms.ToTensor()])
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        #scan = torch.from_numpy(self.data[index][0]).float()
        scan = self.data[index][0]
        y = self.data[index][1]
        return scan, y

In [9]:
scan_dataset = T1Dataset(all_data, None)

In [16]:
def train(model, dataset, num_epochs=10, batch_size=32, learning_rate=1e-4):
    criterion = nn.BCEWithLogitsLoss()
    # use Adam for CNN
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    epochs, train_losses, val_losses, train_acc, val_acc = [], [], [], [], []
    
    train_loader = DataLoader(dataset, batch_size, shuffle=True)
    
    loss = 0
    
    for epoch in range(num_epochs):
        for i, (inputs, labels) in enumerate(train_loader):
            outputs = model(inputs)
            loss = criterion(outputs, labels.float())
            loss.backward(retain_graph=True)
            optimizer.step()
            optimizer.zero_grad()

        epochs.append(epoch)
        train_losses.append(loss)
        
        print(f"Epoch: {epoch} Training Loss: {train_losses[-1]}")
    
    plt.title("Training Curve")
    plt.plot(epochs, train_losses, label="Train")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.show()
    
    print("Final Training Loss: {}".format(train_losses[-1]))

In [11]:
import torchvision.models

alexnet = torchvision.models.alexnet(pretrained=True)
#scan_dataset = T1Dataset(all_data, None)

dataset_list = list(scan_dataset)

imgs, labels = next(iter(torch.utils.data.DataLoader(dataset_list, batch_size=len(dataset_list), shuffle=True)))
imgs = imgs.float()
features = alexnet.features(imgs)

In [12]:
features_list = []
for i in range(features.shape[0]):
    features_list.append((features[i], labels[i]))

In [13]:
print(features_list[0][0].shape)

torch.Size([256, 4, 4])


In [14]:
class CNN2(nn.Module):
    def __init__(self):
        super(CNN2, self).__init__()
        self.name = "CNN2"
        # 256x6x6
        self.fc1 = nn.Linear(256*4*4, 1024)
        nn.init.xavier_uniform(self.fc1.weight)
        #self.fc1 = nn.Linear(3*224*224, 512)
        self.fc2 = nn.Linear(1024, 512)
        nn.init.xavier_uniform(self.fc2.weight)
        self.fc3 = nn.Linear(512, 64)
        nn.init.xavier_uniform(self.fc3.weight)
        self.fc4 = nn.Linear(64, 1)
        nn.init.xavier_uniform(self.fc4.weight)
        #self.softmax = nn.Softmax()

    def forward(self, x):
        x = x.view(-1, 256*4*4)
        #x = x.view(-1, 3*224*224)
        x = F.relu(self.fc1(x))
        #print(x.shape)
        x = F.relu(self.fc2(x))
        #print(x.shape)
        x = F.relu(self.fc3(x))
        #print(x.shape)
        x = self.fc4(x)
        #x = self.softmax(x)
        x = x.squeeze(1) # Flatten to [batch_size]
        return x

In [17]:
alexnetCNN2 = CNN2()
train(alexnetCNN2, features_list, num_epochs=10, batch_size=32, learning_rate=1e-4)

  import sys
  # Remove the CWD from sys.path while we load stuff.
  if sys.path[0] == '':
  


RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.