# cGRU; E.C. 1


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!git clone https://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@github.com/Proton1121/ngcausality.git

Cloning into 'ngcausality'...
remote: Enumerating objects: 560, done.[K
remote: Counting objects: 100% (375/375), done.[K
remote: Compressing objects: 100% (232/232), done.[K
remote: Total 560 (delta 272), reused 201 (delta 143), pack-reused 185 (from 1)[K
Receiving objects: 100% (560/560), 2.97 MiB | 6.81 MiB/s, done.
Resolving deltas: 100% (348/348), done.


In [3]:
%cd /content/ngcausality

/content/ngcausality


In [4]:
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from data.synthetic import simulate_lorenz_96
from data.dream import generate_causal_matrix
from models.cgru import cGRU, train_model_ista

In [5]:
save_dir = '/content/drive/MyDrive/ngcausality_results/' + 'cgru_ec1/'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [6]:
# For GPU acceleration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
# Simulate data
X_np = pd.read_csv('/content/ngcausality/data/InSilicoSize100-Ecoli1-trajectories.tsv', sep='\t')
GC, _ = generate_causal_matrix('/content/ngcausality/data/DREAM3GoldStandard_InSilicoSize100_Ecoli1.txt')
X = X_np.iloc[:, 1:]
X = X.to_numpy()
X = torch.tensor(X.reshape(46, 21, X.shape[1]), dtype=torch.float32, device=device)

In [8]:
# Save the simulated data to Google Drive
np.save(os.path.join(save_dir, 'X_np.npy'), X_np)  # Save X_np (simulated data)
np.save(os.path.join(save_dir, 'GC.npy'), GC)  # Save GC (Granger causality matrix)

torch.save(X, os.path.join(save_dir, 'X_tensor.pt'))

with open(os.path.join(save_dir, 'data_shapes.txt'), 'w') as f:
    f.write(f'Shape of X_np: {X_np.shape}\n')
    f.write(f'Shape of GC: {GC.shape}\n')
    f.write(f'Shape of X (torch tensor): {X.shape}\n')

In [9]:
# Plot data
fig, axarr = plt.subplots(1, 2, figsize=(16, 5))
axarr[0].plot(X_np.iloc[:21, 1:])
axarr[0].set_xlabel('T')
axarr[0].set_title('Entire time series (first replicate)')
axarr[1].plot(X_np.iloc[:21, 1:6])
axarr[1].set_xlabel('T')
axarr[1].set_title('First 5 series')
plt.tight_layout()


# Step 5: Save the plot to Google Drive
plot_filename = os.path.join(save_dir, 'data_plots.png')
plt.savefig(plot_filename)  # Save the plot as a PNG file in Google Drive

# Optionally, close the plot to prevent it from displaying in the notebook (you can skip this if you want to see it in the notebook)
plt.close()

In [14]:
for i in range(1):
  save_dir = '/content/drive/MyDrive/ngcausality_results/' + 'cgru_ec1/lam=' + str(0.1) + '/'

  if not os.path.exists(save_dir):
    os.makedirs(save_dir)

  #Set up model
  cgru = cGRU(X.shape[-1], hidden=10).to(device=device)

  # Train with ISTA
  train_loss_list=train_model_ista(cgru, X, context=10, lr=5e-2, max_iter=50000, lam=(0.001), lam_ridge=1e-2,
                   check_every=50)

  # Loss function plot
  plt.figure(figsize=(8, 5))
  train_loss_np = [loss.cpu().detach().numpy() for loss in train_loss_list]
  plt.plot(50 * np.arange(len(train_loss_np)), train_loss_np)
  plt.title('cGRU training')
  plt.ylabel('Loss')
  plt.xlabel('Training steps')
  plt.tight_layout()
  loss_plot_path = os.path.join(save_dir, f'loss_plot_{0.1}.png')
  plt.savefig(loss_plot_path)  # Save the loss plot to Google Drive
  plt.close()  # Close the plot to prevent it from displaying

  # Verify learned Granger causality
  GC_est = cgru.GC().cpu().data.numpy()

  results_file_path = os.path.join(save_dir, f'gc_results_{0.1}.txt')
  with open(results_file_path, 'w') as f:
    f.write(f'True variable usage = {100 * np.mean(GC)}%\n')
    f.write(f'Estimated variable usage = {100 * np.mean(GC_est)}%\n')
    f.write(f'Accuracy = {100 * np.mean(GC == GC_est)}%\n')
    f.write(f'True positives = {np.sum((GC == 1) & (GC_est == 1))}\n')
    f.write(f'True negatives = {np.sum((GC == 0) & (GC_est == 0))}\n')
    f.write(f'False positives = {np.sum((GC == 0) & (GC_est == 1))}\n')
    f.write(f'False negatives = {np.sum((GC == 1) & (GC_est == 0))}\n')

  # Make figures for Granger causality matrices
  #fig, axarr = plt.subplots(1, 2, figsize=(16, 5))
  #axarr[0].imshow(GC, cmap='Blues')
  #axarr[0].set_title('GC actual')
  #axarr[0].set_ylabel('Affected series')
  #axarr[0].set_xlabel('Causal series')
  #axarr[0].set_xticks([])
  #axarr[0].set_yticks([])

  #axarr[1].imshow(GC_est, cmap='Blues', vmin=0, vmax=1, extent=(0, len(GC_est), len(GC_est), 0))
  #axarr[1].set_title('GC estimated')
  #axarr[1].set_ylabel('Affected series')
  #axarr[1].set_xlabel('Causal series')
  #axarr[1].set_xticks([])
  #axarr[1].set_yticks([])

  # Mark disagreements
  #for i in range(len(GC_est)):
  #  for j in range(len(GC_est)):
  #      if GC[i, j] != GC_est[i, j]:
  #          rect = plt.Rectangle((j, i-0.05), 1, 1, facecolor='none', edgecolor='red', linewidth=1)
  #          axarr[1].add_patch(rect)

  #gc_plot_path = os.path.join(save_dir, f'gc_plot_{i}.png')
  #plt.savefig(gc_plot_path)  # Save the GC plot to Google Drive
  #plt.close()  # Close the plot to prevent it from displaying

----------Iter = 50----------
Loss = 0.201157
Variable usage = 100.00%
----------Iter = 100----------
Loss = 0.190056
Variable usage = 100.00%
----------Iter = 150----------
Loss = 0.180911
Variable usage = 100.00%
----------Iter = 200----------
Loss = 0.172920
Variable usage = 100.00%
----------Iter = 250----------
Loss = 0.165801
Variable usage = 100.00%
----------Iter = 300----------
Loss = 0.159408
Variable usage = 100.00%
----------Iter = 350----------
Loss = 0.153642
Variable usage = 100.00%
----------Iter = 400----------
Loss = 0.148427
Variable usage = 100.00%
----------Iter = 450----------
Loss = 0.143701
Variable usage = 100.00%
----------Iter = 500----------
Loss = 0.139411
Variable usage = 100.00%
----------Iter = 550----------
Loss = 0.135514
Variable usage = 100.00%
----------Iter = 600----------
Loss = 0.131969
Variable usage = 100.00%
----------Iter = 650----------
Loss = 0.128742
Variable usage = 100.00%
----------Iter = 700----------
Loss = 0.125801
Variable usage = 1