# Brain GNN

Understanding how certain brain regions relate to a specific neurological disorder or cognitive stimuli has been an important area of neuro_imaging research. We propose BrainGNN, a graph neuralnetwork (GNN) framework to analyze functional magnetic resonance images (fMRI) and discover neurological biomarkers.
we construct weighted graphs from fMRI and apply a GNN to fMRI brain
graphs.Considering the special property of brain graphs, we design novel
brain ROI-aware graph convolutional layers (Ra-GNN) that leverages
the topological and functional information of fMRI.

In [2]:
import os
import numpy as np
import time
import copy
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
from tensorboardX import SummaryWriter
from imports_data.ABIDEDataset import ABIDEDataset
from torch_geometric.data import DataLoader
from net.braingnn import Network
from imports_data.utils import train_val_test_split
from sklearn.metrics import classification_report, confusion_matrix

In [3]:
torch.manual_seed(123)
EPS = 1e-10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data
ABIDE dataset.
The Autism Brain Imaging Data Exchange (ABIDE) initiative has aggregated functional and structural brain imaging data collected from laboratories around the world to accelerate our understanding of the neural bases of autism.

To create these graphs,nodes are defined as brain regions of interest (ROIs) and edges are defined as the functional connectivity between those ROIs, computed as the pairwise correlations of functional magnetic resonance imaging (fMRI) time series,
Additionally, due to the high dimensionality of fMRI data, usually ROIs are clustered into highly connected communities to reduce dimensionality. Then, features are extracted from these smaller communities for further analysis

In [4]:
# root directory of the dataset
name = 'ABIDE'
dataroot = 'D:\EE\ETH\project\BrainGNN\data\ABIDE_pcp\cpac\\filt_noglobal'

## Notation and Problem Definition

 First we parcelled the brain into N regions of interest (ROIs) based on its T1
structural MRI.
 Define ROIs as graph nodes V = {v1, . . . , vN } and the nodes are pre_ordered.
 We define an undirected weighted graph as G = (V, E), where E is the edge set, i.e., a collection of (vi vj) linking vertices from vi to vj .
 In our setting, G has an associated node feature set H = {h1, . . . , hN }, where hi is the feature vector associated with node vi.
  For every edge connecting two nodes,(vi vj ) ∈ E, we have its strength eij ∈ R and eij > 0. We also define eij = 0 for (vi, vj ) not∈ E and therefore the adjacency matrix E = [eij ] ∈ R N×N  is well defined.

## Hyperparameters

In [5]:
#starting epoch
epoch= 0
# number of epochs of training
n_epochs = 100
# size of the batches
batchSize = 100
# training which fold
fold = 0
# learning rate
lr = 0.001
# scheduler step size
stepsize = 20
# scheduler shrinking rate
gamma = 0.5
# regularization
weightdecay = 5e-3
# classification loss weight
lamb0 = 1
# s1 unit regularization
lamb1 = 0
# s2 unit regularization
lamb2 = 0
# s1 entropy regularization
lamb3 = 0.1
# s2 entropy regularization
lamb4 = 0.1
# consistence regularization
lamb5 = 0.1
# number of GNN layers
layer = 2
# pooling ratio
ratio = 0.5
# feature dim
indim = 200
#num of rio
nroi = 200
# num of classes
nclass = 2
# optimization method: SGD, Adam
optim = "Adam"


## Save model

In [6]:
load_model = False
save_model= True
# path to save model
save_path = './model/'
if not os.path.exists(save_path):
    os.makedirs(save_path)
writer = SummaryWriter(os.path.join('./log', str(fold)))

## Define Dataloader

In [7]:
dataset = ABIDEDataset(dataroot, name)
dataset.data.y = dataset.data.y.squeeze()
dataset.data.x[dataset.data.x == float('inf')] = 0

tr_index, val_index, te_index = np.array(train_val_test_split(fold=fold), dtype=object)

train_dataset = [dataset[i] for i in tr_index]
val_dataset = [dataset[i] for i in val_index]
test_dataset = [dataset[i] for i in te_index]

