# 

## imports

In [1]:
%load_ext autoreload
%autoreload 3

In [2]:
import BrainGB
import BrainGB.models as bgb_models
import torch_geometric as pyg

  from .autonotebook import tqdm as notebook_tqdm


In [128]:
import argparse
import logging
from typing import Optional

import pdb
import inspect
from collections import defaultdict
from random import choice
from functools import wraps

from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn import metrics

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric.nn as gnn
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader

import networkx as nx

#import neurograph as ng
import neurograph.config as cfg

In [163]:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("test")

DEBUG:root:test


In [4]:
pyg.__version__

'2.0.3'

In [5]:
pyg.seed_everything(1380)

----

## look at files

In [6]:
!ls {cfg.DATA_PATH}

fmri


In [7]:
fmri_path = cfg.DATA_PATH / 'fmri'

In [8]:
!ls {fmri_path}

alexeev  cobre


In [9]:
cobre_path = fmri_path / 'cobre'

In [10]:
!ls {cobre_path}

cobre_aal_processed  cobre_cm_aal		  cobre_labels.csv
cobre_cm_39	     cobre_func_connectivity.csv


In [11]:
cobre_aal_path = cobre_path / 'cobre_cm_aal'

In [12]:
cobre_aal_path

PosixPath('/home/gl/skoltech/imaging/data/fmri/cobre/cobre_cm_aal')

## global csv files

In [13]:
cobre_target = pd.read_csv(cobre_path / 'cobre_labels.csv')

In [14]:
cobre_target = cobre_target[['ID', 'target']].copy()

In [15]:
# check that there are no different labels assigned to the same ID
cobre_target.groupby('ID').target.nunique().max()

1

In [16]:
cobre_target.drop_duplicates(inplace=True)

In [17]:
cobre_target.shape

(154, 2)

In [18]:
cobre_target.set_index('ID', inplace=True)

In [19]:
cobre_target.head(3)

Unnamed: 0_level_0,target
ID,Unnamed: 1_level_1
A00007409,No_Known_Disorder
A00031597,Schizophrenia_Strict
A00022500,Schizophrenia_Strict


In [20]:
cobre_target.target.unique()

array(['No_Known_Disorder', 'Schizophrenia_Strict', 'Schizoaffective'],
      dtype=object)

In [21]:
cobre_target.target.value_counts()

No_Known_Disorder       77
Schizophrenia_Strict    66
Schizoaffective         11
Name: target, dtype: int64

### drop schizoaffective

In [22]:
cobre_target = cobre_target[cobre_target.target != 'Schizoaffective'].copy()

### label encoding for cobre 

In [23]:
target_label2id = {i: x for i, x in enumerate(cobre_target.target.unique())}
target_id2label = {x: i for i,x in target_label2id.items()}

In [24]:
target_label2id

{0: 'No_Known_Disorder', 1: 'Schizophrenia_Strict'}

In [25]:
cobre_target.target = cobre_target.target.map(target_id2label)

In [26]:
cobre_target.target.value_counts()

0    77
1    66
Name: target, dtype: int64

## load conn matrices

In [27]:
!ls {cobre_aal_path} | wc -l

326


In [29]:
326 // 2

163

In [30]:
def load_matrices(path, dataset_name='cobre'):
    data = {}
    embed = {}
    # brain regi
    regions = None
    
    for p in path.glob('*.csv'):
        
        name = p.stem.split('_')[0]
        
        if dataset_name == 'cobre':
            name = name.replace('sub-', '')
        
        x = pd.read_csv(p).drop('Unnamed: 0', axis=1)
        
        values = x.values.astype(np.float32)
        if p.stem.endswith('_embed'):
            embed[name] = values
        else:
            data[name] = values
            if regions is None:
                regions = {i: c for i, c in enumerate(x.columns)}
        
    return data, embed, regions

In [31]:
cobre_conn, cobre_embed, cobre_regions = load_matrices(cobre_aal_path)

In [32]:
len(cobre_conn.keys())

163

In [33]:
cobre_conn.keys() == cobre_embed.keys()

