In [None]:
! git clone https://github.com/SwapnilDreams100/Neural-GC/

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd Neural-GC/

In [None]:
# %run script.py -seed "890" -lam "10.0" -lr "0.001" -percent_var "12" -context "10" -mbsize "60000" -file '../drive/MyDrive/eeg_data/evoked_simulated_nonstationary.mat'

In [None]:
from scipy.io import loadmat
import numpy as np
import torch
import matplotlib.pyplot as plt
from models.clstm import cLSTM, train_model_accumulated_ista
import argparse
import random, os

parser = argparse.ArgumentParser()
parser.add_argument("-seed", "--seed", help = "0, 1")
parser.add_argument("-lam", "--lam", help = "0, 1")
parser.add_argument("-lr", "--lr", help = "0, 1")
parser.add_argument("-percent_var", "--percent_var", help = "0, 1")
parser.add_argument("-context", "--context", help = "0, 1")
parser.add_argument("-mbsize", "--mbsize", help = "0, 1")
parser.add_argument("-file", "--file", help = "0, 1")

args = parser.parse_args()

def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(890)

data = loadmat('../drive/MyDrive/eeg_data/evoked_simulated_nonstationary2.mat')
npdata = data['X']
npdata_reps = npdata.mean(axis = 2)
npdata_changed_t = npdata_reps.swapaxes(0,1)
device = torch.device('cuda')
X = torch.tensor(npdata_changed_t, dtype=torch.float32, device=torch.device('cpu'))
crnn = cLSTM(X.shape[-1], hidden=100).cuda(device=device)
train_loss_list = train_model_accumulated_ista(
    crnn, X, context=10, mbsize=60000, lam=20.0, lam_ridge=1e-3, lr=1e-3, max_iter=20000,
    check_every=10, percent_var = 8)

GC_est = crnn.GC().cpu().data.numpy()
GC = np.array([[0, 0, 0, 0, 0],[1, 0, 0, 0, 0], [0, 0, 0, 0, 0],[0, 1, 0, 0, 0],[0, 0, 0, 0, 0]])
print('Estimated variable usage = %.2f%%' % (100 * np.mean(GC_est)))
print('Actual variable usage = %.2f%%' % (100 * np.mean(GC)))

fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
axarr[0].imshow(GC, cmap='viridis')
axarr[0].set_title('GC actual')
axarr[0].set_ylabel('Affected series')
axarr[0].set_xlabel('Causal series')
axarr[0].set_xticks([])
axarr[0].set_yticks([])

axarr[1].imshow(GC_est, cmap='viridis', vmin=0, vmax=1, extent=(0, len(GC_est), len(GC_est), 0))
axarr[1].set_ylabel('Affected series')
axarr[1].set_xlabel('Causal series')
axarr[1].set_xticks([])
axarr[1].set_yticks([])

plt.show()

In [None]:
# from scipy.io import loadmat

# d = loadmat('../drive/MyDrive/eeg_data/R1598_GC.mat')
# sparsity = d['sig_p'].mean()*100
# sparsity

In [None]:
# from scipy.io import loadmat
# import numpy as np
# import torch
# import matplotlib.pyplot as plt
# from models.clstm import cLSTM, train_model_accumulated_ista
# import random, os
# def seed_everything(seed: int):
#     random.seed(seed)
#     os.environ['PYTHONHASHSEED'] = str(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = True
    
# seed_everything(69)

# data = loadmat('../drive/MyDrive/eeg_data/R1598_1_500_covert_roi_matrix.mat')
# rois = [7,11,13,16,17,20,27,39,47,48,51,56,58,68,72,74,77,78,81,88,100,108,109,112,117,119]
# rois = [x-1 for x in rois]
# npdata = data['ts']
# npdata_changed = npdata[rois,:,:]
# npdata_reps = npdata_changed[:,[x for x in range(250,npdata_changed.shape[1])],:]

