In [1]:
import os
import numpy as np
import pandas as pd
import networkx as nx
import seaborn as sns
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch_geometric as tg
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import train_test_split
import natsort

# Prep data
 1. Get cobre timeseries
 2. Get cobre connectomes
 3. Get group average connectome
 4. Build 8 k-NN graph from avg connectome
 5. Split data: 70 training, 10 validation, 20 test
 6. All data from same subject assigned to same Split
 7. Cut time-series into bins of length time window

In [2]:
ts_path = '/home/harveyaa/Documents/fMRI/data/cobre/difumo/timeseries'
conn_path = '/home/harveyaa/Documents/fMRI/data/cobre/difumo/connectomes'
pheno_path = '/home/harveyaa/nilearn_data/cobre/phenotypic_data.tsv'

In [3]:
timeseries = [np.load(os.path.join(ts_path,p)) for p in os.listdir(ts_path)]
ids = [int(p.split('_')[1]) for p in os.listdir(ts_path)]

# One subject has different length timeseries, ignore them for now
not_150 = np.array([t.shape[0]!=150 for t in timeseries])
print('Bad sub ID: {}'.format(np.array(ids)[not_150][0]))

Bad sub ID: 40075


# Make Graph
- Load connectomes
- Get avg connectome
- Get 8 knn graph from avg connectome

In [4]:
def make_undirected(mat):
    """Takes an input adjacency matrix and makes it undirected (symmetric)."""
    m = mat.copy()
    mask = mat != mat.transpose()
    vals = mat[mask] + mat.transpose()[mask]
    m[mask] = vals
    return m

def knn_graph(mat,k=8,directed=False):
    """Takes an input matrix and returns a k-Nearest Neighbour weighted adjacency matrix."""
    m = mat.copy()
    np.fill_diagonal(m,0)
    slices = []
    for i in range(m.shape[0]):
        s = m[:,i]
        not_neighbours = s.argsort()[:-k]
        s[not_neighbours] = 0
        slices.append(s)
    if directed:
        return np.array(slices)
    else:
        return make_undirected(np.array(slices))
    
def make_group_graph(conn_path):
    # Load connectomes
    connectomes = [np.load(os.path.join(conn_path,p)) for p in os.listdir(conn_path)]

    # Group average connectome
    avg_conn = np.array(connectomes).mean(axis=0)

    # Undirected 8 k-NN graph as matrix
    avg_conn8 = knn_graph(avg_conn,directed=False)

    # Format matrix into graph for torch_geometric
    graph = nx.convert_matrix.from_numpy_array(avg_conn8)
    return tg.utils.from_networkx(graph)

# Get train/test/validation data
- Load timeseries and ids
- Split timeseries of 150 volumes into time windows
- Split data into train/test/validation
  - All data from a given subject goes in the same bin

In [5]:
def split_timeseries(ts,n_timepoints=50):
    """Takes an input timeseries and splits it into time windows of specified length. Need to choose a number that splits evenly."""
    if ts.shape[0] % n_timepoints != 0:
        raise ValueError('Yikes choose a divisor for now')
    else:
        n_splits = ts.shape[0] / n_timepoints
        return np.split(ts,n_splits)

def split_ts_labels(timeseries,labels,n_timepoints=50):
    """
    timeseries: list of timeseries
    labels: list of lists (of accompanying labels)
    n_timepoints: n_timepoints of split (must be an even split)
    """
    # Split the timeseries
    split_ts = []
    for ts in map(split_timeseries,timeseries):
        split_ts = split_ts + ts

    #keep track of the corresponding labels
    n = int(timeseries[0].shape[0]/n_timepoints)
    split_labels = []
    for l in labels:
        split_labels.append(np.repeat(l,n))

    #add a label for each split
    split_labels.append(list(range(n))*len(timeseries))
    return split_ts, split_labels

