In [None]:
import pickle
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statistics as stat
import time
%matplotlib inline

from sklearn.metrics import accuracy_score

from nilearn.connectome import ConnectivityMeasure

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.metrics import f1_score

import networkx as nx
import torch
import torch.optim as optim
from torch.nn import Linear
import torch.nn.functional as F

from sklearn import svm

import pickle5 as pickle
import os

from typing import Optional, Tuple

# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
from torch import Tensor
from torch.nn import Parameter
from torch.utils.data import DataLoader
# from torch_scatter import scatter_add
# from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparsesum, mul
# from torch_geometric.nn.conv import MessagePassing
# from torch_geometric.utils import add_remaining_self_loops
# from torch_geometric.utils.num_nodes import maybe_num_nodes
# from torch import softmax

In [None]:
import torch
import torch.nn as nn

In [None]:
def load_obj(path):
    with open(path + '.pkl', 'rb') as f:
        return pickle.load(f)

In [None]:
AD_dict = load_obj('AAL_data/timeseries/AD')
CN_dict = load_obj('AAL_data/timeseries/CN')

AD_train = load_obj('AAL_data/AD_train_full')
AD_val = load_obj('AAL_data/AD_val_full')
AD_test = load_obj('AAL_data/AD_test_full')

CN_train = load_obj('AAL_data/CN_train_full')
CN_val = load_obj('AAL_data/CN_val_full')
CN_test = load_obj('AAL_data/CN_test_full')

In [None]:
def get_correlation_matrix(timeseries,msr):
    correlation_measure = ConnectivityMeasure(kind=msr)
    correlation_matrix = correlation_measure.fit_transform([timeseries])[0]
    return correlation_matrix

def get_upper_triangular_matrix(matrix):
    upp_mat = []
    for i in range(len(matrix)):
        for j in range(i+1,len(matrix)):
            upp_mat.append(matrix[i][j])
    return upp_mat

    
# def get_adj_mat(correlation_matrix, threshold_value, weighted = True):
#     adj_mat = []
#     for i in correlation_matrix:
#         row = []
#         for j in i:
#             if abs(j)>threshold_value:
#                 if not weighted:
#                     row.append(1)
#                 else:
#                     row.append(abs(j))
#             else:
#                 row.append(0)
#         adj_mat.append(row)
#     return adj_mat


def get_adj_mat(correlation_matrix, th_value_p, th_value_n):
    adj_mat = []
    k=0
    for i in correlation_matrix:
        row = []
        for j in i:
            if j>0:
                if j > th_value_p:
                    row.append(j)
                else:
                    row.append(0)
            else:
                if abs(j) > th_value_n:
                    row.append(abs(j))
                else:
                    row.append(0)
        adj_mat.append(row)
#     adj_mat = connect_isolated_nodes(adj_mat, correlation_matrix)
    return adj_mat


def connect_isolated_nodes(adj_mat, correlation_matrix):
    correlation_matrix = list(np.array(correlation_matrix) - np.array(np.eye(len(correlation_matrix))))
    correlation_matrix = [list(a) for a in correlation_matrix]
                              
    for row_num in range(len(adj_mat)):
        if sum(adj_mat[row_num]) == 0:
            index_max_element_corr_row = correlation_matrix[row_num].index(max(correlation_matrix[row_num]))
            adj_mat[row_num][index_max_element_corr_row] = 1
    return adj_mat
    
# def get_threshold_value(ad_timeseires, cn_timeseries, measure, threshold_percent):
#     ad_corr_mats = [get_correlation_matrix(ts, measure) for ts in ad_timeseires]
#     cn_corr_mats = [get_correlation_matrix(ts, measure) for ts in cn_timeseries]

#     ad_upper = [get_upper_triangular_matrix(matrix) for matrix in ad_corr_mats]
#     cn_upper = [get_upper_triangular_matrix(matrix) for matrix in cn_corr_mats]

#     all_correlation_values = ad_upper + cn_upper
#     all_correlation_values = np.array(all_correlation_values).flatten()

#     all_correlation_values = np.array([abs(i) for i in all_correlation_values])
#     all_correlation_values = np.sort(all_correlation_values)[::-1]

#     th_val_index = (len(all_correlation_values)*threshold_percent)//100
#     return all_correlation_values[int(th_val_index)]


def get_threshold_value(ad_timeseires, cn_timeseries, measure, threshold_percent):
    ad_corr_mats = [get_correlation_matrix(ts, measure) for ts in ad_timeseires]
    cn_corr_mats = [get_correlation_matrix(ts, measure) for ts in cn_timeseries]

    ad_upper = [get_upper_triangular_matrix(matrix) for matrix in ad_corr_mats]
    cn_upper = [get_upper_triangular_matrix(matrix) for matrix in cn_corr_mats]

    all_correlation_values = ad_upper + cn_upper
    all_correlation_values = np.array(all_correlation_values).flatten()
    
    all_correlation_values_pos=[]
    all_correlation_values_neg=[]
    for i in all_correlation_values:
        if i==1:
            continue
        elif i>0:
            all_correlation_values_pos.append(i)
        else:  
            all_correlation_values_neg.append(abs(i))

    all_correlation_values_pos = np.array(all_correlation_values_pos)
    all_correlation_values_pos = np.sort(all_correlation_values_pos)[::-1]
    
    all_correlation_values_neg = np.array(all_correlation_values_neg)
    all_correlation_values_neg = np.sort(all_correlation_values_neg)[::-1]

    th_val_index = (len(all_correlation_values)*threshold_percent)//100
    
    return all_correlation_values_pos[int(th_val_index)], all_correlation_values_neg[int(th_val_index)]