train_loader = DataLoader(train_dataset, batch_size=batchSize, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batchSize, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batchSize, shuffle=False)



In [14]:
print("size of the dataset :")
print(len(dataset))
print("size of the train dataset:")
print(len(train_dataset))
print("size of the  validation dataset:")
print(len(val_dataset))
print("size of the test dataset:")
print(len(test_dataset))
print("shape of the data :")
print("X:")
print(dataset.data.x.shape)
print(dataset.data.x)
print("Y:")
print(dataset.data.y.shape)
print("number of input graphs :")
print(len(dataset.data.x))

size of the dataset :
1035
size of the train dataset:
621
size of the  validation dataset:
207
size of the test dataset:
207
shape of the data :
X:
torch.Size([207000, 200])
tensor([[0.0000, 0.6177, 0.7100,  ..., 0.3766, 0.5746, 0.3808],
        [0.6177, 0.0000, 0.6818,  ..., 0.2914, 0.4564, 0.5356],
        [0.7100, 0.6818, 0.0000,  ..., 0.5322, 0.8745, 0.7440],
        ...,
        [0.3508, 0.2820, 0.2556,  ..., 0.0000, 0.1831, 0.4874],
        [0.0966, 0.2417, 0.3996,  ..., 0.1831, 0.0000, 0.1292],
        [0.4726, 0.3215, 0.1351,  ..., 0.4874, 0.1292, 0.0000]])
Y:
torch.Size([1035])
number of input graphs :
207000




## Define Graph Deep Learning Network

(R-pool) -> that highlight salient ROIs(nodes in the graph), so that we can infer which ROIs are important
for prediction.
(regularization terms) - unit loss,topK pooling (TPK) loss and group-level consistency (GLC) loss - on
pooling results to encourage reasonable ROI-selection and provide flexibility to preserve either individual- or group-level patterns.

In [8]:
model = Network(indim, ratio, nclass).to(device)
if optim== 'Adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weightdecay)
elif optim == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weightdecay,
                                nesterov=True)

scheduler = lr_scheduler.StepLR(optimizer, step_size=stepsize, gamma=gamma)

In [9]:
print(model)

Network(
  (n1): Sequential(
    (0): Linear(in_features=200, out_features=8, bias=False)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=6400, bias=True)
  )
  (conv1): MyNNConv(200, 32)
  (pool1): TopKPooling(32, ratio=0.5, multiplier=1)
  (n2): Sequential(
    (0): Linear(in_features=200, out_features=8, bias=False)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=1024, bias=True)
  )
  (conv2): MyNNConv(32, 32)
  (pool2): TopKPooling(32, ratio=0.5, multiplier=1)
  (fc1): Linear(in_features=128, out_features=32, bias=True)
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=32, out_features=512, bias=True)
  (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=512, out_features=2, bias=True)
)


## Loss functions

## Define Other Loss Functions

We add several loss terms to regulate the learning process and control the interpretability.
The classification loss is the cross entropy loss:
 ![figure1](imag\img5.png)

to avoid the problem of identifiability, we propose unit loss:
 ![figure1](imag\img6.png)

For our application,we want to find the common patterns/biomarkers for a certain neuro-prediction
task. Thus, we add regularization to force the ˜s(l) vectors to be similar for different input instances after the first pooling layer and call it group-level consistency (GLC).
 ![figure1](imag\img7.png)

We define TPK loss using binary cross-entropy as:
 ![figure1](imag\img8.png)

Finally, the final loss function is formed as:
 ![figure1](imag\img9.png)

In [10]:
def topk_loss(s, ratio):
    if ratio > 0.5:
        ratio = 1 - ratio
    s = s.sort(dim=1).values
    res = -torch.log(s[:, -int(s.size(1) * ratio):] + EPS).mean() - torch.log(
        1 - s[:, :int(s.size(1) * ratio)] + EPS).mean()
    return res

In [11]:
def consist_loss(s):
    if len(s) == 0:
        return 0
    s = torch.sigmoid(s)
    W = torch.ones(s.shape[0], s.shape[0])
    D = torch.eye(s.shape[0]) * torch.sum(W, dim=1)
    L = D - W
    L = L.to(device)
    res = torch.trace(torch.transpose(s, 0, 1) @ L @ s) / (s.shape[0] * s.shape[0])
    return res

