<a href="https://colab.research.google.com/github/ashmipednekar/asd-diagosis-ml/blob/main/ASD_Diagnostic_CNNs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
from google.colab import drive
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display
from PIL import Image
import torch
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import random
import torch.nn as nn
from torch.optim import *
import os
import zipfile
from tqdm import tqdm
from sklearn.metrics import accuracy_score

# Mounting Drive

**NOTE**: Please replace this path with the correct path to your Google Drive and your raw data path.

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
zip_ref = zipfile.ZipFile('/content/drive/MyDrive/FMRIScans.zip', 'r')
zip_ref.extractall('/FMRIScans')
zip_ref.close()

In [None]:
fmri_scans_path = Path('/FMRIScans/FMRIScans')

# Other Setup

## Loading Dataset

In [None]:
# We define a general FMRI Scans Dataset class
class FMRIScansData(torch.utils.data.Dataset):
    def __init__(self, path_list):
        self.path_list = path_list
        self.to_tensor = ToTensor()
    
    def __len__(self):
        return len(self.path_list)
    
    def __getitem__(self, idx):
        # Get the image data folder path
        image_path = self.path_list[idx]
        # Load and setup the image
        all_img_files = [x for x in image_path.iterdir() if x.name not in ['.DS_Store']]
        
        # Now, we assume that z is 0 to 60 (dim of 61), t is 0 to 75, going up by 5 each time (dim of 16)
        full_img = []
        for t in range(0, 80, 5):
            curr_img = []

            for z in range(0, 61):
                # We first double check to make sure that it loads correctly
                partial_img_path = image_path / f'{image_path.name}_{z}_{t}.png'
                if partial_img_path not in all_img_files:
                    raise Exception(f"TING: {partial_img_path}")
                
                # We added all of the 3D slices of the brain image together
                partial_img = Image.open(partial_img_path)
                curr_img += [self.to_tensor(partial_img)]

            # We make it a stack
            full_img += [torch.stack(curr_img)]
        
        # We stack them all up and return
        full_img = torch.stack(full_img)
        full_img = full_img.reshape(16, 61, 61, 73)
        # Check whether subject has ASD class or Control class and set label
        label = 0.0 if 'Control/' in str(image_path) else 1.0
        return full_img, label

In [None]:
# Then, get all ASD sample filepaths, and all Control sample filepaths, to load them up
asd_files = list((fmri_scans_path / 'ASD').iterdir())
control_files = list((fmri_scans_path / 'Control').iterdir())

# We combine all of these, and then create an 80/20 training testing split, setting a seed value
random.seed(314)
total_files = asd_files + control_files
total_files = [x for x in total_files if x.name not in ['.DS_Store']]
random.shuffle(total_files)
train_index = int(len(total_files) * 0.8)
training_files = total_files[:train_index]
testing_files = total_files[train_index:]

# Then we load these results up into 2 mini datasets
train_data = FMRIScansData(training_files)
test_data = FMRIScansData(testing_files)
train_dl = DataLoader(train_data, batch_size=16, shuffle=True)
test_dl = DataLoader(test_data, batch_size=16, shuffle=True)

# CNN Building and Testing

In [None]:
# Example class, adapted from here: https://towardsdatascience.com/pytorch-step-by-step-implementation-3d-convolution-neural-network-8bf38c70e8b3
class Example3DCNN(nn.Module):
    def __init__(self):
        super(Example3DCNN, self).__init__()
        self.conv_layer_1 = self.conv_layer_set(16, 32)
        self.conv_layer_2 = self.conv_layer_set(32, 64)
        self.conv_layer_3 = self.conv_layer_set(64, 128)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * 7 * 7 * 9, 128 * 7 * 7)
        self.fc2 = nn.Linear(128 * 7 * 7, 128 * 7)
        self.fc3 = nn.Linear(128 * 7, 1)
        self.relu = nn.LeakyReLU()
        self.drop = nn.Dropout(p=0.15)
    
    def conv_layer_set(self, in_c, out_c):
        conv_layer = nn.Sequential(
            nn.Conv3d(in_c, out_c, kernel_size=(3, 3, 3), padding='same'),
            nn.LeakyReLU(),
            nn.MaxPool3d((2, 2, 2)),
        )
        return conv_layer
    
    def forward(self, x):
        out = self.conv_layer_1(x)
        out = self.conv_layer_2(out)
        out = self.conv_layer_3(out)
        out = self.flatten(out)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.drop(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.relu(out)
        out = torch.sigmoid(out)
        return out

In [None]:
# Example Hyperparameters, Loss, Optimizer, Etc
num_epochs = 50
learning_rate = 0.000001

model = Example3DCNN()
error = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [None]:
print(len(train_data))
print(len(test_data))
print(len(train_dl))
print(len(test_dl))

874
219
55
14


In [None]:
# Training Model Loop
count = 0
loss_list = []
iteration_list = []
accuracy_list = []

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_dl):
        labels = torch.unsqueeze(labels,-1).float()
        optimizer.zero_grad()
        outputs = model(images)
        loss = error(outputs, labels)
        loss.backward()
        optimizer.step()

        if count % 50 == 0:
            correct = 0
            total = 0
            for test_images, test_labels in test_dl:
                outputs = model(test_images)
                predicted = (outputs > 0.5) * 1.0
                total += len(test_labels)
                correct += (predicted.flatten() == test_labels).sum()

            accuracy = 100 * correct / float(total)
            loss_list += [loss]
            iteration_list += [count]
            accuracy_list += [accuracy]

        if count % len(train_dl) == 0:
            print(f'Iteration: {count} Loss: {loss} Accuracy: {accuracy} %')
        
        count += 1


Iteration: 0 Loss: 0.6931473016738892 Accuracy: 47.94520568847656 %
Iteration: 55 Loss: 0.6931207180023193 Accuracy: 51.14155197143555 %
Iteration: 110 Loss: 0.6930587887763977 Accuracy: 47.488582611083984 %
Iteration: 165 Loss: 0.6931777000427246 Accuracy: 45.20547866821289 %
Iteration: 220 Loss: 0.693226158618927 Accuracy: 53.42465591430664 %
Iteration: 275 Loss: 0.6932247281074524 Accuracy: 43.83561706542969 %
Iteration: 330 Loss: 0.6931149959564209 Accuracy: 51.14155197143555 %
Iteration: 385 Loss: 0.6929469108581543 Accuracy: 50.684932708740234 %
Iteration: 440 Loss: 0.6931084990501404 Accuracy: 52.511417388916016 %
Iteration: 495 Loss: 0.6931896805763245 Accuracy: 49.315067291259766 %
Iteration: 550 Loss: 0.6931754946708679 Accuracy: 54.79452133178711 %
Iteration: 605 Loss: 0.6931832432746887 Accuracy: 48.401824951171875 %
Iteration: 660 Loss: 0.6933375000953674 Accuracy: 48.401824951171875 %
Iteration: 715 Loss: 0.6931467652320862 Accuracy: 49.771690368652344 %
Iteration: 770 Lo