# npdata_reps = npdata_reps.mean(axis = 2)
# npdata_changed_t = npdata_reps.swapaxes(0,1)
# d = loadmat('../drive/MyDrive/eeg_data/R1598_GC.mat')
# # sparsity = d['sig_p'].mean()*100
# sparsity = 15
# device = torch.device('cuda')
# X = torch.tensor(npdata_changed_t[np.newaxis], dtype=torch.float32, device=torch.device('cpu'))
# # X = torch.tensor(npdata_changed_t, dtype=torch.float32, device=torch.device('cpu'))
# crnn = cLSTM(X.shape[-1], hidden=100).cuda(device=device)
# train_loss_list = train_model_accumulated_ista(
#     crnn, X, context=10, mbsize=500, lam=10.0, lam_ridge=1e-3, lr=1e-3/2, max_iter=20000,
#     check_every=10, percent_var = sparsity)

# GC_est = crnn.GC().cpu().data.numpy()
# GC = d['sig_p']
# print('Estimated variable usage = %.2f%%' % (100 * np.mean(GC_est)))
# print('Actual variable usage = %.2f%%' % (100 * np.mean(GC)))

# fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
# axarr[0].imshow(GC, cmap='viridis')
# axarr[0].set_title('GC actual')
# axarr[0].set_ylabel('Affected series')
# axarr[0].set_xlabel('Causal series')
# axarr[0].set_xticks([])
# axarr[0].set_yticks([])

# axarr[1].imshow(GC_est, cmap='viridis', vmin=0, vmax=1, extent=(0, len(GC_est), len(GC_est), 0))
# axarr[1].set_ylabel('Affected series')
# axarr[1].set_xlabel('Causal series')
# axarr[1].set_xticks([])
# axarr[1].set_yticks([])

# plt.show()

In [None]:
# import ast
# import numpy as np
# import matplotlib.pyplot as plt

In [None]:
# #1598
# v = "[[1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 1 0 0 0 0 0]\
#  [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 0 0 0]\
#  [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]\
#  [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]]"
# v = v.replace(' ',',')
# x = ast.literal_eval(v)
# GC = np.array(x)
# np.fill_diagonal(GC, 0, wrap=False)
# fig, axarr = plt.subplots(1, 1, figsize=(10, 5))
# axarr[0].imshow(GC, cmap='viridis')
# axarr[0].set_title('GC actual')
# axarr[0].set_ylabel('Affected series')
# axarr[0].set_xlabel('Causal series')
# axarr[0].set_xticks([])
# axarr[0].set_yticks([])

In [None]:
# #1551
# v = "[[1 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0]\
#  [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [1 1 0 1 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1]\
#  [1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]\
#  [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [1 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1]\
#  [1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1]\
#  [1 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1]\
#  [1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 1]\
#  [1 0 1 1 1 0 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 1 0 1 1 1]]"

# v = v.replace(' ',',')
# x = ast.literal_eval(v)
# GC = np.array(x)
# np.fill_diagonal(GC, 0, wrap=False)
# fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
# axarr[0].imshow(GC, cmap='viridis')
# axarr[0].set_title('GC actual')
# axarr[0].set_ylabel('Affected series')
# axarr[0].set_xlabel('Causal series')
# axarr[0].set_xticks([])
# axarr[0].set_yticks([])

In [None]:
# #1167
# v = "[[0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 1 1 1 1 0 0 0 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 1 1 1 1 0 1 0 1 0 0 0 0 0 1 0 0 1 0 0 1 0 0 0]\
#  [0 0 1 1 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\
#  [0 0 0 1 0 0 0 0 1 1 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 1 1 0 0 0 0 1 1 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 1 0 0 0 0 1 1 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 1 0 0 1 0 1 1 1 1 0 0 0 0 0 0 0 1 0 0 1 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0]\
#  [0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 1 0 0 0]\
#  [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 1 0 0 0]\
#  [0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 1 0 0 0]\
#  [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 1 0 0 0]\
#  [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 1 0 0 0]\
#  [0 0 0 1 0 1 0 0 0 1 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0]\
#  [0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]\
#  [0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 1 0 0 0]\
#  [0 0 0 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]]"

# v = v.replace(' ',',')
# x = ast.literal_eval(v)
# GC = np.array(x)
# np.fill_diagonal(GC, 0, wrap=False)
# fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
# axarr[0].imshow(GC, cmap='viridis')
# axarr[0].set_title('GC actual')
# axarr[0].set_ylabel('Affected series')
# axarr[0].set_xlabel('Causal series')
# axarr[0].set_xticks([])
# axarr[0].set_yticks([])