# cRNN Lorenz-96 Demo
- In this notebook, we train a cRNN model on data simulated from a Lorenz-96 system.

In [None]:
import numpy as np
import torch
from models.crnn import cRNN, train_model_gista
from synthetic import simulate_lorenz_96
import matplotlib.pyplot as plt

In [None]:
# For GPU acceleration
device = torch.device('cuda')

In [None]:
# Simulate data
p = 10
X_np, GC = simulate_lorenz_96(p, T=1000)
X = torch.tensor(X_np[np.newaxis], dtype=torch.float32, device=device)

In [None]:
# Plot data
fig, axarr = plt.subplots(1, 2, figsize=(16, 5))
axarr[0].plot(X_np)
axarr[1].plot(X_np[:100])
plt.show()

In [None]:
# Set up model
crnn = cRNN(p, hidden=10).cuda(device=device)

In [None]:
# Train with GISTA
check_every = 100
train_loss_list, train_mse_list = train_model_gista(
    crnn, X, lam=6.3, lam_ridge=1e-4, lr=0.005, max_iter=20000, check_every=check_every, truncation=5)

In [None]:
# Loss function plot
fig, axarr = plt.subplots(1, 2, figsize=(16, 5))

axarr[0].plot(check_every * np.arange(len(train_loss_list)), train_loss_list)
axarr[0].set_title('Train loss')

axarr[1].plot(check_every * np.arange(len(train_mse_list)), train_mse_list)
axarr[1].set_title('Train MSE')

plt.show()

In [None]:
# Verify learned Granger causality
GC_est = crnn.GC().cpu().data.numpy()

print('True variable usage = %.2f%%' % (100 * np.mean(GC)))
print('Estimated variable usage = %.2f%%' % (100 * np.mean(GC_est)))
print('Accuracy = %.2f%%' % (100 * np.mean(GC == GC_est)))

# Make figures
fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
axarr[0].imshow(GC, cmap='Blues')
axarr[0].set_title('GC actual')
axarr[0].set_ylabel('Affected series')
axarr[0].set_xlabel('Causal series')
axarr[0].set_xticks([])
axarr[0].set_yticks([])

axarr[1].imshow(GC_est, cmap='Blues', vmin=0, vmax=1, extent=(0, p, p, 0))
axarr[1].set_ylabel('Affected series')
axarr[1].set_xlabel('Causal series')
axarr[1].set_xticks([])
axarr[1].set_yticks([])

# Mark disagreements
for i in range(p):
    for j in range(p):
        if GC[i, j] != GC_est[i, j]:
            rect = plt.Rectangle((j, i-0.05), 1, 1, facecolor='none', edgecolor='red', linewidth=1)
            axarr[1].add_patch(rect)

plt.show()