In [2]:
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy

class LSTM(nn.Module):
    def __init__(self, num_series, hidden):
        '''
        LSTM model with output layer to generate predictions.
        Args:
          num_series: number of input time series.
          hidden: number of hidden units.
        '''
        super(LSTM, self).__init__()
        self.p = num_series
        self.hidden = hidden

        # Set up network.
        self.lstm = nn.LSTM(num_series, hidden, batch_first=True)
        self.lstm.flatten_parameters()
        self.linear = nn.Conv1d(hidden, 1, 1)

    def init_hidden(self, batch):
        '''Initialize hidden states for LSTM cell.'''
        device = self.lstm.weight_ih_l0.device
        return (torch.zeros(1, batch, self.hidden, device=device),
                torch.zeros(1, batch, self.hidden, device=device))

    def forward(self, X, hidden=None):
        # Set up hidden state.
        if hidden is None:
            hidden = self.init_hidden(X.shape[0])

        # Apply LSTM.
        X, hidden = self.lstm(X, hidden)

        # Calculate predictions using output layer.
        X = X.transpose(2, 1)
        X = self.linear(X)
        return X.transpose(2, 1), hidden


class cLSTM(nn.Module):
    def __init__(self, num_series, hidden):
        '''
        cLSTM model with one LSTM per time series.
        Args:
          num_series: dimensionality of multivariate time series.
          hidden: number of units in LSTM cell.
        '''
        super(cLSTM, self).__init__()
        self.p = num_series
        self.hidden = hidden

        # Set up networks.
        self.networks = nn.ModuleList([
            LSTM(num_series, hidden) for _ in range(num_series)])

    def forward(self, X, hidden=None):
        '''
        Perform forward pass.
        Args:
          X: torch tensor of shape (batch, T, p).
          hidden: hidden states for LSTM cell.
        '''
        if hidden is None:
            hidden = [None for _ in range(self.p)]
        pred = [self.networks[i](X, hidden[i])
                for i in range(self.p)]
        pred, hidden = zip(*pred)
        pred = torch.cat(pred, dim=2)
        return pred, hidden

    def GC(self, threshold=True):
        '''
        Extract learned Granger causality.
        Args:
          threshold: return norm of weights, or whether norm is nonzero.
        Returns:
          GC: (p x p) matrix. Entry (i, j) indicates whether variable j is
            Granger causal of variable i.
        '''
        GC = [torch.norm(net.lstm.weight_ih_l0, dim=0)
              for net in self.networks]
        GC = torch.stack(GC)
        if threshold:
            return (GC > 0).int()
        else:
            return GC

In [17]:
model = cLSTM(10, 1)

In [18]:
torch.norm(model.networks[0].lstm.weight_ih_l0, dim = 0)

tensor([0.6173, 1.4973, 1.2667, 1.1568, 0.5270, 1.0651, 0.8881, 1.6661, 1.2385,
        1.0523], grad_fn=<CopyBackwards>)

In [19]:
GC = [torch.norm(net.lstm.weight_ih_l0, dim=0) for net in model.networks]

In [20]:
GC = torch.stack(GC)

In [21]:
GC

tensor([[0.6173, 1.4973, 1.2667, 1.1568, 0.5270, 1.0651, 0.8881, 1.6661, 1.2385,
         1.0523],
        [0.9962, 0.8545, 1.4933, 0.7897, 1.3346, 1.3584, 0.8313, 1.3247, 1.0450,
         0.7429],
        [0.4315, 1.2793, 1.5651, 1.5193, 1.1819, 1.3165, 1.4299, 0.5405, 1.3209,
         1.5309],
        [0.8698, 1.0056, 0.6390, 1.2231, 1.3162, 1.3827, 1.2765, 1.1480, 1.3012,
         1.0124],
        [1.0124, 1.0345, 0.7956, 1.0513, 1.5305, 1.0449, 1.2376, 1.2703, 1.0327,
         1.2977],
        [0.8290, 1.3051, 1.2378, 1.0446, 1.3819, 1.0458, 0.9777, 1.3437, 1.0110,
         0.6101],
        [1.0669, 1.2738, 0.9987, 1.0002, 0.8015, 1.0384, 1.0044, 0.9458, 1.1482,
         1.3276],
        [1.2178, 1.1900, 1.1593, 0.6082, 0.6262, 0.7715, 1.6707, 1.2949, 1.4524,
         0.9718],
        [1.1711, 1.2581, 1.0922, 0.9312, 0.8908, 1.2116, 1.5612, 1.3664, 1.4758,
         0.5883],
        [0.6969, 0.7921, 1.3447, 1.4845, 0.5037, 1.2211, 0.9846, 0.6048, 0.8401,
         0.7758]], grad_fn=<