In [1]:
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import os
import copy
import json
import pickle
import scipy.sparse as sp

In [2]:
root = 'data_used'
label_list = os.listdir(root)
label_list

['ASD', 'HC']

### GAugM

## 1. calculate A_pred by VGAE (variational graph autoencoder)

In [3]:
import os
import sys
import time
import pickle
import warnings
import networkx as nx
import dgl
from dgl import DGLGraph
import torch
from collections import defaultdict
from sklearn.preprocessing import normalize
import time
import copy
import pickle
import warnings
import torch
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, auc


Using backend: pytorch


### utils

In [4]:
def sparse_to_tuple(sparse_mx):
    if not sp.isspmatrix_coo(sparse_mx):
        sparse_mx = sparse_mx.tocoo()
    coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
    values = sparse_mx.data
    shape = sparse_mx.shape
    return coords, values, shape


def get_scores(edges_pos, edges_neg, A_pred, adj_label):
    # get logists and labels
    preds = A_pred[edges_pos.T]
    preds_neg = A_pred[edges_neg.T]
    logists = np.hstack([preds, preds_neg])
    labels = np.hstack([np.ones(preds.size(0)), np.zeros(preds_neg.size(0))])
    # logists = A_pred.view(-1)
    # labels = adj_label.to_dense().view(-1)
    # calc scores
    roc_auc = roc_auc_score(labels, logists)
    ap_score = average_precision_score(labels, logists)
    precisions, recalls, thresholds = precision_recall_curve(labels, logists)
    pr_auc = auc(recalls, precisions)
    warnings.simplefilter('ignore', RuntimeWarning)
    f1s = np.nan_to_num(2*precisions*recalls/(precisions+recalls))
    best_comb = np.argmax(f1s)
    f1 = f1s[best_comb]
    pre = precisions[best_comb]
    rec = recalls[best_comb]
    thresh = thresholds[best_comb]
    # calc reconstracted adj_mat and accuracy with the threshold for best f1
    adj_rec = copy.deepcopy(A_pred)
    adj_rec[adj_rec < thresh] = 0
    adj_rec[adj_rec >= thresh] = 1
    labels_all = adj_label.to_dense().view(-1).long()
    preds_all = adj_rec.view(-1).long()
    recon_acc = (preds_all == labels_all).sum().float() / labels_all.size(0)
    results = {'roc': roc_auc,
               'pr': pr_auc,
               'ap': ap_score,
               'pre': pre,
               'rec': rec,
               'f1': f1,
               'acc': recon_acc,
               'adj_recon': adj_rec}
    return results

def train_model(device, epochs, gae, criterion, dl, vgae):
    optimizer = torch.optim.Adam(vgae.parameters(), lr=0.01)
    # weights for log_lik loss
    adj_t = dl.adj_train
    norm_w = adj_t.shape[0]**2 / float((adj_t.shape[0]**2 - adj_t.sum()) * 2)
    pos_weight = torch.FloatTensor([float(adj_t.shape[0]**2 - adj_t.sum()) / adj_t.sum()]).to(device)
    # move input data and label to gpu if needed
    features = dl.features.to(device)
    adj_label = dl.adj_label.to_dense().to(device)

    best_vali_criterion = 0.0
    best_state_dict = None
    vgae.train()
    for epoch in range(epochs):
        t = time.time()
        A_pred = vgae(features)
        optimizer.zero_grad()
        loss = log_lik = norm_w*F.binary_cross_entropy_with_logits(A_pred, adj_label, pos_weight=pos_weight)
        if not gae:
            kl_divergence = 0.5/A_pred.size(0) * (1 + 2*vgae.logstd - vgae.mean**2 - torch.exp(2*vgae.logstd)).sum(1).mean()
            loss -= kl_divergence

        A_pred = torch.sigmoid(A_pred).detach().cpu()
        r = get_scores(dl.val_edges, dl.val_edges_false, A_pred, dl.adj_label)
        print('Epoch{:3}: train_loss: {:.4f} recon_acc: {:.4f} val_roc: {:.4f} val_ap: {:.4f} f1: {:.4f} time: {:.4f}'.format(
            epoch+1, loss.item(), r['acc'], r['roc'], r['ap'], r['f1'], time.time()-t))
        if r[criterion] > best_vali_criterion:
            best_vali_criterion = r[criterion]
            best_state_dict = copy.deepcopy(vgae.state_dict())
            # r_test = get_scores(dl.test_edges, dl.test_edges_false, A_pred, dl.adj_label)
            r_test = r
            print("          test_roc: {:.4f} test_ap: {:.4f} test_f1: {:.4f} test_recon_acc: {:.4f}".format(
                    r_test['roc'], r_test['ap'], r_test['f1'], r_test['acc']))
        loss.backward()
        optimizer.step()

    print("Done! final results: test_roc: {:.4f} test_ap: {:.4f} test_f1: {:.4f} test_recon_acc: {:.4f}".format(
            r_test['roc'], r_test['ap'], r_test['f1'], r_test['acc']))

    vgae.load_state_dict(best_state_dict)
    return vgae

