In [282]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os
import nibabel as nib
import numpy as np
import random

In [289]:
# Set constant variables
data_path = '../ds000201-download/'

In [587]:
# Load data and preprocess
def load_subject_and_compress(subject_name, task_name, data_shape, data_path):
    subject_path = data_path + subject_name
    
    ses_1_data_path = subject_path + f'/ses-1/func/{subject_name}_ses-1_{task_name}.nii.gz'
    ses_2_data_path = subject_path + f'/ses-2/func/{subject_name}_ses-2_{task_name}.nii.gz'
    
    print(ses_1_data_path)
    try:
        ses_1_img = nib.load(ses_1_data_path)
        ses_2_img = nib.load(ses_2_data_path)
    except(FileNotFoundError):
        return None
    
    ses_1_data = ses_1_img.get_fdata()
    ses_2_data = ses_2_img.get_fdata()
    
    #Average along time axis to compress data
    ses_1 = np.mean(ses_1_data, axis=-1)
    ses_1 = np.pad(ses_1, [(0, data_shape[0] - ses_1_data.shape[0]), 
                           (0, data_shape[1] - ses_1_data.shape[1]), 
                           (0, data_shape[2] - ses_1_data.shape[2])])
    
    ses_2 = np.mean(ses_2_data, axis=-1)
    ses_2 = np.pad(ses_2, [(0, data_shape[0] - ses_2_data.shape[0]), 
                           (0, data_shape[1] - ses_2_data.shape[1]), 
                           (0, data_shape[2] - ses_2_data.shape[2])])
    
    # Normalize
    
    ses_1 = (ses_1 - np.mean(ses_1)) / np.std(ses_1)
    ses_2 = (ses_2 - np.mean(ses_2)) / np.std(ses_2)
    
    return ses_1, ses_2

In [588]:
# Get the compressed subject data for all valid subjects
def retrieve_data(data_path, task_name, data_shape):
    ses_1_data = []
    ses_2_data = []
    
    for i in range(9001, 9101):
        compressed_subject_data = load_subject_and_compress(f'sub-{i}', task_name, data_shape, data_path)
        if (compressed_subject_data != None):
            print(compressed_subject_data[0].shape)
            print(compressed_subject_data[1].shape)
            ses_1_data.append(compressed_subject_data[0])
            ses_2_data.append(compressed_subject_data[1])
            
#     for root, dirs, files in os.walk(data_path):
#         for folder_name in dirs:
#             print(folder_name)
#             if folder_name[:3] == 'sub' and len(folder_name) == 8:
#                 compressed_subject_data = load_subject_and_compress(folder_name, data_path)
#                 if (compressed_subject_data != None):
#                     ses_1_data.append(compressed_subject_data[0])
#                     ses_2_data.append(compressed_subject_data[1]) 
    
    return np.array(ses_1_data), np.array(ses_2_data)

# Split data into training and test sets
def split_training_test(ses_1_data, ses_2_data):
    # 0 label means rested, 1 label means sleepy
    assert(ses_1_data.shape == ses_2_data.shape)
    
    N = ses_1_data.shape[0]
    
    shuffler = np.random.permutation(N)
    ses_1_shuffle = ses_1_data[shuffler]
    ses_2_shuffle = ses_2_data[shuffler]
    
    num_train = int(.8 * N)
    num_test = N - num_train
    
    #Split into training and test sets, keeping subjects consistent within sets
    ses_1_train = ses_1_shuffle[:num_train]
    ses_1_test = ses_1_shuffle[-num_test:]
    ses_2_train = ses_2_shuffle[:num_train]
    ses_2_test = ses_2_shuffle[-num_test:]
    
    #Define labels for training and test sets
    ses_1_train_labels = [0 for i in range(num_train)]
    ses_1_test_labels = [0 for i in range(num_test)]
    ses_2_train_labels = [1 for i in range(num_train)]
    ses_2_test_labels = [1 for i in range(num_train)]
    
    #zip labels up
    ses_1_train_set = zip(ses_1_train, ses_1_train_labels)
    ses_2_train_set = zip(ses_2_train, ses_2_train_labels)
    ses_1_test_set = zip(ses_1_test, ses_1_test_labels)
    ses_2_test_set =zip(ses_2_test, ses_2_test_labels)
    
    #put all training and test fMRI images into the same train set, with labels included now
    train_set = list(ses_1_train_set) + list(ses_2_train_set)
    test_set = list(ses_1_test_set) + list(ses_2_test_set)
    
    return train_set, test_set

