Install relevant packages and modules.

In [3]:
!apt-get install -y dcm2niix # convert dicom files to nii
!apt-get install -y parallel # allows for utilization of multiple cores at once. Useful for actual implementation, just for demonstration here

# dependencies for freesurfer
!apt-get install -y wget
!apt-get install -y grep
!apt-get install -y tcsh
!apt-get install -y bc

!pip install nibabel # helps to deal with nii data in a format which you can actually work with
!pip install pydicom # helps with reading dycom headers, useful for scraping metadata i.e. age, sex, etc.

import os
import glob

import numpy as np
import pandas as pd
import scipy

import nibabel as nib
import pydicom

!pip install torch-geometric

import torch
import torch.nn as nn
import torch.nn.functional as F

# for the graph neural network portion
!pip install torch-geometric
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data, DataLoader
import torch.nn as nn

# install freesurfer (might take a few minutes)

# if you are having issues downloading freesurfer, it is likely due to the version selected here
!wget -O freesurfer.tar.gz https://freesurfer.net/pub/dist/freesurfer/7.4.1/freesurfer-linux-ubuntu22_amd64-7.4.1.tar.gz
!tar -xzf freesurfer.tar.gz

# set the relevant freesurfer directories
os.environ['FREESURFER_HOME'] = '/content/freesurfer'
os.environ['SUBJECTS_DIR'] = '/content/freesurfer_output'
os.environ['PATH'] += ':/content/freesurfer/bin'

!source /content/freesurfer/SetUpFreeSurfer.sh

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
dcm2niix is already the newest version (1.0.20211006-1build1).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
parallel is already the newest version (20210822+ds-2).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
wget is already the newest version (1.21.2-2ubuntu1.1).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
grep is already the newest version (3.7-1build1).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tcsh is already the newest version

--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

Add the FreeSurfer License path. This is free to get but due to their policy I cannot include it here. It is available at https://surfer.nmr.mgh.harvard.edu/registration.html.

In [2]:
from google.colab import files
license = files.upload()

# change this to what you named the license
your_file_name = 'freesurfer_license.txt'

# set the enviorment to use the license
os.environ['FS_LICENSE'] = f'/content/{your_file_name}'

Saving freesurfer_license.txt to freesurfer_license.txt


Get anonymized (T1 weighted) dicom files for analysis, and create temporary folders to store these and other files.

In [4]:
# clone the repository to get the raw dicoms
!git clone https://github.com/datalad/example-dicom-structural

# create folders for the NIfTI conversions and corresponding recons
!mkdir /content/nii_files/
!mkdir /content/freesurfer_output/

'''
Note: "T1 weighting" refers to a type of MRI image which is quite good at identifying tissue.
It is the most common form of imaging used in structural MRI.
''';