True

In [34]:
cobre_target.shape

(143, 1)

In [35]:
cobre_target.index.nunique()

143

## take a look a few examples from both datasets

In [23]:
c = choice(list(cobre_conn.keys()))

In [24]:
len(set(cobre_conn.keys()))

163

In [25]:
c

'A00028052'

In [26]:
cobre_target.index.nunique()

154

In [27]:
cm = cobre_conn[c]
ce = cobre_embed[c]
_y = cobre_target.loc[c].values

In [28]:
# for k in cobre_conn:
#     assert (cobre_conn[k].columns == cobre_embed[k].columns).all()

In [29]:
ce.shape

(150, 116)

## Get PyG.Data from conn matrices

In [36]:
def one_hot_embed(n: int):
    return torch.eye(n)

In [37]:
def square_check(f):
    # decotated function must take a np.ndarray as the first argument
    
    @wraps(f)
    def wrapper(*args, **kwargs):
        m = args[0]
        assert isinstance(m, np.ndarray), 'input matrix must be np.ndarray!'
        assert m.ndim == 2, 'input matrix must be 2d array!'
        assert m.shape[0] == cm.shape[1], 'input matrix must be square!'
        
        return f(*args, **kwargs)

    return wrapper

In [38]:
@square_check
def find_thr(
    cm: np.ndarray,
    k=5,
):
    assert cm.ndim == 2, 'adj matrix must be 2d array!'
    assert cm.shape[0] == cm.shape[1], 'adj matrix must be square!'
    
    n = cm.shape[0]
    abs_cm = np.abs(cm)
    
    # find thr to get the desired k
    # = average number of edges for a node
    vals = np.sort(abs_cm.ravel())
    thr_idx = min(max(0, n**2 - 2*k*n - 1), n**2 - 1)
    thr = vals[thr_idx]
    
    return thr

In [39]:
@square_check
def apply_thr(cm: np.ndarray, thr: float):
    abs_cm = np.abs(cm)
    idx = np.nonzero(abs_cm > thr)
    edge_index = torch.LongTensor(np.stack(idx))
    edge_weights = torch.FloatTensor(cm[idx])
    
    return edge_index, edge_weights

In [40]:
@square_check
def cm_to_edges(cm: np.ndarray):
    """
    Convert CM to (edge_index, edge_weights) of a fully connected weighted graph
    (including self-loops with zero weights)
    return: (edge_index, edge_weights)
    """
    cm = torch.FloatTensor(cm)
    index = (torch.isnan(cm) == 0).nonzero(as_tuple=True)
    edge_attr = torch.abs(cm[index])
    
    return torch.stack(index, dim=0), edge_attr

In [66]:
@square_check
def prepare_pyg_data(
    cm: np.ndarray,
    subj_id: str,
    targets: pd.DataFrame,
) -> Data:
    
    # fully connected graph
    # TODO
    n = cm.shape[0]
    edge_index, edge_attr = cm_to_edges(cm)
    
    # compute initial node embeddings -> just original weights
    x = torch.from_numpy(cm)

    # get labels from DF via subject_id
    y = torch.LongTensor(targets.loc[subj_id].values)
    
    data = Data(
        edge_index=edge_index,
        edge_attr=edge_attr,
        x=x,
        num_nodes=n,
        y=y,
        subj_id=subj_id,
    ) 
    #data.validate()
    return data

## form datalist

In [67]:
datalist = []
missed_ids = set()
for subj_id, cm in cobre_conn.items():
    try:
        # try to process a graph
        datalist.append(prepare_pyg_data(cm, subj_id, cobre_target))
    except KeyError:
        missed_ids.add(subj_id)

In [69]:
datalist[0].y

tensor([0])

In [43]:
len(datalist), len(missed_ids)

(142, 21)

## Prepare data

### Define dataset class

In [44]:
class ListDataset(InMemoryDataset):
    def __init__(self, root, data_list):
        # first save 
        self.data_list = data_list
        super().__init__(root=root)
        self.data, self.slices = torch.load(self.processed_paths[0])
    
    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        # https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html
        data, slices = self.collate(self.data_list)
        torch.save((data, slices), self.processed_paths[0])