In [589]:
# Retrieve rest data
ses_1_rest_task_data, ses_2_rest_task_data = retrieve_data(data_path, 'task-rest_bold', data_shape=(128,128,49))

../ds000201-download/sub-9001/ses-1/func/sub-9001_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9002/ses-1/func/sub-9002_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9003/ses-1/func/sub-9003_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9004/ses-1/func/sub-9004_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9005/ses-1/func/sub-9005_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9006/ses-1/func/sub-9006_ses-1_task-rest_bold.nii.gz
../ds000201-download/sub-9007/ses-1/func/sub-9007_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9008/ses-1/func/sub-9008_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9009/ses-1/func/sub-9009_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9010/ses-1/func/sub-9010_ses-1_ta

(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9082/ses-1/func/sub-9082_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9083/ses-1/func/sub-9083_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9084/ses-1/func/sub-9084_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9085/ses-1/func/sub-9085_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9086/ses-1/func/sub-9086_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9087/ses-1/func/sub-9087_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9088/ses-1/func/sub-9088_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9089/ses-1/func/sub-9089_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49)
../ds000201-download/sub-9090/ses-1/func/sub-9090_ses-1_task-rest_bold.nii.gz
(128, 128, 49)
(128, 128, 49

In [590]:
class SleepDeprivationNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 8, 7, stride = 2)
        self.conv2 = nn.Conv3d(1, 16, 5, stride = 2)
        self.conv3 = nn.Conv3d(1, 32, 3, stride = 2)
        
        self.linear1 = nn.Linear(25088,  2048)
        self.linear2 = nn.Linear(2048, 128)
        self.linear3 = nn.Linear(128, 2)
        
        self.dropout = nn.Dropout()
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = torch.mean(x, dim=1)
        x = x[:, None, :]
        x = F.relu(self.conv2(x))
        x = torch.mean(x, dim=1)
        x = x[:, None, :]
        x = F.relu(self.conv3(x))
        
        x = self.dropout(F.relu(self.linear1(self.dropout(torch.flatten(x, start_dim=1)))))
        x = self.dropout(F.relu(self.linear2(x)))
        x = torch.sigmoid(self.linear3(x))
        
        return x

In [591]:
# Instantiate net
net = SleepDeprivationNet()
net = net.double()

device = torch.device("cuda:0")

if (torch.cuda.is_available()):
    net.to(device)

In [592]:
# Simple sanity checks 
test = torch.from_numpy(ses_1_train[0])
label = torch.tensor([0]).to(device)
print(test.shape)
test = test[None, None, :]
print(test.shape)

test = test.to(device)

torch.Size([128, 128, 49])
torch.Size([1, 1, 128, 128, 49])


In [593]:
#Overfit one sample for sanity

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-3)
num_epochs = 100

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min = 1e-6)

net.train()

#Save model each iteration of training
for epoch in range(num_epochs):     
    optimizer.zero_grad()

    outputs = net(test)
    
    loss = criterion(outputs, label)
    loss.backward()
    optimizer.step()

    print("loss", loss)
    
    scheduler.step()