def create_graph(timeseries, threshold_p, threshold_n, y, measure='correlation'):
    correlation_matrix = get_correlation_matrix(timeseries, measure)
    adj_mat = get_adj_mat(correlation_matrix, threshold_p, threshold_n)

    G = nx.from_numpy_matrix(np.array(adj_mat), create_using=nx.DiGraph)
    data=torch_geometric.utils.from_networkx(G)
    data['x'] = torch.tensor(correlation_matrix, dtype=torch.float)
    data['y'] = torch.tensor([y])

    if torch.cuda.is_available():
        device = torch.device('cuda:1')
        data = data.to(device)
        return data
  
    return data

In [None]:
AD_features = torch.tensor([get_upper_triangular_matrix(get_correlation_matrix(AD_dict[sub_id],'correlation')) for sub_id in AD_dict])
CN_features = torch.tensor([get_upper_triangular_matrix(get_correlation_matrix(CN_dict[sub_id],'correlation')) for sub_id in CN_dict])

In [None]:
AD_loader = DataLoader(AD_features, batch_size=32, shuffle=True)
CN_loader = DataLoader(CN_features, batch_size=32, shuffle=True)
# for batch_idx, real in enumerate(CN_loader):
#     print(batch_idx)
#     print(real.shape)

In [None]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, dim = real.shape
    alpha = torch.rand((BATCH_SIZE, 1)).repeat(1, dim).to(device)
    interpolated_mats = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_mats)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_mats,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
class Discriminator(nn.Module):
    def __init__(self, features_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(features_dim, 1),
            nn.LeakyReLU(0.1),
#             nn.Linear(4096, 2048),
#             nn.LeakyReLU(0.1),
#             nn.Linear(2048, 1024),
#             nn.LeakyReLU(0.1),
#             nn.Linear(1024, 512),
#             nn.LeakyReLU(0.1),
#             nn.Linear(512, 1),
        )
        
    def forward(self, x):
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self, noise_dim, features_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(noise_dim, 512),
            nn.LeakyReLU(0.1),
#             nn.Linear(512, 1024),
#             nn.LeakyReLU(0.1),
#             nn.Linear(1024, 2048),
#             nn.LeakyReLU(0.1),
#             nn.Linear(2048, 4096),
            nn.LeakyReLU(0.1),
            nn.Linear(512, features_dim),
            nn.Tanh(),
        )
        
    def forward(self, x):
        return self.gen(x)

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
lr = 3e-6
features_dim = 6670     # 116*116 connectivity matrices
noise_dim = 100
batch_size = 32
num_epochs = 2000
LAMBDA_GP = 10
Dloss = []
Gloss = []

disc = Discriminator(features_dim).to(device)
gen = Generator(noise_dim, features_dim).to(device)
fixed_noise = torch.randn((batch_size, noise_dim)).to(device)
# transforms = transforms.Compose(
#     [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
# )

# loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
# writer_fake = SummaryWriter(f"logs/fake")
# writer_real = SummaryWriter(f"logs/real")
step = 0

for epoch in range(num_epochs):
    D_l = 0
    G_l = 0
    cnt = 0
    for batch_idx, real in enumerate(CN_loader):
        real = real.view(-1, 6670).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real.float()).view(-1)
        disc_fake = disc(fake).view(-1)
        
        gp = gradient_penalty(disc, real.float(), fake, device=device)
        lossD = -(torch.mean(disc_real) - torch.mean(disc_fake)) + LAMBDA_GP * gp
        D_l += lossD * batch_size
        
        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake = gen(noise)
        output = disc(fake)
        lossG = -torch.mean(fake)
        G_l += lossG * batch_size
        
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        cnt += batch_size
    
    D_l /= cnt
    G_l /= cnt
    
    Gloss.append(G_l)
    Dloss.append(D_l)

    print(f"Epoch [{epoch}/{num_epochs}] Loss D: {D_l:.4f}, loss G: {lossG:.4f}")

#             with torch.no_grad():
#                 fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
#                 data = real.reshape(-1, 1, 28, 28)
#                 img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
#                 img_grid_real = torchvision.utils.make_grid(data, normalize=True)

#                 writer_fake.add_image(
#                     "Mnist Fake Images", img_grid_fake, global_step=step
#                 )
#                 writer_real.add_image(
#                     "Mnist Real Images", img_grid_real, global_step=step
#                 )
#                 step += 1
plt.plot(Gloss,label='Gen')
plt.plot(Dloss,label='Disc')
plt.legend()
plt.show()

In [None]:
plt.plot(Gloss,label='Gen')
plt.plot(Dloss,label='Disc')
plt.legend()
plt.show()

In [None]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, B).to(device)
    interpolated_mat = real * alpha + fake * (1 - alpha)

    mixed_scores = critic(interpolated_mat)

    gradient = torch.autograd.grad(
        inputs=interpolated_mat,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
#             critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()
            
#         for _ in range(CRITIC_ITERATIONS):
#             noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
#             fake = gen(noise)
# #             critic_real = critic(real).reshape(-1)
#             critic_fake = critic(fake).reshape(-1)
#             gp = gradient_penalty(critic, real, fake, device=device)
#             loss_critic = (
#                 -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
#             )
#             critic.zero_grad()
#             loss_critic.backward(retain_graph=True)
#             opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()