def gen_graphs(filename, gae, device, gen_graphs, dl, vgae):
    adj_orig = dl.adj_orig
    assert adj_orig.diagonal().sum() == 0
    # sp.csr_matrix
    if gae:
        pickle.dump(adj_orig, open(f'graphs/{filename}_graph_0_gae.pkl', 'wb'))
    else:
        pickle.dump(adj_orig, open(f'graphs/{filename}_graph_0.pkl', 'wb'))
    # sp.lil_matrix
    pickle.dump(dl.features_orig, open(f'graphs/{filename}_features.pkl', 'wb'))
    features = dl.features.to(device)
    for i in range(gen_graphs):
        with torch.no_grad():
            A_pred = vgae(features)
        A_pred = torch.sigmoid(A_pred).detach().cpu()
        r = get_scores(dl.val_edges, dl.val_edges_false, A_pred, dl.adj_label)
        adj_recon = A_pred.numpy()
        np.fill_diagonal(adj_recon, 0)
        # np.ndarray
        if gae:
            filename = f'graphs/{filename}_graph_{i+1}_logits_gae.pkl'
        else:
            filename = f'graphs/{filename}_graph_{i+1}_logits.pkl'
        pickle.dump(adj_recon, open(filename, 'wb'))


### 2. VGAE

In [5]:

class VGAE(nn.Module):
    def __init__(self, adj, dim_in, dim_h, dim_z, gae):
        super(VGAE,self).__init__()
        self.dim_z = dim_z
        self.gae = gae
        self.base_gcn = GraphConvSparse(dim_in, dim_h, adj)
        self.gcn_mean = GraphConvSparse(dim_h, dim_z, adj, activation=False)
        self.gcn_logstd = GraphConvSparse(dim_h, dim_z, adj, activation=False)

    def encode(self, X):
        hidden = self.base_gcn(X)
        self.mean = self.gcn_mean(hidden)
        if self.gae:
            # graph auto-encoder
            return self.mean
        else:
            # variational graph auto-encoder
            self.logstd = self.gcn_logstd(hidden)
            gaussian_noise = torch.randn_like(self.mean)
            sampled_z = gaussian_noise*torch.exp(self.logstd) + self.mean
            return sampled_z

    def decode(self, Z):
        A_pred = Z @ Z.T
        return A_pred

    def forward(self, X):
        Z = self.encode(X)
        A_pred = self.decode(Z)
        return A_pred

class GraphConvSparse(nn.Module):
    def __init__(self, input_dim, output_dim, adj, activation=True):
        super(GraphConvSparse, self).__init__()
        self.weight = self.glorot_init(input_dim, output_dim)
        self.adj = adj
        self.activation = activation

    def glorot_init(self, input_dim, output_dim):
        init_range = np.sqrt(6.0/(input_dim + output_dim))
        initial = torch.rand(input_dim, output_dim)*2*init_range - init_range
        return nn.Parameter(initial)

    def forward(self, inputs):
        x = inputs @ self.weight
        x = self.adj @ x
        if self.activation:
            return F.elu(x)
        else:
            return x

In [6]:
def sample_graph_det(adj_orig, A_pred, remove_pct, add_pct, ratio):
    if remove_pct == 0 and add_pct == 0:
        return copy.deepcopy(adj_orig)
    orig_upper = sp.triu(adj_orig, 1)
    n_edges = orig_upper.nnz
    print('n_edges: ', n_edges)
    edges = np.asarray(orig_upper.nonzero()).T
    ratio = ratio
    if remove_pct:
        n_remove = int(n_edges * remove_pct / ratio)
        pos_probs = A_pred[edges.T[0], edges.T[1]]
        e_index_2b_remove = np.argpartition(pos_probs, n_remove)[:n_remove]
        mask = np.ones(len(edges), dtype=bool)
        mask[e_index_2b_remove] = False
        edges_pred = edges[mask]
    else:
        edges_pred = edges

    if add_pct:
        n_add = int(n_edges * add_pct / ratio)
        # deep copy to avoid modifying A_pred
        A_probs = np.array(A_pred)
        # make the probabilities of the lower half to be zero (including diagonal)
        A_probs[np.tril_indices(A_probs.shape[0])] = 0
        # make the probabilities of existing edges to be zero
        A_probs[edges.T[0], edges.T[1]] = 0
        all_probs = A_probs.reshape(-1)
        e_index_2b_add = np.argpartition(all_probs, -n_add)[-n_add:]
        new_edges = []
        for index in e_index_2b_add:
            i = int(index / A_probs.shape[0])
            j = index % A_probs.shape[0]
            new_edges.append([i, j])
        edges_pred = np.concatenate((edges_pred, new_edges), axis=0)
    adj_pred = sp.csr_matrix((np.ones(len(edges_pred)), edges_pred.T), shape=adj_orig.shape)
    adj_pred = adj_pred + adj_pred.T
    return adj_pred

