In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os
import nibabel as nib
import numpy as np
import random

import pandas as pd

In [2]:
# Set constant variables
data_path = '../ds000201-download/'

In [276]:
# Load data and preprocess
def load_subject_and_compress(subject_name, task_name, data_shape, data_path):
    subject_path = data_path + subject_name
    
    ses_1_data_path = subject_path + f'/ses-1/func/{subject_name}_ses-1_{task_name}.nii.gz'
    ses_2_data_path = subject_path + f'/ses-2/func/{subject_name}_ses-2_{task_name}.nii.gz'
    
    print(ses_1_data_path)
    try:
        ses_1_img = nib.load(ses_1_data_path)
        ses_2_img = nib.load(ses_2_data_path)
    except(FileNotFoundError):
        return None
    
    ses_1_data = ses_1_img.get_fdata()
    ses_2_data = ses_2_img.get_fdata()
    
    print(ses_1_data.shape)
    #Average along time axis to compress data
    ses_1 = np.mean(ses_1_data, axis=-1)
    ses_1 = np.pad(ses_1, [(0, data_shape[0] - ses_1_data.shape[0]), 
                           (0, data_shape[1] - ses_1_data.shape[1]), 
                           (0, data_shape[2] - ses_1_data.shape[2])])
    
    ses_2 = np.mean(ses_2_data, axis=-1)
    ses_2 = np.pad(ses_2, [(0, data_shape[0] - ses_2_data.shape[0]), 
                           (0, data_shape[1] - ses_2_data.shape[1]), 
                           (0, data_shape[2] - ses_2_data.shape[2])])
    
    # Normalize
    
    ses_1 = (ses_1 - np.mean(ses_1)) / np.std(ses_1)
    ses_2 = (ses_2 - np.mean(ses_2)) / np.std(ses_2)
    
    return ses_1, ses_2

In [277]:
# Get the compressed subject data for all valid subjects
def retrieve_data(data_path, task_name, data_shape):
    rested_data = []
    sleep_deprived_data = []
    
    participants = pd.read_csv('../ds000201-download/participants.tsv', sep='\t')
    rows = participants.iterrows()
    
    for idx, row in rows:
        pid = row['participant_id']
        
        compressed_subject_data = load_subject_and_compress(pid, task_name, data_shape, data_path)
        if (compressed_subject_data != None):
            print(compressed_subject_data[0].shape)
            print(compressed_subject_data[1].shape)
            
            sd_session_idx = row['Sl_cond'] - 1
            rested_session_idx = int(not sd_session_idx)
            
            rested_data.append(compressed_subject_data[rested_session_idx])
            sleep_deprived_data.append(compressed_subject_data[sd_session_idx])
            
#     for root, dirs, files in os.walk(data_path):
#         for folder_name in dirs:
#             print(folder_name)
#             if folder_name[:3] == 'sub' and len(folder_name) == 8:
#                 compressed_subject_data = load_subject_and_compress(folder_name, data_path)
#                 if (compressed_subject_data != None):
#                     ses_1_data.append(compressed_subject_data[0])
#                     ses_2_data.append(compressed_subject_data[1]) 
    
    return np.array(rested_data), np.array(sleep_deprived_data)

# Split data into training and test sets
def split_training_test(rested_data, sleep_deprived_data):
    # 0 label means rested, 1 label means sleepy
    assert(rested_data.shape == sleep_deprived_data.shape)
    
    N = rested_data.shape[0]
    
    shuffler = np.random.permutation(N)
    rested_shuffle = rested_data[shuffler]
    sleep_deprived_shuffle = sleep_deprived_data[shuffler]
    
    num_train = int(.8 * N)
    num_test = N - num_train
    
    #Split into training and test sets, keeping subjects consistent within sets
    rested_train = rested_shuffle[:num_train]
    rested_test = rested_shuffle[-num_test:]
    sleep_deprived_train = sleep_deprived_shuffle[:num_train]
    sleep_deprived_test = sleep_deprived_shuffle[-num_test:]
    
    #Define labels for training and test sets
    rested_train_labels = [0 for i in range(num_train)]
    rested_test_labels = [0 for i in range(num_test)]
    sleep_deprived_train_labels = [1 for i in range(num_train)]
    sleep_deprived_test_labels = [1 for i in range(num_train)]
    
    #zip labels up
    rested_train_set = zip(rested_train, rested_train_labels)
    sleep_deprived_train_set = zip(sleep_deprived_train, sleep_deprived_train_labels)
    rested_test_set = zip(rested_test, rested_test_labels)
    sleep_deprived_test_set =zip(sleep_deprived_test, sleep_deprived_test_labels)
    
    #put all training and test fMRI images into the same train set, with labels included now
    train_set = list(rested_train_set) + list(sleep_deprived_train_set)
    test_set = list(rested_test_set) + list(sleep_deprived_test_set)
    
    return train_set, test_set