## Network Training Function

In [12]:
def train(epoch):
    print('train...........')

    for param_group in optimizer.param_groups:
        print("LR", param_group['lr'])
    model.train()
    s1_list = []
    s2_list = []
    loss_all = 0
    step = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, w1, w2, s1, s2 = model(data.x, data.edge_index, data.batch, data.edge_attr, data.pos)
        s1_list.append(s1.view(-1).detach().cpu().numpy())
        s2_list.append(s2.view(-1).detach().cpu().numpy())

        loss_c = F.nll_loss(output, data.y)

        loss_p1 = (torch.norm(w1, p=2) - 1) ** 2
        loss_p2 = (torch.norm(w2, p=2) - 1) ** 2
        loss_tpk1 = topk_loss(s1, ratio)
        loss_tpk2 = topk_loss(s2, ratio)
        loss_consist = 0
        for c in range(nclass):
            loss_consist += consist_loss(s1[data.y == c])
        loss = lamb0 * loss_c + lamb1 * loss_p1 + lamb2 * loss_p2 \
               + lamb3 * loss_tpk1 + lamb4 * loss_tpk2 + lamb5 * loss_consist
        writer.add_scalar('train/classification_loss', loss_c, epoch * len(train_loader) + step)
        writer.add_scalar('train/unit_loss1', loss_p1, epoch * len(train_loader) + step)
        writer.add_scalar('train/unit_loss2', loss_p2, epoch * len(train_loader) + step)
        writer.add_scalar('train/TopK_loss1', loss_tpk1, epoch * len(train_loader) + step)
        writer.add_scalar('train/TopK_loss2', loss_tpk2, epoch * len(train_loader) + step)
        writer.add_scalar('train/GCL_loss', loss_consist, epoch * len(train_loader) + step)
        step = step + 1

        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()


        s1_arr = np.hstack(s1_list)
        s2_arr = np.hstack(s2_list)
    scheduler.step()
    return loss_all / len(train_dataset), s1_arr, s2_arr, w1, w2

## Network Testing Function

In [13]:
def test_acc(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        outputs = model(data.x, data.edge_index, data.batch, data.edge_attr, data.pos)
        pred = outputs[0].max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()

    return correct / len(loader.dataset)

In [14]:
def test_loss(loader):
    print('testing...........')
    model.eval()
    loss_all = 0
    for data in loader:
        data = data.to(device)
        output, w1, w2, s1, s2 = model(data.x, data.edge_index, data.batch, data.edge_attr, data.pos)
        loss_c = F.nll_loss(output, data.y)

        loss_p1 = (torch.norm(w1, p=2) - 1) ** 2
        loss_p2 = (torch.norm(w2, p=2) - 1) ** 2
        loss_tpk1 = topk_loss(s1, ratio)
        loss_tpk2 = topk_loss(s2, ratio)
        loss_consist = 0
        for c in range(nclass):
            loss_consist += consist_loss(s1[data.y == c])
        loss = lamb0 * loss_c + lamb1 * loss_p1 + lamb2 * loss_p2 \
               + lamb3 * loss_tpk1 + lamb4 * loss_tpk2 + lamb5 * loss_consist

        loss_all += loss.item() * data.num_graphs
    return loss_all / len(loader.dataset)


# Model Training

In [15]:
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e10
for epoch in range(0, n_epochs):
    since = time.time()
    tr_loss, s1_arr, s2_arr, w1, w2 = train(epoch)
    tr_acc = test_acc(train_loader)
    val_acc = test_acc(val_loader)
    val_loss = test_loss(val_loader)
    time_elapsed = time.time() - since
    print('*====**')
    print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Epoch: {:03d}, Train Loss: {:.7f}, '
          'Train Acc: {:.7f}, Test Loss: {:.7f}, Test Acc: {:.7f}'.format(epoch, tr_loss,
                                                                          tr_acc, val_loss, val_acc))

    writer.add_scalars('Acc', {'train_acc': tr_acc, 'val_acc': val_acc}, epoch)
    writer.add_scalars('Loss', {'train_loss': tr_loss, 'val_loss': val_loss}, epoch)
    writer.add_histogram('Hist/hist_s1', s1_arr, epoch)
    writer.add_histogram('Hist/hist_s2', s2_arr, epoch)

    if val_loss < best_loss and epoch > 5:
        print("saving best model")
        best_loss = val_loss
        best_model_wts = copy.deepcopy(model.state_dict())
        if save_model:
            torch.save(best_model_wts, os.path.join(save_path, str(fold) + '.pth'))