In [7]:
class DataLoader():
    def __init__(self, adj_orig, features_orig):
        # self.args = args
        self.load_data(adj_orig, features_orig)
        self.mask_test_edges(0.05, 0.1, True)
        self.normalize_adj()
        self.to_pyt_sp()

    def load_data(self, adj, features):
        if adj.diagonal().sum() > 0:
            adj = sp.coo_matrix(adj)
            adj.setdiag(0)
            adj.eliminate_zeros()
            adj = sp.csr_matrix(adj)
        features = features
        if isinstance(features, torch.Tensor):
            features = features.numpy()
        features = sp.csr_matrix(features)
        self.adj_orig = adj
        self.features_orig = normalize(features, norm='l1', axis=1)

    def mask_test_edges(self, val_frac, test_frac, no_mask=False):
        adj = self.adj_orig
        assert adj.diagonal().sum() == 0

        adj_triu = sp.triu(adj)
        edges = sparse_to_tuple(adj_triu)[0]
        edges_all = sparse_to_tuple(adj)[0]
        num_test = int(np.floor(edges.shape[0] * test_frac))
        num_val = int(np.floor(edges.shape[0] * val_frac))

        all_edge_idx = list(range(edges.shape[0]))
        np.random.shuffle(all_edge_idx)
        val_edge_idx = all_edge_idx[:num_val]
        test_edge_idx = all_edge_idx[num_val:(num_val + num_test)]
        test_edges = edges[test_edge_idx]
        val_edges = edges[val_edge_idx]
        if no_mask:
            train_edges = edges
        else:
            train_edges = np.delete(edges, np.hstack([test_edge_idx, val_edge_idx]), axis=0)

        def ismember(a, b, tol=5):
            rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
            return np.any(rows_close)

        test_edges_false = []
        while len(test_edges_false) < len(test_edges):
            idx_i = np.random.randint(0, adj.shape[0])
            idx_j = np.random.randint(0, adj.shape[0])
            if idx_i == idx_j:
                continue
            if ismember([idx_i, idx_j], edges_all):
                continue
            if test_edges_false:
                if ismember([idx_j, idx_i], np.array(test_edges_false)):
                    continue
                if ismember([idx_i, idx_j], np.array(test_edges_false)):
                    continue
            test_edges_false.append([idx_i, idx_j])

        val_edges_false = []
        while len(val_edges_false) < len(val_edges):
            idx_i = np.random.randint(0, adj.shape[0])
            idx_j = np.random.randint(0, adj.shape[0])
            if idx_i == idx_j:
                continue
            if ismember([idx_i, idx_j], train_edges):
                continue
            if ismember([idx_j, idx_i], train_edges):
                continue
            if ismember([idx_i, idx_j], val_edges):
                continue
            if ismember([idx_j, idx_i], val_edges):
                continue
            if val_edges_false:
                if ismember([idx_j, idx_i], np.array(val_edges_false)):
                    continue
                if ismember([idx_i, idx_j], np.array(val_edges_false)):
                    continue
            val_edges_false.append([idx_i, idx_j])

        adj_train = sp.csr_matrix((np.ones(train_edges.shape[0]), (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)
        self.adj_train = adj_train + adj_train.T
        self.adj_label = adj_train + sp.eye(adj_train.shape[0])

        # NOTE: these edge lists only contain single direction of edge!
        self.val_edges = val_edges
        self.val_edges_false = np.asarray(val_edges_false)
        self.test_edges = test_edges
        self.test_edges_false = np.asarray(test_edges_false)

    def normalize_adj(self):
        adj_ = sp.coo_matrix(self.adj_train)
        adj_.setdiag(1)
        rowsum = np.array(adj_.sum(1))
        degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
        self.adj_norm = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()

    def to_pyt_sp(self):
        adj_norm_tuple = sparse_to_tuple(self.adj_norm)
        adj_label_tuple = sparse_to_tuple(self.adj_label)
        features_tuple = sparse_to_tuple(self.features_orig)
        self.adj_norm = torch.sparse.FloatTensor(torch.LongTensor(adj_norm_tuple[0].T),
                                                torch.FloatTensor(adj_norm_tuple[1]),
                                                torch.Size(adj_norm_tuple[2]))
        self.adj_label = torch.sparse.FloatTensor(torch.LongTensor(adj_label_tuple[0].T),
                                                torch.FloatTensor(adj_label_tuple[1]),
                                                torch.Size(adj_label_tuple[2]))
        self.features = torch.sparse.FloatTensor(torch.LongTensor(features_tuple[0].T),
                                                torch.FloatTensor(features_tuple[1]),
                                                torch.Size(features_tuple[2]))

In [8]:
# from torch_geometric.utils import dense_to_sparse
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#
# class_dict = {
#             "HC": 0,
#             "ASD": 1,
#             }
# threshold = 0.3
# data_processed = []
# for label_files in label_list:
#     label = torch.LongTensor([class_dict[label_files]])
#     filelist = os.listdir(os.path.join(root, label_files))
#     for files in filelist:
#         subj_fc_dir = os.path.join(root, label_files, files)
#         subj_mat_fc=np.loadtxt(subj_fc_dir)[:176,:90]
#         print("reading data " + subj_fc_dir)
#         subj_mat_fc_adj = np.corrcoef(np.transpose(subj_mat_fc))
#         subj_mat_fc_adj = subj_mat_fc_adj - np.diag(np.diag(subj_mat_fc_adj))
#         A_pred = copy.deepcopy(subj_mat_fc_adj)
#         #take the upper triangle and compute the threshold
#         subj_fc_adj_up=subj_mat_fc_adj[np.triu_indices(90,k=1)]
#         subj_fc_adj_list = subj_fc_adj_up.reshape((-1))
#         thindex = int(threshold * subj_fc_adj_list.shape[0])
#         thremax = subj_fc_adj_list[subj_fc_adj_list.argsort()[-1 * thindex-1]]#
#         #avoiding Nan
#         subj_fc_adj_t = np.zeros((90, 90))
#         subj_fc_adj_t[subj_mat_fc_adj > thremax] = 1
#         subj_mat_fc_adj=subj_fc_adj_t
#
#
#         adj_orig = sp.csr_matrix(subj_mat_fc_adj, shape=subj_fc_adj_t.shape)
#
#         subj_mat_fc_adj = np.array(adj_orig.todense())
#         fcedge_index, _ = dense_to_sparse(torch.from_numpy(subj_mat_fc_adj.astype(np.int16)))
#
#         subj_mat_fc_list = subj_mat_fc.reshape((-1))
#         subj_mat_fc_new = (subj_mat_fc - min(subj_mat_fc_list)) / (
#                 max(subj_mat_fc_list) - min(subj_mat_fc_list))
#
#         subj_mat_fc_new = np.transpose(subj_mat_fc_new)
#
#         rowsum = np.array(subj_mat_fc_adj.sum(1))
#         N = np.diag(rowsum)
#         degree_C_BOLD=np.concatenate((N,subj_mat_fc_new),1)
#
#         ############################
#         dl = DataLoader(adj_orig, degree_C_BOLD)
#         gae = True
#         vgae = VGAE(dl.adj_norm.to(device), dl.features.size(1), 32, 16,False) # 先暂时用False
#         vgae.to(device)
#         vgae = train_model(device=device, epochs=200, gae=gae, dl=dl, vgae=vgae, criterion='acc')
#         gengraphs = 1
#         if gengraphs > 0:
#             gen_graphs(filename=subj_fc_dir, device=device, dl=dl, gae=gae, vgae=vgae, gen_graphs=gengraphs)


## data reading

In [9]:
from torch_geometric.utils import dense_to_sparse
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class_dict = {
            "HC": 0,
            "ASD": 1,
            }
threshold = 0.3
data_processed = []
for label_files in label_list:
    label = torch.LongTensor([class_dict[label_files]])
    list = os.listdir(os.path.join(root, label_files))
    for files in list:
        subj_fc_dir = os.path.join(root, label_files, files)
        subj_mat_fc=np.loadtxt(subj_fc_dir)[:176,:90]
        print("reading data " + subj_fc_dir)
        A_pred = np.array(pickle.load(open(f'graphs/{subj_fc_dir}_graph_1_logits_gae.pkl', 'rb')))
        adj_orig = pickle.load(open(f'graphs/{subj_fc_dir}_graph_0_gae.pkl', 'rb')).A
        fcedge_index, _ = dense_to_sparse(torch.from_numpy(adj_orig.astype(np.int16)))

        features = pickle.load(open(f'graphs/{subj_fc_dir}_features.pkl', 'rb')).A


        # Data Augmentation
        # ###############################
        adj_pred_1 = np.array(sample_graph_det(adj_orig, A_pred, remove_pct=1, add_pct=1, ratio=50).todense())
        fcedge_index_1, _ = dense_to_sparse(torch.from_numpy(adj_pred_1.astype(np.int16)))
        adj_pred_2 = np.array(sample_graph_det(adj_orig, A_pred, remove_pct=0, add_pct=1, ratio=50).todense())
        fcedge_index_2, _ = dense_to_sparse(torch.from_numpy(adj_pred_2.astype(np.int16)))
        adj_pred_3 = np.array(sample_graph_det(adj_orig, A_pred, remove_pct=1, add_pct=0, ratio=50).todense())
        fcedge_index_3, _ = dense_to_sparse(torch.from_numpy(adj_pred_3.astype(np.int16)))
        # adj_pred_4 = np.array(sample_graph_det(adj_orig, A_pred, remove_pct=1, add_pct=1, ratio=30).todense())
        # fcedge_index_4, _ = dense_to_sparse(torch.from_numpy(adj_pred_4.astype(np.int16)))
        # adj_pred_5 = np.array(sample_graph_det(adj_orig, A_pred, remove_pct=0, add_pct=1, ratio=30).todense())
        # fcedge_index_5, _ = dense_to_sparse(torch.from_numpy(adj_pred_5.astype(np.int16)))
        # adj_pred_6 = np.array(sample_graph_det(adj_orig, A_pred, remove_pct=1, add_pct=0, ratio=30).todense())
        # fcedge_index_6, _ = dense_to_sparse(torch.from_numpy(adj_pred_6.astype(np.int16)))
        # # adj_pred_1 = np.array(sample_graph_det(subj_mat_fc_adj, A_pred, remove_pct=1, add_pct=1, ratio=100).todense())
        # fcedge_index_1, _ = dense_to_sparse(torch.from_numpy(adj_pred_1.astype(np.int16)))
        # adj_pred_2 = np.array(sample_graph_det(subj_mat_fc_adj, A_pred, remove_pct=0, add_pct=1, ratio=50).todense())
        # fcedge_index_2, _ = dense_to_sparse(torch.from_numpy( adj_pred_2.astype(np.int16)))
        # adj_pred_3= np.array(sample_graph_det(subj_mat_fc_adj, A_pred, remove_pct=1, add_pct=0, ratio=50).todense())
        # fcedge_index_3, _ = dense_to_sparse(torch.from_numpy( adj_pred_3.astype(np.int16)))
        # ##################################################

        data_processed.append([
            Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index, y=torch.tensor(label)),
            Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index_1, y=torch.tensor(label)),
            Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index_2, y=torch.tensor(label)),
            Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index_3, y=torch.tensor(label)),
            # Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index_4, y=torch.tensor(label)),
            # Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index_5, y=torch.tensor(label)),
            # Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index_6, y=torch.tensor(label)),
        ])