In [5]:
# Retrieve rest data
ses_1_rest_task_data, ses_2_rest_task_data = retrieve_data(data_path, 'task-rest_bold', data_shape=(128,128,49))

../ds000201-download/sub-9001/ses-1/func/sub-9001_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9002/ses-1/func/sub-9002_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9003/ses-1/func/sub-9003_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9004/ses-1/func/sub-9004_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9005/ses-1/func/sub-9005_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9007/ses-1/func/sub-9007_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9008/ses-1/func/sub-9008_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9009/ses-1/func/sub-9009_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9011/ses-1/func/sub-9011_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-901

(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9091/ses-1/func/sub-9091_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9092/ses-1/func/sub-9092_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9093/ses-1/func/sub-9093_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9094/ses-1/func/sub-9094_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9095/ses-1/func/sub-9095_ses-1_task-rest_bold.nii.gz
../ds000201-download/sub-9096/ses-1/func/sub-9096_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9098/ses-1/func/sub-9098_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9100/ses-1/func/sub-9100_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)


In [428]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 8, 7, stride = 2)
        
        self.linear1 = nn.Linear(4608, 128)
        self.linear2 = nn.Linear(128, 2)
        
        self.max_pool = nn.MaxPool3d(5, stride = 5)
        self.relu = nn.LeakyReLU()
        self.softmax = nn.Softmax(dim = 1)
        
        self.dropout = nn.Dropout()
        
    def forward(self, x):
        x = self.relu(self.max_pool(self.conv1(x)))
        
        x = self.dropout(self.relu(self.linear1(torch.flatten(x, start_dim=1))))
        x = self.softmax(self.linear2(x))
        
        return x

In [429]:
class SleepDeprivationNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 8, 5, stride = 2)
        self.conv2 = nn.Conv3d(8, 16, 3, stride = 2)
        
        self.linear1 = nn.Linear(1568,  128)
        self.linear2 = nn.Linear(128, 64)
        self.linear3 = nn.Linear(128, 2)
        
        self.dropout = nn.Dropout()
        
        self.max_pool = nn.MaxPool3d(2, stride = 2)
    def forward(self, x):
        x = self.max_pool(self.conv1(x))
        x = self.max_pool(self.conv2(x))
        
        x = self.dropout(F.relu(self.linear1(torch.flatten(x, start_dim=1))))
        #x = F.relu(self.linear2(x))
        x = torch.sigmoid(self.linear3(x))
        
        return x

In [430]:
# Instantiate net
is_simple_net = True

if is_simple_net:
    net = SimpleNet()
else:
    net = SleepDeprivationNet()
net = net.double()

device = torch.device("cuda:0")

if (torch.cuda.is_available()):
    net.to(device)

In [431]:
# Simple sanity checks 
test = torch.from_numpy(ses_1_rest_task_data[0]).to(device)
label = torch.tensor([0]).to(device)
print(test.shape)
test = test[None,None, :]
print(test.shape)

test = test.to(device)

torch.Size([128, 128, 49])
torch.Size([1, 1, 128, 128, 49])


In [432]:
#Overfit one sample for sanity

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-3)
num_epochs = 100

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min = 1e-6)

net.train()

#Save model each iteration of training
for epoch in range(num_epochs):     
    optimizer.zero_grad()

    outputs = net(test)
    
    loss = criterion(outputs, label)
    loss.backward()
    optimizer.step()

    print("loss", loss)
    
    scheduler.step()

loss tensor(0.6178, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.6161, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.5057, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.4705, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.4143, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.4582, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.4696, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.4101, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.4112, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.4124, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3809, device='cuda:0', dtype=torch.float64,
       grad_

loss tensor(0.3190, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3212, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3210, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3229, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3190, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3223, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3404, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3547, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3333, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3280, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)


In [433]:
net(test)

tensor([[0.9947, 0.0053]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SoftmaxBackward0>)

In [434]:
# Split training and test set
train_set, test_set = split_training_test(ses_1_rest_task_data, ses_2_rest_task_data)

#Shuffle train and test data
random.shuffle(train_set)
random.shuffle(test_set)

In [435]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-5)
num_epochs = 300

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min = 1e-6)
                                                                                                 