Cloning into 'example-dicom-structural'...
remote: Enumerating objects: 393, done.[K
remote: Total 393 (delta 0), reused 0 (delta 0), pack-reused 393 (from 1)[K
Receiving objects: 100% (393/393), 15.45 MiB | 4.94 MiB/s, done.
Resolving deltas: 100% (223/223), done.


Convert the DICOMs to NIfTIs.

In [None]:
# prepare paths
INPUT_PATH = '/content/example-dicom-structural/dicoms/'
OUTPUT_PATH = '/content/nii_files/'

# grep to ignore some warnings regarding the manufacturer (since we're using sample dicoms)
!dcm2niix -o '{OUTPUT_PATH}' '{INPUT_PATH}' | grep -v "Unknown manufacturer"

Chris Rorden's dcm2niiX version v1.0.20211006  (JP2:OpenJPEG) GCC11.2.0 x86-64 (64-bit Linux)
Found 384 DICOM file(s)
Convert 384 DICOM as /content/nii_files/dicoms_anat-T1w_20130717141500_401 (274x384x384x1)
Conversion required 0.412129 seconds (0.389445 for core code).


Reconstruct (recon) the NIfTI files.


Parallel allows you to recon multiple subjects simultaneously by recruiting n CPU cores (determined by --jobs n). It is irrelevant here, but is good to be aware of. The code below uses parallel, but only for a single process (n=1), making it effectively the same as not using it.

Recons take a long time (4 hours+ per nii) and are quite large (~100Mb). For this reason, I have included a sample finished recon within this repository, with only the essential files kept.

In [None]:
# prepare paths
INPUT_PATH = '/content/freesurfer_output'
nii_paths = glob.glob('/content/nii_files/*.nii') # list of paths
ALL_NII = ' '.join(nii_paths)  # format that parallel wants

# use parallel to execute recon-all on each NIfTI file
!parallel --jobs 1 recon-all -i {} -s {/.} -all ::: /content/nii_files/*.nii

Load in the sample recon.

In [23]:
!git clone https://github.com/SamAndTheSun/sMRI_BrainAge_Tutorial.git
recon_path = '/content/sample_recon'

Cloning into 'sMRI_BrainAge_Tutorial'...
^C


--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

Deep Learning Analysis Example 1: CNN using brain.mgz


The brain.mgz file represents the combination of each individual "slice" of the brain stitched together to form a single cohesive volume. We can use this file to construct a CNN, trained on 3D images, to predict brain age.

In [21]:
# load in the brain.mgz files
brain_files = glob.glob('/content/*_recon')
print(brain_files)

[]


First, we downsample the brain files. This isn't strictly necessary but it often helps to reduce noise and make it easier for the model the train. We then save the downsampled files as numpy files, allowing for further manipulation.

In [18]:
# get the converted files
for idx, subj_brain in enumerate(brain_files):

    # load the mgz file in, get the volume, then downsample
    brain_data = mgh.load(subj_brain)
    affine_space_brain = brain_data.affine
    vol = zoom(affine_space_brain.get_fdata(), (0.5, 0.5, 0.5))

    print(vol.shape)

    # save the downsampled files in numpy format
    np.save(f"/content/subj_{idx}", vol)

Now lets define a rudimentary CNN for us to train.


In [11]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Convolutional layer 1
        self.conv1 = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

        # Convolutional layer 2
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)

        # Fully connected layer
        self.fc1 = nn.Linear(64 * 31 * 31 * 31, 128)  # Adjust dimensions after pooling
        self.fc2 = nn.Linear(128, 1)  # Example: 10 output classes

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate the model
model = CNN()

Our next step is to derive our data of interest, in this case it's age, so let's generate sample age values for us to use. Sometimes demographic data will be difficult to find; I'll show one common work around in the next model example.

After this, we can define our training loop.

In [26]:
# fake data; ingenious, I know
y = torch.tensor(49)

# load in the brain files, and convert them into a single tensor
paths = [x for x in sorted(glob.glob("/content/subj_*"))]
tensor_list = [torch.from_numpy(np.load(path)) for path in paths]
X = torch.stack(tensor_list)

'''
note 1: glob is NOT sorted. The order of files will vary every time you use it.
Because of this, it's important if working with numpy files containing ages, with corresponding data files,
that you use sorted(glob.glob()) for both.

note 2: if dealing with too much data to simply load it in, you can also load in small chunks of these
numpy files, convert them, train with them, then unload these files and repeat
'''

# very important for training efficiency if you have an Nvidia GPU
# we didn't compile torch with cuda though, so this won't do anything
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(dataloader), correct / total

RuntimeError: stack expects a non-empty TensorList

Now let's train the model.

In [None]:
criterion = nn.L1Loss() # This is the same as MAE, and is standard for brain age
optimizer = optim.Adam(model.parameters(), lr=0.001) # ubiquitious basically everywhere

train_loader = DataLoader(np.load(), batch_size=1, shuffle=True)

num_epochs = 10

for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device) # recall that device was the whole bit about cuda before
    test_loss, test_acc = test(model, test_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}\n")

Unsuprisingly, given we are using a single piece of fake data, it doesn't do very well. And we haven't even tested it! Regardless, these are the basic steps for developing a CNN for brain age.

--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

Deep Learning Analysis Example 2: graphSAGE using the cortex files

The .pial files (one for each hemisphere) represent the geometric vertices and faces (i.e. the connections between vertices) for each individual's cortex. Freesurfer also provides other files, such as the cortical thickness (.thickness) and white-grey matter intensity ratio (.w-g.pct.mgh) which provide information on each of these vertices.

We can use these files in conjunction to construct a graph of each individual's brain, with vertices serving as nodes, faces serving as edges, and attributes serving as features.

In [None]:
# get each subject (only one here)
subj_files = glob.glob(f'{recon_path}')

# load in the pial files
pial_files = glob.glob(f'{recon_path}/surf/*h.pial') # *h because we want both hemispheres

# load in the thickness files
thickness_files = glob.glob(f'{recon_path}/surf/*h.thickness')

# load in the white-grey matter intensity ratio files
wg_ratio_files = glob.glob(f'{recon_path}/surf/*h.w-g.pct.mgh')

# glob uses a random sort order, so we sort alphabetically to match everything
subj_files = sorted(subj_files)
pial_files = sorted(pial_files)
thickness_files = sorted(thickness_files)
wg_ratio_files = sorted(wg_ratio_files)

# make a dictionary for the training data (we will make it structured as: subjects -> nodes -> features/edges)
training_data = {}

# make a dictionary for the edge indices
edge_indices = {}

# loop through the subjects to construct the desired dictionaries
for i, subj in enumerate(subj_files): # since we only have on subject, this only runs once

  # get the last part of the path
  subj_id = subj.split('/')[-1]

  # get the vertice and face data for the subject
  lh_vertices, lh_faces = nib.freesurfer.read_geometry(pial_files[i]) # we know lh is before rh because we sorted alphabetically
  rh_vertices, rh_faces = nib.freesurfer.read_geometry(pial_files[i+1])

  # combine them into a single array
  vertices = np.vstack((lh_vertices, rh_vertices+(np.max(lh_vertices)+1))) # vertices uses relative node index (min=0), so we need to account for this

  # do the same for every other file type
  lh_thickness = nib.freesurfer.io.read_morph_data(thickness_files[i]) # be mindful of which nib reading varient to use
  rh_thickness = nib.freesurfer.io.read_morph_data(thickness_files[i+1])
  thickness = np.hstack((lh_thickness, rh_thickness)) # not relative, notice the use of h-stack for single-dimension variables
  #
  lh_ratio = nib.load(wg_ratio_files[i]).get_fdata()
  rh_ratio = nib.load(wg_ratio_files[i+1]).get_fdata()
  ratio = np.vstack((lh_ratio, rh_ratio)) # not relative
  ratio = ratio.squeeze() # this has dimensions (n_nodes, 1, 1) otherwise

  # create a node for each vertice and a seperate list for the edges
  training_data[subj_id] = [[] for _ in range(vertices.shape[0])]

  # for each node add the corresponding features
  for n, node in enumerate(training_data[subj_id]):
    node.extend(vertices[n, :])
    node.append(thickness[n])
    node.append(ratio[n])

  # additionally, get the faces of the each subject as an edge index
  faces = np.vstack((lh_faces, rh_faces+(np.max(lh_faces)+1))) # recall that we got the faces from the pial files, which use relative indexing
  edge_index = []

  # loop through the faces and create 2-dimensional representation
  # these edges are undirected and thus should include both directions
  for face in faces:
      edges = [
          (face[i], face[j])
          for i in range(3)
          for j in range(i + 1, 3)
      ]
      edge_index.extend(edges)
  edge_index = np.array(edge_index).T
  edge_indices[subj_id] = edge_index

# get the size of each sub-structure
num_subjects = len(training_data)
num_nodes = len(next(iter(training_data.values())))
num_values = len(next(iter(next(iter(training_data.values())))))
num_edges = edge_indices[subj_id].shape[1]

print("Number of subjects:", num_subjects)
print("Number of nodes (in subject 1):", num_nodes)
print("Number of features:", num_values) # recall that spatial position is 3 features: x, y, and z
print("Number of num_edges (in subject 1):", num_edges)

Number of subjects: 1
Number of nodes (in subject 1): 320845
Number of features: 5
Number of num_edges (in subject 1): 1925046


The next step is to get the data regarding participant ages, or whatever it is that we want to predict. Many datasets will include "demographic.csv" or "metadata.csv" files, but sometimes they won't. In these cases, we need to extract the metadata from the original DICOM files.

We can reasonably expect that every DICOM for a given subject will have the same demographic information within its metadata, so we only need to look at any random DICOM file for each subject. Let's take a look at the metadata.

In [None]:
# select all subject folders (only one in this case)
dicom_paths = glob.glob('/content/example-dicom-structural/*/')

for path in dicom_paths:

  # select an arbitray DICOM file within the subject folder
  subj_dicoms = glob.glob(f'{path}/*')
  target_dicom = subj_dicoms[0] # 0 is arbitrary

  # get the metadata
  metadata = pydicom.dcmread(target_dicom)

  # print out all of the metadata
  for elem in metadata.iterall():
    print(elem)

(0008,0008) Image Type                          CS: ['DERIVED', 'SECONDARY']
(0008,0016) SOP Class UID                       UI: MR Image Storage
(0008,0018) SOP Instance UID                    UI: 1.2.826.0.1.3680043.2.1143.7980170295326065434086375780975261994
(0008,0020) Study Date                          DA: '20130717'
(0008,0021) Series Date                         DA: '20130717'
(0008,0022) Acquisition Date                    DA: '20130717'
(0008,0023) Content Date                        DA: '20130717'
(0008,0030) Study Time                          TM: '141500'
(0008,0031) Series Time                         TM: '142035.93000'
(0008,0032) Acquisition Time                    TM: '132518'
(0008,0033) Content Time                        TM: '142035.93'
(0008,0050) Accession Number                    SH: ''
(0008,0060) Modality                            CS: 'MR'
(0008,0070) Manufacturer                        LO: 'BIOLAB'
(0008,0080) Institution Name                    LO: ''
(000

We can see that patient age is present. Our next step is to construct a loop that assembles the patient ages to line up correctly with the feature data.

In [None]:
# select all subject folders (only one in this case)
dicom_paths = glob.glob('/content/example-dicom-structural/*/')

# SORT the dicom paths. This is essential to making sure that the data is aligned.
# By sorting across all usages of glob we can make sure our results are consistent.
dicom_paths = sorted(dicom_paths)

# create a list for all of the subjects
ages = []

for path in dicom_paths:

  # select an arbitray DICOM file within the subject folder
  subj_dicoms = glob.glob(f'{path}/*')
  target_dicom = subj_dicoms[0] # 0 is arbitrary

  # get the metadata
  metadata = pydicom.dcmread(target_dicom)

  # get the age and add it to the list for all subjects
  age = metadata[(0x0010, 0x1010)].value

  # typically, age will be in the format str('55Y'),
  # but it varies by dataset. Here, it is simply str('55')
  ages.append(int(age))

print(ages)

[42]


Now lets get the relevant dependencies for the next steps (formatting the data then training a model using graphSAGE).

In [None]:
!pip install torch-geometric

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data, DataLoader



Format the Data and Batch objects in accordance with PyTorch Geometric's specifications. We do this by creating a unique Data object for each subject and adding this to a list, with this list then being used to create a Batch object for training.

In [None]:
# create an empty list to store the subject data
data_list = []

for subject, features in training_data.items():
    # convert the node features to a tensor
    x = torch.tensor(features, dtype=torch.float)

    # convert the edge indices array directly to a tensor
    edge_index = torch.tensor(edge_indices[subject], dtype=torch.long)

    # create the Data object and add it to the list
    data = Data(x=x, edge_index=edge_index)
    data_list.append(data)

# create the batch object
batch = Batch.from_data_list(data_list)
batch

DataBatch(x=[320845, 5], edge_index=[2, 1925046], batch=[320845], ptr=[2])

Now lets create a rudimentary GNN using graphSAGE. Certain architectures, such as a GAT, would require us to project the vertices of the subjects to a common space. This can be done using freesurfer's mri_surf2surf.

In [None]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

Now lets create a basic training loop for our model


In [None]:
def train(model, data_loader, epochs, lr=0.01, device='cuda'):

    optimizer = optim.Adam(model.parameters(), lr=lr) # essentially ubiquitious
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for data in data_loader:
            data = data.to(device)  # Move the data to the correct device
            optimizer.zero_grad()

            # Forward pass
            out = model(data.x, data.edge_index)  # Assuming data.x is node features and data.edge_index is adjacency list
            loss = model.loss(out, data.y)  # Assuming the model has a loss method for supervised tasks

            # Backward pass
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(data_loader)}")