In [None]:
import pickle
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statistics as stat
import time
%matplotlib inline

from sklearn.metrics import accuracy_score

from nilearn.connectome import ConnectivityMeasure

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import f1_score

import networkx as nx
import torch
import torch_geometric.utils
from torch_geometric.data import Data, DataLoader
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool

from sklearn import svm

import pickle5 as pickle
import os

from typing import Optional, Tuple
from torch_geometric.typing import Adj, OptTensor, PairTensor

# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparsesum, mul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch import softmax

In [1]:
import torch
import torch.nn as nn

In [None]:
def load_obj(path):
    with open(path + '.pkl', 'rb') as f:
        return pickle.load(f)

In [None]:
AD_dict = load_obj('AAL_data/timeseries/AD')
CN_dict = load_obj('AAL_data/timeseries/CN')

AD_train = load_obj('AAL_data/AD_train_full')
AD_val = load_obj('AAL_data/AD_val_full')
AD_test = load_obj('AAL_data/AD_test_full')

CN_train = load_obj('AAL_data/CN_train_full')
CN_val = load_obj('AAL_data/CN_val_full')
CN_test = load_obj('AAL_data/CN_test_full')

In [None]:
def get_correlation_matrix(timeseries,msr):
    correlation_measure = ConnectivityMeasure(kind=msr)
    correlation_matrix = correlation_measure.fit_transform([timeseries])[0]
    return correlation_matrix

def get_upper_triangular_matrix(matrix):
    upp_mat = []
    for i in range(len(matrix)):
        for j in range(i+1,len(matrix)):
            upp_mat.append(matrix[i][j])
    return upp_mat

    
# def get_adj_mat(correlation_matrix, threshold_value, weighted = True):
#     adj_mat = []
#     for i in correlation_matrix:
#         row = []
#         for j in i:
#             if abs(j)>threshold_value:
#                 if not weighted:
#                     row.append(1)
#                 else:
#                     row.append(abs(j))
#             else:
#                 row.append(0)
#         adj_mat.append(row)
#     return adj_mat


def get_adj_mat(correlation_matrix, th_value_p, th_value_n):
    adj_mat = []
    k=0
    for i in correlation_matrix:
        row = []
        for j in i:
            if j>0:
                if j > th_value_p:
                    row.append(j)
                else:
                    row.append(0)
            else:
                if abs(j) > th_value_n:
                    row.append(abs(j))
                else:
                    row.append(0)
        adj_mat.append(row)
#     adj_mat = connect_isolated_nodes(adj_mat, correlation_matrix)
    return adj_mat


def connect_isolated_nodes(adj_mat, correlation_matrix):
    correlation_matrix = list(np.array(correlation_matrix) - np.array(np.eye(len(correlation_matrix))))
    correlation_matrix = [list(a) for a in correlation_matrix]
                              
    for row_num in range(len(adj_mat)):
        if sum(adj_mat[row_num]) == 0:
            index_max_element_corr_row = correlation_matrix[row_num].index(max(correlation_matrix[row_num]))
            adj_mat[row_num][index_max_element_corr_row] = 1
    return adj_mat
    
# def get_threshold_value(ad_timeseires, cn_timeseries, measure, threshold_percent):
#     ad_corr_mats = [get_correlation_matrix(ts, measure) for ts in ad_timeseires]
#     cn_corr_mats = [get_correlation_matrix(ts, measure) for ts in cn_timeseries]

#     ad_upper = [get_upper_triangular_matrix(matrix) for matrix in ad_corr_mats]
#     cn_upper = [get_upper_triangular_matrix(matrix) for matrix in cn_corr_mats]

#     all_correlation_values = ad_upper + cn_upper
#     all_correlation_values = np.array(all_correlation_values).flatten()

#     all_correlation_values = np.array([abs(i) for i in all_correlation_values])
#     all_correlation_values = np.sort(all_correlation_values)[::-1]

#     th_val_index = (len(all_correlation_values)*threshold_percent)//100
#     return all_correlation_values[int(th_val_index)]


def get_threshold_value(ad_timeseires, cn_timeseries, measure, threshold_percent):
    ad_corr_mats = [get_correlation_matrix(ts, measure) for ts in ad_timeseires]
    cn_corr_mats = [get_correlation_matrix(ts, measure) for ts in cn_timeseries]

    ad_upper = [get_upper_triangular_matrix(matrix) for matrix in ad_corr_mats]
    cn_upper = [get_upper_triangular_matrix(matrix) for matrix in cn_corr_mats]

    all_correlation_values = ad_upper + cn_upper
    all_correlation_values = np.array(all_correlation_values).flatten()
    
    all_correlation_values_pos=[]
    all_correlation_values_neg=[]
    for i in all_correlation_values:
        if i==1:
            continue
        elif i>0:
            all_correlation_values_pos.append(i)
        else:  
            all_correlation_values_neg.append(abs(i))

    all_correlation_values_pos = np.array(all_correlation_values_pos)
    all_correlation_values_pos = np.sort(all_correlation_values_pos)[::-1]
    
    all_correlation_values_neg = np.array(all_correlation_values_neg)
    all_correlation_values_neg = np.sort(all_correlation_values_neg)[::-1]

    th_val_index = (len(all_correlation_values)*threshold_percent)//100
    
    return all_correlation_values_pos[int(th_val_index)], all_correlation_values_neg[int(th_val_index)]