### Set processed path and create a dataset instance

In [45]:
# dir for processed cobre
cobre_aal_processed = cobre_path / 'cobre_aal_processed'

In [46]:
cobre_aal_processed

PosixPath('/home/gl/skoltech/imaging/data/fmri/cobre/cobre_aal_processed')

### Define train and test ListDatasets

In [47]:
#train_lst, test_lst = train_test_split(datalist, test_size=0.2, random_state=1380)

In [48]:
# train_y = [g.y.item() for g in train_lst]
# test_y = [g.y.item() for g in test_lst]

In [49]:
cobre_ds = ListDataset(root=cobre_aal_processed, data_list=datalist)

In [50]:
cobre_ds._indices

In [51]:
len(cobre_ds)

142

In [52]:
# check that data_list is intact
cobre_ds.data_list == datalist

True

In [265]:
train_idx, test_idx = train_test_split(
    np.arange(len(cobre_ds)),
    test_size=0.2,
    random_state=1380,
    stratify=cobre_ds.data.y.numpy(),
)

In [266]:
train_idx.shape, test_idx.shape

((113,), (29,))

In [267]:
train_ds = cobre_ds[train_idx]
test_ds = cobre_ds[test_idx]

In [268]:
test_ds.data.y[test_ds._indices].shape

torch.Size([29])

In [269]:
train_ds.data.y[train_ds._indices].shape

torch.Size([113])

In [282]:
len(train_ds.data.subj_id)

142

#### extract subj_id from each subset

In [304]:
def get_ids(ds):
    return [g.subj_id for g in ds]

In [310]:
train_ids = get_ids(train_ds)
test_ids  = get_ids(test_ds)

In [311]:
len(train_ids), len(test_ids)

(113, 29)

In [312]:
len(set(train_ids)), len(set(test_ids))

(113, 29)

In [313]:
set(train_ids) & set(test_ids)

set()

In [58]:
# pdb.runcall(
#     bgat_conv1.forward, g.x, g.edge_index, g.edge_attr,
# )

------------------

## Run training 

### Define model and training config

In [178]:
model