def train_test_val_splits(split_ids,test_size=0.20,val_size=0.10,random_state=111):
    """Train test val split the data (in splits) so splits from a subject are in the same group.
        returns INDEX for each split
    """
    # Train test validation split of ids, then used to split dataframe
    X = np.unique(split_ids)
    y = list(range(len(X)))

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size+val_size, random_state=random_state)
    X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, test_size=val_size/(test_size+val_size), random_state=random_state)

    train_idx = []
    test_idx = []
    val_idx = []
    for i in range(len(split_ids)):
        if split_ids[i] in X_train:
            train_idx.append(i)
        elif split_ids[i] in X_test:
            test_idx.append(i)
        elif split_ids[i]in X_val:
            val_idx.append(i)

    return train_idx,test_idx,val_idx
    
class cobreTimeWindows(Dataset):
    def __init__(self,ts_path,pheno_path,test_size=0.20,val_size=0.10,random_state=111,n_timepoints=50):
        self.pheno_path = pheno_path
        pheno = pd.read_csv(pheno_path,delimiter='\t')
        pheno = pheno[pheno['ID']!=40075]
        pheno.sort_values('ID',inplace=True)
        self.labels = pheno['Subject Type'].map({'Patient':1,'Control':0}).tolist()

        self.ts_path = ts_path
        self.timeseries = [np.load(os.path.join(ts_path,p)) for p in natsort.natsorted(os.listdir(ts_path))]
        self.sub_ids = [int(p.split('_')[1]) for p in natsort.natsorted(os.listdir(ts_path))]

        #filter out bad sub
        idx = self.sub_ids.index(40075)
        del self.sub_ids[idx]
        del self.timeseries[idx]

        #split timeseries
        self.split_timeseries,split_labs = split_ts_labels(self.timeseries,[self.sub_ids,self.labels],n_timepoints=n_timepoints)
        self.split_sub_ids = split_labs[0]
        self.split_labels = split_labs[1]
        self.split_ids = split_labs[-1]

        #train test val split the data (each sub's splits in one category only)
        self.train_idx,self.test_idx,self.val_idx = train_test_val_splits(self.split_sub_ids,
                                                                            test_size=test_size,
                                                                            val_size=val_size,
                                                                            random_state=random_state)

    def __len__(self):
        return len(self.split_sub_ids)

    def __getitem__(self,idx):
        ts = torch.from_numpy(self.split_timeseries[idx]).transpose(0,1)
        sub_id = self.split_sub_ids[idx]
        label = self.split_labels[idx]
        split_id = self.split_ids[idx]
        #return {'timeseries':ts,
                 #"sub_id":sub_id, 
        #         'label':label, 
                 #"split_id":split_id
        #         }
        return ts,label

# Model
 - C input channels (n time points of timeseries)
 - 6 GCN layers
 - 32 graph filters at each layer
 - Global average pooling layer
 - 2 fully connected layers
 - 256, 128 units
 - ReLU activation
 - Softmax last layer

