# cMLP Lagged VAR Demo
- In this notebook, we train a cMLP model on data simulated from a vector autoregressive (VAR) process with lagged interactions. We use the hierarchical sparse penalty to perform lag selection.
- We use unregularized pretraining before training with GISTA.
- After examining the Granger causality discovery, we train a debiased model using only the discovered interactions.

In [None]:
import numpy as np
import torch
from models.cmlp import cMLP, cMLPSparse, train_model_adam, train_model_gista
from synthetic import simulate_var
import matplotlib.pyplot as plt

In [None]:
# For GPU acceleration
device = torch.device('cuda')

In [None]:
# Simulate data
p = 10
T = 1000
var_lag = 3
X_np, beta, GC = simulate_var(p, T, var_lag)
X = torch.tensor(X_np[np.newaxis], dtype=torch.float32, device=device)

In [None]:
# Plot data
fig, axarr = plt.subplots(1, 2, figsize=(16, 5))
axarr[0].plot(X_np)
axarr[1].plot(X_np[:100])
plt.show()

In [None]:
# Set up model
lag = 5
hidden = [10]
cmlp = cMLP(p, lag, hidden).cuda(device=device)

In [None]:
# Pretrain (no regularization)
check_every = 1000
train_loss_list = train_model_adam(cmlp, X, lr=1e-2, niter=10000, check_every=check_every)

# Plot loss function
plt.figure(figsize=(10, 5))
plt.title('Pretraining')
plt.ylabel('MSE')
plt.xlabel('Iterations')
plt.plot(check_every * np.arange(len(train_loss_list)), train_loss_list)
plt.show()

In [None]:
# Train with GISTA
check_every = 1000
train_loss_list, train_mse_list = train_model_gista(
    cmlp, X, lam=0.012, lam_ridge=1e-4, lr=0.02, penalty='H', max_iter=50000, check_every=check_every)

In [None]:
# Loss function plot
fig, axarr = plt.subplots(1, 2, figsize=(16, 5))

axarr[0].plot(check_every * np.arange(len(train_loss_list)), train_loss_list)
axarr[0].set_title('Train loss')

axarr[1].plot(check_every * np.arange(len(train_mse_list)), train_mse_list)
axarr[1].set_title('Train MSE')

plt.show()

In [None]:
# Verify learned Granger causality
GC_est = cmlp.GC().cpu().data.numpy()

print('True variable usage = %.2f%%' % (100 * np.mean(GC)))
print('Estimated variable usage = %.2f%%' % (100 * np.mean(GC_est)))
print('Accuracy = %.2f%%' % (100 * np.mean(GC == GC_est)))

# Make figures
fig, axarr = plt.subplots(1, 2, figsize=(16, 5))
axarr[0].imshow(GC, cmap='Blues')
axarr[0].set_title('GC actual')
axarr[0].set_ylabel('Affected series')
axarr[0].set_xlabel('Causal series')
axarr[0].set_xticks([])
axarr[0].set_yticks([])

axarr[1].imshow(GC_est, cmap='Blues', vmin=0, vmax=1, extent=(0, p, p, 0))
axarr[1].set_title('GC estimated')
axarr[1].set_ylabel('Affected series')
axarr[1].set_xlabel('Causal series')
axarr[1].set_xticks([])
axarr[1].set_yticks([])

# Mark disagreements
for i in range(p):
    for j in range(p):
        if GC[i, j] != GC_est[i, j]:
            rect = plt.Rectangle((j, i-0.05), 1, 1, facecolor='none', edgecolor='red', linewidth=1)
            axarr[1].add_patch(rect)

plt.show()

In [None]:
# Verify lag selection
for i in range(p):
    # Get true GC
    GC_lag = np.zeros((lag, p))
    GC_lag[:var_lag, GC[i].astype(bool)] = 1.0

    # Get estimated GC
    GC_est_lag = cmlp.GC(ignore_lag=False, threshold=False)[i].cpu().data.numpy().T[::-1]

    # Make figures
    fig, axarr = plt.subplots(1, 2, figsize=(16, 5))
    axarr[0].imshow(GC_lag, cmap='Blues', extent=(0, p, lag, 0))
    axarr[0].set_title('Series %d true GC' % (i + 1))
    axarr[0].set_ylabel('Lag')
    axarr[0].set_xlabel('Series')
    axarr[0].set_xticks(np.arange(p) + 0.5)
    axarr[0].set_xticklabels(range(p))
    axarr[0].set_yticks(np.arange(lag) + 0.5)
    axarr[0].set_yticklabels(range(1, lag + 1))
    axarr[0].tick_params(axis='both', length=0)

    axarr[1].imshow(GC_est_lag, cmap='Blues', extent=(0, p, lag, 0))
    axarr[1].set_title('Series %d estimated GC' % (i + 1))
    axarr[1].set_ylabel('Lag')
    axarr[1].set_xlabel('Series')
    axarr[1].set_xticks(np.arange(p) + 0.5)
    axarr[1].set_xticklabels(range(p))
    axarr[1].set_yticks(np.arange(lag) + 0.5)
    axarr[1].set_yticklabels(range(1, lag + 1))
    axarr[1].tick_params(axis='both', length=0)

    # Mark nonzeros
    for i in range(p):
        for j in range(lag):
            if GC_est_lag[j, i] > 0.0:
                rect = plt.Rectangle((i, j), 1, 1, facecolor='none', edgecolor='green', linewidth=1.0)
                axarr[1].add_patch(rect)

    plt.show()

In [None]:
# Create a debiased model
sparsity = cmlp.GC().byte()
cmlp_sparse = cMLPSparse(p, sparsity, lag, hidden).cuda(device=device)

# Train
check_every = 1000
train_loss_list = train_model_adam(cmlp_sparse, X, lr=1e-3, niter=20000, check_every=check_every, verbose=1)

# Plot loss function
plt.figure(figsize=(10, 5))
plt.title('Debiased model training')
plt.ylabel('MSE')
plt.xlabel('Iterations')
plt.plot(check_every * np.arange(len(train_loss_list)), train_loss_list)
plt.show()

In [None]:
# Get optimal forecasts using VAR parameters
X_optimal_forecast = np.zeros((p, T-var_lag))
for t in range(T-var_lag):
    X_optimal_forecast[:, t] = np.dot(beta, X_np.T[:, t:(t+var_lag)].flatten(order='F'))
X_optimal_forecast = X_optimal_forecast.T

# Forecast using debiased cMLP
X_pred = cmlp_sparse(X)

# Plot actual data and forecasts
num_points = 10

for i in range(p):
    plt.figure(figsize=(10, 5))
    plt.plot(X[0, var_lag:num_points+var_lag, i].cpu().data.numpy(), label='Actual')
    plt.plot(X_pred[0, :num_points, i].cpu().data.numpy(), label='cMLP forecasting')
    plt.plot(X_optimal_forecast[:num_points, i], label='Optimal forecasting')
    plt.legend(loc='upper right')
    plt.title('Series %d forecasting' % (i + 1))
    plt.show()