batch_size = 8
train_len = len(train_set)

In [436]:
# Train NN
net.train()

#Save model each iteration of training
for epoch in range(num_epochs):
    random.shuffle(train_set)
    
    running_loss = 0.0
    iters = 0
    for i in range(0, train_len, batch_size):
        batch = train_set[i:i+batch_size]
        
        inputs = np.array([j for j,k in batch])
        inputs = torch.from_numpy(inputs[:, None, :]).to(device)
        
        labels = torch.from_numpy(np.array([k for j,k in batch])).type(torch.LongTensor).to(device)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss
        iters += 1
    
    scheduler.step()
    print(f'Epoch {epoch}, average loss {running_loss/iters}')
    
    if epoch % 10 == 0:
        torch.save(net.state_dict(), './sleep_deprivation_net.pt')
        

Epoch 0, average loss 0.8032173511910429
Epoch 1, average loss 0.80031716141364
Epoch 2, average loss 0.8034415282570657
Epoch 3, average loss 0.7999709536854148
Epoch 4, average loss 0.7955011822826267
Epoch 5, average loss 0.79531875013193
Epoch 6, average loss 0.7926683046987656
Epoch 7, average loss 0.7939729375349459
Epoch 8, average loss 0.7968384791241452
Epoch 9, average loss 0.7973933456821808
Epoch 10, average loss 0.7967315111765081
Epoch 11, average loss 0.8003306691785655
Epoch 12, average loss 0.7981726208723278
Epoch 13, average loss 0.8014327933266036
Epoch 14, average loss 0.7952169328596264
Epoch 15, average loss 0.7977170901780056
Epoch 16, average loss 0.7950333522080988
Epoch 17, average loss 0.7981245441831237
Epoch 18, average loss 0.7988770497556574
Epoch 19, average loss 0.7952206834391033
Epoch 20, average loss 0.7997744621946948
Epoch 21, average loss 0.7985757286427375
Epoch 22, average loss 0.7981583045758058
Epoch 23, average loss 0.8035331098637064
Epoch 

Epoch 194, average loss 0.7832257090784779
Epoch 195, average loss 0.786331079627466
Epoch 196, average loss 0.7938841522817406
Epoch 197, average loss 0.7845998237139133
Epoch 198, average loss 0.7803206183252858
Epoch 199, average loss 0.7897082345053781
Epoch 200, average loss 0.7929055820761808
Epoch 201, average loss 0.7894570654789108
Epoch 202, average loss 0.7891576716635792
Epoch 203, average loss 0.7808711559930075
Epoch 204, average loss 0.7894891536860409
Epoch 205, average loss 0.7885004435390722
Epoch 206, average loss 0.7876092165136849
Epoch 207, average loss 0.7956322102916077
Epoch 208, average loss 0.7897322671523145
Epoch 209, average loss 0.7888535578645524
Epoch 210, average loss 0.7886886607961009
Epoch 211, average loss 0.7885704605074007
Epoch 212, average loss 0.7819584751378315
Epoch 213, average loss 0.7871884845330475
Epoch 214, average loss 0.7872915072495374
Epoch 215, average loss 0.7854169476886261
Epoch 216, average loss 0.778942772075635
Epoch 217, av

KeyboardInterrupt: 

In [437]:
def classify(test_data, model):
    outputs = model(test_data)
    
    return torch.argmax(outputs, dim=1)

def compute_accuracy(inputs, labels, model):
    outputs = classify(inputs, model)
    
    accuracy = (outputs == labels).sum()/outputs.shape[0]
    
    return accuracy

def print_accuracy(test_set, model):
    inputs = np.array([i for i,j in test_set])
    inputs = torch.from_numpy(inputs[:, None, :]).to(device)
    
    labels = torch.from_numpy(np.array([j for i,j in test_set])).to(device)
    
    accuracy = compute_accuracy(inputs, labels, model)
    
    print(accuracy)

In [438]:
# Run trained NN on test data and obtain accuracies
net.eval()
print_accuracy(test_set, net)

tensor(0.5000, device='cuda:0')