def create_graph(timeseries, threshold_p, threshold_n, y, measure='correlation'):
    correlation_matrix = get_correlation_matrix(timeseries, measure)
    adj_mat = get_adj_mat(correlation_matrix, threshold_p, threshold_n)

    G = nx.from_numpy_matrix(np.array(adj_mat), create_using=nx.DiGraph)
    data=torch_geometric.utils.from_networkx(G)
    data['x'] = torch.tensor(correlation_matrix, dtype=torch.float)
    data['y'] = torch.tensor([y])

    if torch.cuda.is_available():
        device = torch.device('cuda:1')
        data = data.to(device)
        return data
  
    return data

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_features, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.Linear(1024, 2048),
            nn.Linear(2048, 4096),
            nn.Linear(4096, 6670),
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

#     def _block(self, in_channels, out_channels, kernel_size, stride, padding):
#         return nn.Sequential(
#             nn.Conv2d(
#                 in_channels, out_channels, kernel_size, stride, padding, bias=False,
#             ),
#             nn.InstanceNorm2d(out_channels, affine=True),
#             nn.LeakyReLU(0.2),
#         )
    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, noise_d, num_classes):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_d + num_classes, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 4096),
            nn.ReLU(),
            nn.Linear(4096, 6670),
            nn.Tanh(),
        )
        
    def symm_reshape(x):
        return nn.triu(x)
        
#         return nn.tril(x)
        

    def forward(self, x, labels):
#         x = torch.tensor(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([x, labels], dim=1)
        x = self.net(x)
#         x = self.symm_reshape(x)
        return 


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
class E2EBlock(torch.nn.Module):

    def __init__(self, in_planes, planes,example,bias=False):
        super(E2EBlock, self).__init__()
        self.d = example.size(3)
        self.cnn1 = torch.nn.Conv2d(in_planes,planes,(1,self.d),bias=bias)
#         self.cnn1 = torch.nn.Conv2d(in_planes,planes,(1,self.d),bias=bias)
        self.cnn2 = torch.nn.Conv2d(in_planes,planes,(self.d,1),bias=bias)

        
    def forward(self, x):
        a = self.cnn1(x)
        b = self.cnn2(x)
        return torch.cat([a]*self.d,3)+torch.cat([b]*self.d,2)

In [None]:
class BrainNetCNN(torch.nn.Module):
    def __init__(self, example, num_classes=10):
        super(BrainNetCNN, self).__init__()
        self.in_planes = example.size(1)
        self.d = example.size(3)
        
        self.e2econv1 = E2EBlock(1,32,example,bias=True)
        self.e2econv2 = E2EBlock(32,64,example,bias=True)
#         self.N2G = torch.nn.Conv2d(1,256,(self.d,1))
#         self.E2N = torch.nn.Conv2d(64,1,(1,self.d))
        self.N2G = torch.nn.Conv2d(1,256,(self.d,1))
        self.dense1 = torch.nn.Linear(256,128)
        self.dense2 = torch.nn.Linear(128,30)
        self.dense3 = torch.nn.Linear(30,2)
        
    def forward(self, x):
        out = F.leaky_relu(self.e2econv1(x),negative_slope=0.33)
#         out = F.leaky_relu(self.e2econv2(out),negative_slope=0.33) 
        out = F.leaky_relu(self.E2N(out),negative_slope=0.33)
        out = F.dropout(F.leaky_relu(self.N2G(out),negative_slope=0.33),p=0.5)
        out = out.view(out.size(0), -1)
#         out = F.dropout(F.leaky_relu(self.dense1(out),negative_slope=0.33),p=0.5)
        out = F.dropout(F.leaky_relu(self.dense2(out),negative_slope=0.33),p=0.5)
        out = F.leaky_relu(self.dense3(out),negative_slope=0.33)
        
        return out

In [None]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, B).to(device)
    interpolated_mat = real * alpha + fake * (1 - alpha)

    mixed_scores = critic(interpolated_mat)

    gradient = torch.autograd.grad(
        inputs=interpolated_mat,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
#             critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()
            
#         for _ in range(CRITIC_ITERATIONS):
#             noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
#             fake = gen(noise)
# #             critic_real = critic(real).reshape(-1)
#             critic_fake = critic(fake).reshape(-1)
#             gp = gradient_penalty(critic, real, fake, device=device)
#             loss_critic = (
#                 -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
#             )
#             critic.zero_grad()
#             loss_critic.backward(retain_graph=True)
#             opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()