In [6]:
class GCN(torch.nn.Module):
    def __init__(self,edge_index,edge_weight,n_timepoints = 50):
        super().__init__()
        #forward(x, edge_index, edge_weight: Optional[torch.Tensor] = None
        self.edge_index = edge_index
        self.edge_weight = edge_weight
        self.conv1 = tg.nn.ChebConv(in_channels=n_timepoints,out_channels=32,K=2,bias=True)
        self.conv2 = tg.nn.ChebConv(in_channels=32,out_channels=32,K=2,bias=True)
        self.conv3 = tg.nn.ChebConv(in_channels=32,out_channels=32,K=2,bias=True)
        self.conv4 = tg.nn.ChebConv(in_channels=32,out_channels=32,K=2,bias=True)
        self.conv5 = tg.nn.ChebConv(in_channels=32,out_channels=32,K=2,bias=True)
        self.conv6 = tg.nn.ChebConv(in_channels=32,out_channels=32,K=2,bias=True)
        #self.fc1 = nn.Linear(512, 256)
        #self.fc2 = nn.Linear(256, 128)
        #self.fc3 = nn.Linear(128,2)
        self.fc1 = nn.Linear(512*32, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 2)
        self.dropout = nn.Dropout(0.5)

    def forward(self,x):
        #print(x.size())
        x = self.conv1(x,self.edge_index,self.edge_weight)
        x = F.relu(x)
        x = self.conv2(x,self.edge_index,self.edge_weight)
        x = F.relu(x)
        x = self.conv3(x,self.edge_index,self.edge_weight)
        x = F.relu(x)
        x = self.conv4(x,self.edge_index,self.edge_weight)
        x = F.relu(x)
        x = self.conv5(x,self.edge_index,self.edge_weight)
        x = F.relu(x)
        x = self.conv6(x,self.edge_index,self.edge_weight)
        #print(x.size())
        x = tg.nn.global_mean_pool(x,torch.from_numpy(np.array(range(x.size(0)),dtype=int)))

        #print(x.size())
        ####x = torch.transpose(x,1,2)
        x = x.view(-1, 512*32)
        x = self.fc1(x)
        #print(x.size())
        x = self.dropout(x)
        x = self.fc2(x)
        #print(x.size())
        x = self.dropout(x)
        x = self.fc3(x)
        #print(x.size())
        #x = F.softmax(x,dim=0)
        return x
        


In [7]:
np.array(range(3))

array([0, 1, 2])

In [8]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.sampler)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        #print(pred.size())
        #print(y)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss, current = loss.item(), batch * len(X)
        print(batch)
        print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.sampler)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model.forward(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [9]:
graph = make_group_graph(conn_path)

In [10]:
data = cobreTimeWindows(ts_path,pheno_path,n_timepoints=15)
batch_size = 128

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(data.train_idx)
test_sampler = SubsetRandomSampler(data.test_idx)
val_sampler = SubsetRandomSampler(data.val_idx)

train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, sampler=test_sampler)
val_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, sampler=val_sampler)

In [11]:
gcn = GCN(graph.edge_index,graph.weight,n_timepoints=15)

In [12]:
data[0][0].size()

torch.Size([512, 50])

In [37]:
learning_rate = 0.1
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(gcn.parameters(), lr=learning_rate)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, gcn, loss_fn, optimizer)
    test_loop(test_loader, gcn, loss_fn)
print("Done!")

Epoch 1
-------------------------------
0
loss: 0.716936  [    0/  303]
1
loss: 13.215805  [  128/  303]
2
loss: 82.561584  [   94/  303]
Test Error: 
 Accuracy: 58.6%, Avg loss: 38438.609195 

Epoch 2
-------------------------------
0
loss: 3978590.500000  [    0/  303]
1
loss:     inf  [  128/  303]
2
loss:     nan  [   94/  303]
Test Error: 
 Accuracy: 58.6%, Avg loss:      nan 

Epoch 3
-------------------------------
0
loss:     nan  [    0/  303]
1
loss:     nan  [  128/  303]
2
loss:     nan  [   94/  303]
Test Error: 
 Accuracy: 58.6%, Avg loss:      nan 

Epoch 4
-------------------------------
0
loss:     nan  [    0/  303]


In [33]:
gcn

GCN(
  (conv1): ChebConv(50, 32, K=2, normalization=sym)
  (conv2): ChebConv(32, 32, K=2, normalization=sym)
  (conv3): ChebConv(32, 32, K=2, normalization=sym)
  (conv4): ChebConv(32, 32, K=2, normalization=sym)
  (conv5): ChebConv(32, 32, K=2, normalization=sym)
  (conv6): ChebConv(32, 32, K=2, normalization=sym)
  (fc1): Linear(in_features=16384, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=2, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [98]:
np.unique(graph.edge_index)

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
       156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
       169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 18