train...........
LR 0.001


  C = torch.sparse.mm(A, B)


testing...........
*====**
0m 21s
Epoch: 000, Train Loss: 1.1346474, Train Acc: 0.5330113, Test Loss: 0.9784882, Test Acc: 0.4927536
train...........
LR 0.001
testing...........
*====**
0m 21s
Epoch: 001, Train Loss: 1.0308003, Train Acc: 0.5458937, Test Loss: 0.9785184, Test Acc: 0.4734300
train...........
LR 0.001
testing...........
*====**
0m 21s
Epoch: 002, Train Loss: 1.0967633, Train Acc: 0.5539452, Test Loss: 0.9833101, Test Acc: 0.4782609
train...........
LR 0.001
testing...........
*====**
0m 21s
Epoch: 003, Train Loss: 1.0425431, Train Acc: 0.5507246, Test Loss: 0.9860983, Test Acc: 0.4734300
train...........
LR 0.001
testing...........
*====**
0m 21s
Epoch: 004, Train Loss: 1.0586852, Train Acc: 0.5603865, Test Loss: 0.9936846, Test Acc: 0.4975845
train...........
LR 0.001
testing...........
*====**
0m 21s
Epoch: 005, Train Loss: 1.0292435, Train Acc: 0.5813205, Test Loss: 0.9877933, Test Acc: 0.5072464
train...........
LR 0.001
testing...........
*====**
0m 21s
Epoch: 006, 

## Testing on testing set

In [19]:
best_model = Network(indim, ratio, nclass).to(device)
best_model.state_dict(best_model_wts)
if load_model:
    model = Network(indim, ratio, nclass).to(device)
    model.load_state_dict(torch.load(os.path.join(save_path, str(fold) + '.pth')))
    model.eval()
    preds = []
    correct = 0
    for data in val_loader:
        data = data.to(device)
        outputs = best_model(data.x, data.edge_index, data.batch, data.edge_attr, data.pos)
        pred = outputs[0].max(1)[1]
        preds.append(pred.cpu().detach().numpy())
        correct += pred.eq(data.y).sum().item()
    preds = np.concatenate(preds, axis=0)
    trues = val_dataset.data.y.cpu().detach().numpy()
    cm = confusion_matrix(trues, preds)
    print("Confusion matrix")
    print(classification_report(trues, preds))

else:
    model.load_state_dict(best_model_wts)
    model.eval()
    test_accuracy = test_acc(test_loader)
    test_l = test_loss(test_loader)
    print("===========================")
    print("Test Acc: {:.7f}, Test Loss: {:.7f} ".format(test_accuracy, test_l))



testing...........
Test Acc: 0.5314010, Test Loss: 0.9909237 


NOTE:results of the paper for 2 datasets:
1.Bio_point Autism Study Dataset,2.HCP dataset(900 subject)
For the Bio_point dataset, the aim is to classify Autism Spectrum Disorder (ASD) and Healthy Control (HC).
For the HCP dataset, the aim is to classify 7 task states - gambling, language, motor, relational, social, working memory (WM), emotion.
The available code, is for 2 class classification, Autism Disorder or healthy, for 115 subjects.(43 healty, 72 ASM). Augment data 30 times, resulting in 3,450 graphs.
For HCP, there is 3,542 graphs for 237 subjects.
But in ABIDE data set, for 1036 subjects, we have 207000 graphs which is about 38 time bigger than other two datasets! so thats why this tuning doesn't work well here.