reading data data_used\ASD\CMU_a_0050649_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050686_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050689_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050690_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050693_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050694_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050695_rois_aal.1D


  Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index, y=torch.tensor(label)),
  Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index_1, y=torch.tensor(label)),
  Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index_2, y=torch.tensor(label)),
  Data(x=torch.from_numpy(features).float(), edge_index=fcedge_index_3, y=torch.tensor(label)),


n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050696_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050697_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050700_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050702_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050704_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050705_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050708_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_1_0050711_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_2_0050743_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Leuven_2_0050745_ro

n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\NYU_0051028_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\NYU_0051029_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\NYU_0051032_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\NYU_0051033_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\NYU_0051034_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\NYU_0051035_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Olin_0050118_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Olin_0050119_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Olin_0050121_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Olin_0050122_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edge

n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050277_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050278_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050280_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050282_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050283_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050284_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050285_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050287_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050289_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\UM_1_0050290_rois_aal.1D
n_edges:  1201
n_edges:  1201


n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050617_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050619_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050620_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050621_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050623_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050624_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050625_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050626_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050627_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\ASD\Yale_0050628_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201


n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051084_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051085_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051086_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051087_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051088_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051089_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051090_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051091_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051093_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\NYU_0051094_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
readi

n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SBL_0051570_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SDSU_0050193_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SDSU_0050194_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SDSU_0050196_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SDSU_0050197_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SDSU_0050198_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SDSU_0050199_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SDSU_0050200_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SDSU_0050201_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\SDSU_0050203_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1