loss tensor(0.7684, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.5547, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3156, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.5118, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3301, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_

loss tensor(0.3142, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3153, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_fn=<NllLossBackward0>)
loss tensor(0.3133, device='cuda:0', dtype=torch.float64,
       grad_

In [594]:
net(test)

tensor([[1.0000e+00, 3.6974e-08]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SigmoidBackward0>)

In [595]:
# Split training and test set
train_set, test_set = split_training_test(ses_1_rest_task_data, ses_2_rest_task_data)

#Shuffle train and test data
random.shuffle(train_set)
random.shuffle(test_set)

In [596]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-3)
num_epochs = 10000

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min = 1e-6)
                                                                                                 
batch_size = 8
train_len = len(train_set)

In [None]:
# Train NN
net.train()

#Save model each iteration of training
for epoch in range(num_epochs):
    random.shuffle(train_set)
    
    running_loss = 0.0
    iters = 0
    for i in range(0, train_len, batch_size):
        batch = train_set[i:i+batch_size]
        
        inputs = np.array([j for j,k in batch])
        inputs = torch.from_numpy(inputs[:, None, :]).to(device)
        
        labels = torch.from_numpy(np.array([k for j,k in batch])).type(torch.LongTensor).to(device)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss
        iters += 1
    
    scheduler.step()
    print(f'Epoch {epoch}, average loss {running_loss/iters}')
    
    if epoch % 10 == 0:
        torch.save(net.state_dict(), './sleep_deprivation_net.pt')
        

Epoch 0, average loss 0.6930998087534775
Epoch 1, average loss 0.6931362178908568
Epoch 2, average loss 0.692919148821814
Epoch 3, average loss 0.6928898914867372
Epoch 4, average loss 0.6933186761570876
Epoch 5, average loss 0.6929547467276119
Epoch 6, average loss 0.6929960583778279
Epoch 7, average loss 0.6928863101019779
Epoch 8, average loss 0.6933204120350769
Epoch 9, average loss 0.6932885091946059
Epoch 10, average loss 0.6930540472573655
Epoch 11, average loss 0.6937791561143308
Epoch 12, average loss 0.6932995912187816
Epoch 13, average loss 0.6934602418568468
Epoch 14, average loss 0.6935823071278352
Epoch 15, average loss 0.6932619190149473
Epoch 16, average loss 0.6932975734146806
Epoch 17, average loss 0.6932426482711141
Epoch 18, average loss 0.6932334794749014
Epoch 19, average loss 0.693505065846261
Epoch 20, average loss 0.6931633755661067
Epoch 21, average loss 0.693090069171525
Epoch 22, average loss 0.6928344940563419
Epoch 23, average loss 0.6931130301781496
Epoch

Epoch 194, average loss 0.6931830352088043
Epoch 195, average loss 0.6931699746325964
Epoch 196, average loss 0.6934530036789901
Epoch 197, average loss 0.6932026184413028
Epoch 198, average loss 0.6931205970198477
Epoch 199, average loss 0.6930529403441321
Epoch 200, average loss 0.6933320132877341
Epoch 201, average loss 0.6932484345568158
Epoch 202, average loss 0.6933941023738772
Epoch 203, average loss 0.6932218995384094
Epoch 204, average loss 0.6931728810871547
Epoch 205, average loss 0.6929419252223277
Epoch 206, average loss 0.69329990299087
Epoch 207, average loss 0.6928570332114283
Epoch 208, average loss 0.6930491970871123
Epoch 209, average loss 0.6938043152456953
Epoch 210, average loss 0.6929208529555346
Epoch 211, average loss 0.6931914978158373
Epoch 212, average loss 0.6927195743508917
Epoch 213, average loss 0.6929136515978345
Epoch 214, average loss 0.6932257518303979
Epoch 215, average loss 0.6936895200856665
Epoch 216, average loss 0.6934143308980663
Epoch 217, av

Epoch 385, average loss 0.6928141296802084
Epoch 386, average loss 0.69347266085767
Epoch 387, average loss 0.6933780650272743
Epoch 388, average loss 0.693467533394263
Epoch 389, average loss 0.6931670238608039
Epoch 390, average loss 0.6931482854844309
Epoch 391, average loss 0.6929968376758017
Epoch 392, average loss 0.6930704813230192
Epoch 393, average loss 0.6928188284850396
Epoch 394, average loss 0.6932483421409347
Epoch 395, average loss 0.6931969735168844
Epoch 396, average loss 0.6932106209167352
Epoch 397, average loss 0.6936034141729246
Epoch 398, average loss 0.6931705538993306
Epoch 399, average loss 0.693380243278784
Epoch 400, average loss 0.6928301047162959
Epoch 401, average loss 0.6934921213378848
Epoch 402, average loss 0.6929120348703901
Epoch 403, average loss 0.6932970164330188
Epoch 404, average loss 0.6935728718280313
Epoch 405, average loss 0.6932824007893212
Epoch 406, average loss 0.6931892669412236
Epoch 407, average loss 0.6930434784298029
Epoch 408, aver

In [582]:
def classify(test_data, model):
    outputs = model(test_data)
    
    return torch.argmax(outputs, dim=1)

def compute_accuracy(test_set, model):
    inputs = np.array([i for i,j in test_set])
    inputs = torch.from_numpy(inputs[:, None, :]).to(device)
    
    labels = torch.from_numpy(np.array([j for i,j in test_set])).to(device)
    
    outputs = classify(inputs, model)
    
    accuracy = (outputs == labels).sum()/outputs.shape[0]
    
    return accuracy

In [583]:
# Run trained NN on test data and obtain accuracies
net.eval()
accuracy = compute_accuracy(test_set, net)

In [584]:
print(accuracy)

tensor(0.5000, device='cuda:0')