GAT(
  (activation): ReLU()
  (convs): ModuleList(
    (0): Sequential(
      (0): MPGATConv(116, 4, heads=2)
      (1): Linear(in_features=8, out_features=4, bias=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): MPGATConv(4, 4, heads=2)
      (1): Linear(in_features=8, out_features=64, bias=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Linear(in_features=64, out_features=8, bias=True)
      (4): LeakyReLU(negative_slope=0.2)
      (5): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (fcn): Sequential(
    (0): Linear(in_features=928, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=256, out_features=32, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=32, out_features=1, bias=True)
  )
)

### define train function

In [179]:
def train_and_evaluate(model, train_loader, test_loader, optimizer, device, args):
    # set model to train
    model.train()
    
    # init list
    accs, aucs, macros = [], [], []
    epoch_num = args.epochs

    for i in range(epoch_num):
        loss_all = 0
        for data in train_loader:
            data = data.to(device)
            x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

            #if args.mixup:
            #    data, y_a, y_b, lam = mixup(data)
            
            # zero grads
            optimizer.zero_grad()
            
            # forward
            out = model(x, edge_index, edge_attr, batch )

            if args.mixup:
                loss = mixup_criterion(F.nll_loss, out, y_a, y_b, lam)
            else:
                loss = F.binary_cross_entropy_with_logits(out.squeeze(), data.y.float())

            loss.backward()
            optimizer.step()

            loss_all += loss.item()
        
        epoch_loss = loss_all / len(train_loader.dataset)

        train_micro, train_auc, train_macro = evaluate(model, device, train_loader)
        logging.info(f'(Train) | Epoch={i:03d}, loss={epoch_loss:.4f}, '
                     f'train_micro={(train_micro * 100):.2f}, train_macro={(train_macro * 100):.2f}, '
                     f'train_auc={(train_auc * 100):.2f}')

        if (i + 1) % args.test_interval == 0:
            test_micro, test_auc, test_macro = evaluate(model, device, test_loader)
            accs.append(test_micro)
            aucs.append(test_auc)
            macros.append(test_macro)
            
            text = f'(Train Epoch {i}), test_micro={(test_micro * 100):.2f}, ' \
                   f'test_macro={(test_macro * 100):.2f}, test_auc={(test_auc * 100):.2f}\n'
            logging.info(text)

        if args.enable_nni:
            nni.report_intermediate_result(train_auc)

    accs, aucs, macros = np.sort(np.array(accs)), np.sort(np.array(aucs)), np.sort(np.array(macros))
    
    return accs.mean(), aucs.mean(), macros.mean()

In [180]:
@torch.no_grad()
def evaluate(model, device, loader, thr=0.5,):
    # compute metrics on valid data
    
    model.eval()
    preds, trues, preds_prob = [], [], []

    for data in loader:
        data = data.to(device)
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        c = model(x, edge_index, edge_attr, batch)
        
        # label preds
        preds += (torch.sigmoid(c) > thr).long().detach().cpu().tolist()
        preds_prob += (torch.sigmoid(c)).detach().cpu().tolist()
        trues += data.y.detach().long().cpu().tolist()

    train_auc = metrics.roc_auc_score(trues, preds_prob)

    if np.isnan(train_auc):
        train_auc = 0.5
        
    train_micro = metrics.f1_score(trues, preds, average='micro')
    train_macro = metrics.f1_score(trues, preds, average='macro', labels=[0, 1])

    return train_micro, train_auc, train_macro


### Run cross-validation

In [211]:
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=1380)

In [254]:
config = argparse.Namespace(
    # GAT params
    pooling='concat',
    hidden_dim=8,
    num_heads=4,
    n_GNN_layers=2,
    edge_emb_dim=1,
    gat_mp_type='node_concate',
    dropout=0.1,
    # training
    mixup=False,
    enable_nni=False,
    test_interval=1,
    epochs=20,
    lr=0.001,
    weight_decay=1e-4,
)

In [255]:
model = bgb_models.gat.GAT(
    input_dim=116,
    args=config,
    num_nodes=116,
    num_classes=1,
)
model


GAT(
  (activation): ReLU()
  (convs): ModuleList(
    (0): Sequential(
      (0): MPGATConv(116, 8, heads=4)
      (1): Linear(in_features=32, out_features=8, bias=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): MPGATConv(8, 8, heads=4)
      (1): Linear(in_features=32, out_features=64, bias=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Linear(in_features=64, out_features=8, bias=True)
      (4): LeakyReLU(negative_slope=0.2)
      (5): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (fcn): Sequential(
    (0): Linear(in_features=928, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=256, out_features=32, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=32, out_features=1, bias=True)
  )
)

In [256]:
del model

In [294]:
from collections import defaultdict

In [415]:
%%time
device = 'cpu'

accs, aucs, macros = [], [], []
models = []

fold_ids = {}
for fold_i, (train_idx, val_idx) in enumerate(cv.split(np.arange(len(train_ds)), y=train_ds.data.y[train_ds._indices])):
    logging.info(f'Fold {fold_i}')
    
    train_subset = train_ds[train_idx]
    val_subset = train_ds[val_idx]
    fold_ids[fold_i] = {'train': get_ids(train_subset), 'valid': get_ids(val_subset)}
                             
#     # split train into train and valid, create loaders
#     train_loader = DataLoader(dataset=train_ds[train_idx], batch_size=8, shuffle=True)
#     val_loader = DataLoader(dataset=train_ds[val_idx], batch_size=8, shuffle=False)
    
#     # create model instance
#     model = bgb_models.gat.GAT(
#         input_dim=116,
#         args=config,
#         num_nodes=116,
#         num_classes=1,
#     )
#     # set optimizer
#     optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

#     # train and eval
#     test_micro, test_auc, test_macro = train_and_evaluate(
#         model,
#         train_loader,
#         val_loader,
#         optimizer,
#         device=device,
#         args=config,
#     )

#     # evaluate again
#     test_micro, test_auc, test_macro = evaluate(model, device, val_loader)
    
#     # print valid metrics
#     logging.info(f'(Initial Performance Last Epoch) | test_micro={(test_micro * 100):.2f}, '
#                  f'test_macro={(test_macro * 100):.2f}, test_auc={(test_auc * 100):.2f}')

#     # store metrics for the current fold
#     accs.append(test_micro)
#     aucs.append(test_auc)
#     macros.append(test_macro)
#     models.append(model)
#     del model

INFO:root:Fold 0
INFO:root:Fold 1
INFO:root:Fold 2
INFO:root:Fold 3
INFO:root:Fold 4


CPU times: user 9.14 ms, sys: 0 ns, total: 9.14 ms
Wall time: 8.51 ms


In [416]:
fold_ids['test'] = test_ids

In [418]:
import json

In [419]:
with open(cobre_aal_processed / 'cobre_5_fold_split_w_test.json', 'w') as f:
    json.dump(fold_ids, f)

In [420]:
with open(cobre_aal_processed / 'cobre_5_fold_split_w_test.json', 'r') as f:
    _x = json.load(f)

In [428]:
cobre_target.loc[_x['0']['valid']].shape

(23, 1)

In [263]:
np.mean(accs), np.std(accs)

(0.7355731225296444, 0.12122472295151034)

In [259]:
aucs, np.mean(aucs), np.std(aucs)

([0.8484848484848485,
  0.5984848484848485,
  0.8939393939393939,
  0.7833333333333333,
  0.8916666666666666],
 0.8031818181818181,
 0.10992150133374144)

In [260]:
macros

[0.6956521739130435,
 0.518095238095238,
 0.8695652173913043,
 0.7603485838779955,
 0.8181818181818182]

In [412]:
ids_df2['train'].fillna('[]')

0    ['A00015518', 'A00014590', 'A00022810', 'A0003...
1    ['A00015518', 'A00014590', 'A00022810', 'A0003...
2    ['A00015518', 'A00014590', 'A00022810', 'A0003...
3    ['A00015518', 'A00037665', 'A00003150', 'A0001...
4    ['A00014590', 'A00022810', 'A00003150', 'A0001...
5                                                   []
Name: train, dtype: object

In [413]:
eval('[]')

[]

In [325]:
#set(train_ids)

In [414]:
ids_df2['train'].fillna('[]').apply(eval)

0    [A00015518, A00014590, A00022810, A00037665, A...
1    [A00015518, A00014590, A00022810, A00037665, A...
2    [A00015518, A00014590, A00022810, A00037665, A...
3    [A00015518, A00037665, A00003150, A00014830, A...
4    [A00014590, A00022810, A00003150, A00014830, A...
5                                                   []
Name: train, dtype: object

----------

----------

----------

----------

----------

----------

----------

## Visualize via networkx

In [67]:
g = datalist[14]

In [68]:
g

Data(x=[116, 116], edge_index=[2, 13456], y=[1], edge_weight=[13456], num_nodes=116, subj_id='A00031764')

In [79]:
g_nx = pyg.utils.to_networkx(g, to_undirected=False)

In [80]:
len(g_nx.nodes), len(g_nx.edges),

(116, 13456)

In [81]:
len(g_nx.nodes), len(g_nx.edges)

(116, 13456)

In [83]:
#nx.cycle_basis(g_nx)

In [84]:
nx.draw_circular(g_nx, with_labels=True)

TypeError: '_AxesStack' object is not callable

<Figure size 640x480 with 0 Axes>

### edge message passing

In [876]:
H = torch.tensor(
    [
        [1., 0.],
        [2., 0.],
        [3., 0.],
    ],
)

In [877]:
W = torch.tensor(
    [
        [0,  5., 7.],
        [5., 0,  11],
        [7,  11, 0,]
    ],
)

In [882]:
M = torch.einsum('jk, ij -> ijk', [H, W])