n_edges:  1201
reading data data_used\HC\UM_2_0050391_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\UM_2_0050414_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\UM_2_0050415_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\UM_2_0050416_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\UM_2_0050417_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\UM_2_0050418_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\UM_2_0050419_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\UM_2_0050421_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\UM_2_0050424_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC\UM_2_0050425_rois_aal.1D
n_edges:  1201
n_edges:  1201
n_edges:  1201
reading data data_used\HC

In [10]:
import random
random.seed(1234)
random.shuffle(data_processed)
data_processed

[[Data(x=[90, 266], edge_index=[2, 2402], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2402], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2450], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2354], y=[1])],
 [Data(x=[90, 266], edge_index=[2, 2402], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2402], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2450], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2354], y=[1])],
 [Data(x=[90, 266], edge_index=[2, 2403], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2402], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2450], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2354], y=[1])],
 [Data(x=[90, 266], edge_index=[2, 2402], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2402], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2450], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2354], y=[1])],
 [Data(x=[90, 266], edge_index=[2, 2402], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2402], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2450], y=[1]),
  Data(x=[90, 266], edge_index=[2, 2354], y=[1

In [11]:
train_d = data_processed[:500]
val_d = data_processed[500:]
test_d = data_processed[500:]
train_dataset = []
test_dataset = []
val_dataset = []
for tl in train_d:
    for t in tl:
        train_dataset.append(t)
for testl in test_d:
    test_dataset.append(testl[0])
for vall in val_d:
    val_dataset.append(vall[0])

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of validation graphs: {len(val_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')


Number of training graphs: 2000
Number of validation graphs: 118
Number of test graphs: 118


In [12]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

validation_loader = DataLoader(val_dataset, batch_size=128, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 128
DataBatch(x=[11520, 266], edge_index=[2, 307318], y=[128], batch=[11520], ptr=[129])

Step 2:
Number of graphs in the current batch: 128
DataBatch(x=[11520, 266], edge_index=[2, 307175], y=[128], batch=[11520], ptr=[129])

Step 3:
Number of graphs in the current batch: 128
DataBatch(x=[11520, 266], edge_index=[2, 307173], y=[128], batch=[11520], ptr=[129])

Step 4:
Number of graphs in the current batch: 128
DataBatch(x=[11520, 266], edge_index=[2, 307938], y=[128], batch=[11520], ptr=[129])

Step 5:
Number of graphs in the current batch: 128
DataBatch(x=[11520, 266], edge_index=[2, 307127], y=[128], batch=[11520], ptr=[129])

Step 6:
Number of graphs in the current batch: 128
DataBatch(x=[11520, 266], edge_index=[2, 308034], y=[128], batch=[11520], ptr=[129])

Step 7:
Number of graphs in the current batch: 128
DataBatch(x=[11520, 266], edge_index=[2, 307314], y=[128], batch=[11520], ptr=[129])

Step 8:
Number of graphs in the current b

In [13]:
# allData = DataLoader(data_processed, batch_size=800, shuffle=True)
# data_all = None
# for step, data in enumerate(allData):
#     print(f'Step {step + 1}:')
#     print('=======')
#     print(f'Number of graphs in the current batch: {data.num_graphs}')
#     print(data)
#     data_all = data
#     print()

In [14]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
class GAT1(torch.nn.Module):
    def __init__(self):
        super(GAT1, self).__init__()
        self.hid = 16
        self.in_head = 16
        self.out_head = 4
        self.conv1 = GATConv(266, self.hid, heads=self.in_head, dropout=0.5)
        # self.conv2 = GATConv(self.hid*self.in_head, self.hid, heads=self.in_head, dropout=0.5)
        self.conv3 = GATConv(self.hid*self.in_head, 4, concat=False,
                             heads=self.out_head, dropout=0.5)
        self.lin = Linear(4, 2)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.tanh(x)
        x = F.dropout(x, p=0.5, training=self.training)
        # x = self.conv2(x, edge_index)
        # x = F.tanh(x)
        # x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_max_pool(x, batch)

        # 3. Apply a final classifier
        x = self.lin(x)

        return x


In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model_GAT = GAT1().to(device)
optimizer = torch.optim.Adam(model_GAT.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
def train():
    model_GAT.train()
    for data in train_loader:
        data.to(device)
        optimizer.zero_grad()
        out = model_GAT(data)
        # print('out: ', out.shape)
        # print('data.y: ', data.y.shape)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
    return model_GAT


def test(loader):
    model_GAT.eval()

    correct = 0
    for data in loader:
        data.to(device)
        out = model_GAT(data)
        pred = out.argmax(dim=1)
        # print('pred: ', pred)
        # print('data.y: ', data.y)
        correct += int((pred == data.y).sum())

    return correct / len(loader.dataset)

def mymain_GAT():
    best_val_acc = 0
    for epoch in range(1, 200):
        model_temp = train()
        # print('train: ')
        train_acc = test(train_loader)
        # print('test: ')
        validation_acc = test(validation_loader)

        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Validation Acc: {validation_acc:.4f}')

        if validation_acc > best_val_acc:
            best_val_acc = validation_acc
            torch.save(model_temp, 'BestGAT.pt')

    model_final = torch.load('BestGAT.pt')
    model_final.eval()

    correct = 0
    for data in test_loader:
        data.to(device)
        out = model_final(data)
        pred = out.argmax(dim=1)
        # print('pred: ', pred)
        # print('data.y: ', data.y)
        correct += int((pred == data.y).sum())

    test_acc = correct / len(test_loader.dataset)
    print(f'Final test acc: {test_acc:.4f}')



cuda


In [16]:
mymain_GAT()



Epoch: 001, Train Acc: 0.5685, Validation Acc: 0.4746
Epoch: 002, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 003, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 004, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 005, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 006, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 007, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 008, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 009, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 010, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 011, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 012, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 013, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 014, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 015, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 016, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 017, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 018, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 019, Train Acc: 0.544

Epoch: 153, Train Acc: 0.8395, Validation Acc: 0.6356
Epoch: 154, Train Acc: 0.8340, Validation Acc: 0.6356
Epoch: 155, Train Acc: 0.8380, Validation Acc: 0.6610
Epoch: 156, Train Acc: 0.8400, Validation Acc: 0.6525
Epoch: 157, Train Acc: 0.8400, Validation Acc: 0.6441
Epoch: 158, Train Acc: 0.8435, Validation Acc: 0.6610
Epoch: 159, Train Acc: 0.8440, Validation Acc: 0.6441
Epoch: 160, Train Acc: 0.8440, Validation Acc: 0.6441
Epoch: 161, Train Acc: 0.8295, Validation Acc: 0.6186
Epoch: 162, Train Acc: 0.8460, Validation Acc: 0.6610
Epoch: 163, Train Acc: 0.8450, Validation Acc: 0.6356
Epoch: 164, Train Acc: 0.8390, Validation Acc: 0.6271
Epoch: 165, Train Acc: 0.8430, Validation Acc: 0.6271
Epoch: 166, Train Acc: 0.8475, Validation Acc: 0.6441
Epoch: 167, Train Acc: 0.8385, Validation Acc: 0.6271
Epoch: 168, Train Acc: 0.8485, Validation Acc: 0.6441
Epoch: 169, Train Acc: 0.8375, Validation Acc: 0.6271
Epoch: 170, Train Acc: 0.8490, Validation Acc: 0.6271
Epoch: 171, Train Acc: 0.850

### Classify by GraphSage

In [17]:
from torch_geometric.datasets import Planetoid
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, global_mean_pool, global_max_pool
from copy import deepcopy
class GraphSage_Net(torch.nn.Module):
    def __init__(self, features, classes):
        super(GraphSage_Net, self).__init__()
        num_layers = 3
        dim_embedding = 32
        self.aggregation = 'max'

        if self.aggregation == 'max':
            self.fc_max = nn.Linear(dim_embedding, dim_embedding)

        self.layers = nn.ModuleList([])
        for i in range(num_layers):
            dim_input = features if i == 0 else dim_embedding
            conv = SAGEConv(dim_input, dim_embedding)
            conv.aggr = self.aggregation
            self.layers.append(conv)

        self.fc1 = nn.Linear(num_layers * dim_embedding, dim_embedding)
        self.fc2 = nn.Linear(dim_embedding, classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x_all = []
        for i, layer in enumerate(self.layers):
            x = layer(x, edge_index)
            if self.aggregation == 'max':
                x = torch.tanh(self.fc_max(x))
            x = F.dropout(x, p=0.5, training=self.training)
            x_all.append(x)

        x = torch.cat(x_all, dim=1)
        x = global_max_pool(x, batch)

        x = F.tanh(self.fc1(x))
        x = self.fc2(x)
        return x

In [18]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_GraphSage = GraphSage_Net(266, 2).to(device)
# data = dataset[0]
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# optimizer = torch.optim.Adam([
# 	dict(params=model.conv1.parameters(), weight_decay=5e-4),
#     dict(params=model.conv2.parameters(), weight_decay=0)
#     ], lr=0.01)
optimizer = torch.optim.Adam(model_GraphSage.parameters(),
                             lr=0.01, weight_decay=5e-4)

criterion = torch.nn.CrossEntropyLoss()

def Sage_train():
    model_GraphSage.train()
    for data in train_loader:
        data.to(device)
        optimizer.zero_grad()
        out = model_GraphSage(data)
        # print('out: ', out.shape)
        # print('data.y: ', data.y.shape)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
    return model_GraphSage


def Sage_test(loader):
    model_GraphSage.eval()

    correct = 0
    for data in loader:
        data.to(device)
        out = model_GraphSage(data)
        pred = out.argmax(dim=1)
        # print('pred: ', pred)
        # print('data.y: ', data.y)
        correct += int((pred == data.y).sum())

    return correct / len(loader.dataset)


def mymain_GraphSage():
    best_val_acc = 0
    for epoch in range(1, 50):
        model_temp = Sage_train()
        # print('train: ')
        train_acc = Sage_test(train_loader)
        # print('test: ')
        validation_acc = Sage_test(validation_loader)

        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Validation Acc: {validation_acc:.4f}')

        if validation_acc > best_val_acc:
            best_val_acc = validation_acc
            torch.save(model_temp, 'BestGraphSage.pt')

    model_final = torch.load('BestGraphSage.pt')
    model_final.eval()

    correct = 0
    for data in test_loader:
        data.to(device)
        out = model_final(data)
        pred = out.argmax(dim=1)
        # print('pred: ', pred)
        # print('data.y: ', data.y)
        correct += int((pred == data.y).sum())

    test_acc = correct / len(test_loader.dataset)
    print(f'Final test acc: {test_acc:.4f}')

In [19]:
mymain_GraphSage()

Epoch: 001, Train Acc: 0.6245, Validation Acc: 0.5424
Epoch: 002, Train Acc: 0.5440, Validation Acc: 0.4746
Epoch: 003, Train Acc: 0.7535, Validation Acc: 0.6017
Epoch: 004, Train Acc: 0.7855, Validation Acc: 0.6017
Epoch: 005, Train Acc: 0.8315, Validation Acc: 0.5847
Epoch: 006, Train Acc: 0.8450, Validation Acc: 0.6356
Epoch: 007, Train Acc: 0.8820, Validation Acc: 0.5932
Epoch: 008, Train Acc: 0.9150, Validation Acc: 0.6864
Epoch: 009, Train Acc: 0.9145, Validation Acc: 0.5847
Epoch: 010, Train Acc: 0.7540, Validation Acc: 0.6017
Epoch: 011, Train Acc: 0.9255, Validation Acc: 0.6610
Epoch: 012, Train Acc: 0.9210, Validation Acc: 0.6441
Epoch: 013, Train Acc: 0.9725, Validation Acc: 0.6356
Epoch: 014, Train Acc: 0.9485, Validation Acc: 0.5763
Epoch: 015, Train Acc: 0.9825, Validation Acc: 0.5763
Epoch: 016, Train Acc: 0.9905, Validation Acc: 0.5763
Epoch: 017, Train Acc: 0.9900, Validation Acc: 0.5678
Epoch: 018, Train Acc: 0.9805, Validation Acc: 0.5932
Epoch: 019, Train Acc: 0.984

In [20]:
def _make_block_diag(mats, mat_sizes):
    block_diag = torch.zeros(sum(mat_sizes), sum(mat_sizes))

    for i, (mat, size) in enumerate(zip(mats, mat_sizes)):
        cum_size = sum(mat_sizes[:i])
        block_diag[cum_size:cum_size+size,cum_size:cum_size+size] = mat

    return 

In [22]:
import torch
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.utils import degree, dense_to_sparse
from torch_geometric.nn import ECConv
from torch_scatter import scatter_add


class ECCLayer(nn.Module):
    def __init__(self, dim_input, dim_embedding, dropout=0.):
        super().__init__()

        fnet1 = nn.Sequential(nn.Linear(1, 16),
                              nn.ReLU(),
                              nn.Linear(16, dim_embedding * dim_input))

        fnet2 = nn.Sequential(nn.Linear(1, 16),
                              nn.ReLU(),
                              nn.Linear(16, dim_embedding * dim_embedding))

        fnet3 = nn.Sequential(nn.Linear(1, 16),
                              nn.ReLU(),
                              nn.Linear(16, dim_embedding * dim_embedding))

        self.conv1 = ECConv(dim_input, dim_embedding, nn=fnet1)
        self.conv2 = ECConv(dim_embedding, dim_embedding, nn=fnet2)
        self.conv3 = ECConv(dim_embedding, dim_embedding, nn=fnet3)

        self.bn1 = nn.BatchNorm1d(dim_embedding)
        self.bn2 = nn.BatchNorm1d(dim_embedding)
        self.bn3 = nn.BatchNorm1d(dim_embedding)

        self.dropout = dropout

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(self.bn1(x), p=self.dropout, training=self.training)

        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.dropout(self.bn2(x), p=self.dropout, training=self.training)

        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.dropout(self.bn3(x), p=self.dropout, training=self.training)

        return x


class ECC(nn.Module):
    """
    Uses fixed architecture.
    IMPORTANT NOTE: we will consider dataset which do not have edge labels.
    Therefore, we avoid learning the function that associates a weight matrix
    to an edge specific weight.
    """

    def __init__(self, dim_features, dim_target):
        super().__init__()
        self.dropout = 0.5
        self.dropout_final = 0.5
        self.num_layers = 3
        dim_embedding = 32

        self.layers = nn.ModuleList([])
        for i in range(self.num_layers):
            dim_input = dim_features if i == 0 else dim_embedding
            layer = ECCLayer(dim_input, dim_embedding, dropout=self.dropout)
            self.layers.append(layer)

        fnet = nn.Sequential(nn.Linear(1, 16),
                             nn.ReLU(),
                             nn.Linear(16, dim_embedding * dim_embedding))

        self.final_conv = ECConv(dim_embedding, dim_embedding, nn=fnet)
        self.final_conv_bn = nn.BatchNorm1d(dim_embedding)

        self.fc1 = nn.Linear(dim_embedding, dim_embedding)
        self.fc2 = nn.Linear(dim_embedding, dim_target)

    def make_block_diag(self, matrix_list):
        mat_sizes = [m.size(0) for m in matrix_list]
        return _make_block_diag(matrix_list, mat_sizes)

    def get_ecc_conv_parameters(self, data, layer_no):
        v_plus_list, laplacians = data.v_plus, data.laplacians

        # print([v_plus[layer_no] for v_plus in v_plus_list])
        v_plus_batch = torch.cat([v_plus[layer_no] for v_plus in v_plus_list], dim=0)

        laplacian_layer_list = [laplacians[i][layer_no] for i in range(len(laplacians))]
        laplacian_block_diagonal = self.make_block_diag(laplacian_layer_list)
        
        if self.config.dataset.name == 'DD':
            laplacian_block_diagonal[laplacian_block_diagonal<1e-4] = 0

        # First layer
        lap_edge_idx, lap_edge_weights = dense_to_sparse(laplacian_block_diagonal)

        return lap_edge_idx, lap_edge_weights, (v_plus_batch == 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for i, layer in enumerate(self.layers):
            # TODO should lap_edge_index[0] be equal to edge_idx?
            lap_edge_idx, lap_edge_weights, v_plus_batch = self.get_ecc_conv_parameters(data, layer_no=i)
            edge_index = lap_edge_idx if i != 0 else edge_index
            edge_weight = lap_edge_weights if i != 0 else x.new_ones((edge_index.size(1), ))

            edge_index = edge_index.to(self.config.device)
            edge_weight = edge_weight.to(self.config.device)

            # apply convolutional layer
            x = layer(x, edge_index, edge_weight)

            # pooling
            x = x[v_plus_batch]
            batch = batch[v_plus_batch]

        # final_convolution
        lap_edge_idx, lap_edge_weight, v_plus_batch = self.get_ecc_conv_parameters(data, layer_no=self.num_layers)

        lap_edge_idx = lap_edge_idx.to(self.config.device)
        lap_edge_weight = lap_edge_weight.to(self.config.device)

        x = F.relu(self.final_conv(x, lap_edge_idx, lap_edge_weight))
        x = F.dropout(self.final_conv_bn(x), p=self.dropout, training=self.training)

        # TODO: is the following line needed before global pooling?
        # batch = batch[v_plus_batch]

        graph_emb = global_mean_pool(x, batch)

        x = F.relu(self.fc1(graph_emb))
        x = F.dropout(x, p=self.dropout_final, training=self.training)

        # No ReLU specified here todo check with source code (code is not so clear)
        x = self.fc2(x)

        return x

In [23]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_ECC = ECC(266, 2).to(device)
# data = dataset[0]
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# optimizer = torch.optim.Adam([
# 	dict(params=model.conv1.parameters(), weight_decay=5e-4),
#     dict(params=model.conv2.parameters(), weight_decay=0)
#     ], lr=0.01)
optimizer = torch.optim.Adam(model_ECC.parameters(),
                             lr=0.01, weight_decay=5e-4)

criterion = torch.nn.CrossEntropyLoss()

def ECC_train():
    model_ECC.train()
    for data in train_loader:
        data.to(device)
        optimizer.zero_grad()
        out = model_ECC(data)
        # print('out: ', out.shape)
        # print('data.y: ', data.y.shape)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
    return model_ECC


def ECC_test(loader):
    model_ECC.eval()

    correct = 0
    for data in loader:
        data.to(device)
        out = model_ECC(data)
        pred = out.argmax(dim=1)
        # print('pred: ', pred)
        # print('data.y: ', data.y)
        correct += int((pred == data.y).sum())

    return correct / len(loader.dataset)


def mymain_ECC():
    best_val_acc = 0
    for epoch in range(1, 100):
        model_temp = ECC_train()
        # print('train: ')
        train_acc = ECC_test(train_loader)
        # print('test: ')
        validation_acc = ECC_test(validation_loader)

        print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Validation Acc: {validation_acc:.4f}')

        if validation_acc > best_val_acc:
            torch.save(model_temp, 'BestECC.pt')

    model_final = torch.load('BestECC.pt')
    model_final.eval()

    correct = 0
    for data in test_loader:
        data.to(device)
        out = model_final(data)
        pred = out.argmax(dim=1)
        # print('pred: ', pred)
        # print('data.y: ', data.y)
        correct += int((pred == data.y).sum())

    test_acc = correct / len(test_loader.dataset)
    print(f'Final test acc: {test_acc:.4f}')

TypeError: __init__() missing 1 required positional argument: 'config'

In [None]:
mymain_ECC()