In [1]:
import torch

print(torch.__version__)

1.12.1+cu102


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.nn import Parameter
from torch_geometric.nn import ChebConv
from torch_geometric.nn.inits import glorot, zeros


class GCLSTM(torch.nn.Module):
    """
    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        K (int): Chebyshev filter size :math:`K`.
        normalization (str, optional): The normalization scheme for the graph
            Laplacian (default: :obj:`"sym"`):

            1. :obj:`None`: No normalization
            :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`

            2. :obj:`"sym"`: Symmetric normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
            \mathbf{D}^{-1/2}`

            3. :obj:`"rw"`: Random-walk normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`

            You need to pass :obj:`lambda_max` to the :meth:`forward` method of
            this operator in case the normalization is non-symmetric.
            :obj:`\lambda_max` should be a :class:`torch.Tensor` of size
            :obj:`[num_graphs]` in a mini-batch scenario and a
            scalar/zero-dimensional tensor when operating on single graphs.
            You can pre-compute :obj:`lambda_max` via the
            :class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        K: int,
        normalization: str = "sym",
        bias: bool = True,
    ):
        super(GCLSTM, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.normalization = normalization
        self.bias = bias
        self._create_parameters_and_layers()
        self._set_parameters()

    def _create_input_gate_parameters_and_layers(self):

        self.conv_i = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.W_i = Parameter(torch.Tensor(self.in_channels, self.out_channels))
        self.b_i = Parameter(torch.Tensor(1, self.out_channels))

    def _create_forget_gate_parameters_and_layers(self):

        self.conv_f = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.W_f = Parameter(torch.Tensor(self.in_channels, self.out_channels))
        self.b_f = Parameter(torch.Tensor(1, self.out_channels))

    def _create_cell_state_parameters_and_layers(self):

        self.conv_c = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.W_c = Parameter(torch.Tensor(self.in_channels, self.out_channels))
        self.b_c = Parameter(torch.Tensor(1, self.out_channels))

    def _create_output_gate_parameters_and_layers(self):

        self.conv_o = ChebConv(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            normalization=self.normalization,
            bias=self.bias,
        )

        self.W_o = Parameter(torch.Tensor(self.in_channels, self.out_channels))
        self.b_o = Parameter(torch.Tensor(1, self.out_channels))

    def _create_parameters_and_layers(self):
        self._create_input_gate_parameters_and_layers()
        self._create_forget_gate_parameters_and_layers()
        self._create_cell_state_parameters_and_layers()
        self._create_output_gate_parameters_and_layers()

    def _set_parameters(self):
        glorot(self.W_i)
        glorot(self.W_f)
        glorot(self.W_c)
        glorot(self.W_o)
        zeros(self.b_i)
        zeros(self.b_f)
        zeros(self.b_c)
        zeros(self.b_o)

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _set_cell_state(self, X, C):
        if C is None:
            C = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return C

    def _calculate_input_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
        I = torch.matmul(X, self.W_i)
        I = I + self.conv_i(H, edge_index, edge_weight, lambda_max=lambda_max)
        I = I + self.b_i
        I = torch.sigmoid(I)
        return I

    def _calculate_forget_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
        F = torch.matmul(X, self.W_f)
        F = F + self.conv_f(H, edge_index, edge_weight, lambda_max=lambda_max)
        F = F + self.b_f
        F = torch.sigmoid(F)
        return F

    def _calculate_cell_state(self, X, edge_index, edge_weight, H, C, I, F, lambda_max):
        T = torch.matmul(X, self.W_c)
        T = T + self.conv_c(H, edge_index, edge_weight, lambda_max=lambda_max)
        T = T + self.b_c
        T = torch.tanh(T)
        C = F * C + I * T
        return C

    def _calculate_output_gate(self, X, edge_index, edge_weight, H, C, lambda_max):
        O = torch.matmul(X, self.W_o)
        O = O + self.conv_o(H, edge_index, edge_weight, lambda_max=lambda_max)
        O = O + self.b_o
        O = torch.sigmoid(O)
        return O

    def _calculate_hidden_state(self, O, C):
        H = O * torch.tanh(C)
        return H
    
    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
        C: torch.FloatTensor = None,
        lambda_max: torch.Tensor = None,
    ) -> torch.FloatTensor:
        """
        Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state and cell state
        matrices are not present when the forward pass is called these are
        initialized with zeros.

        Arg types:
            * **X** *(PyTorch Float Tensor)* - Node features.
            * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
            * **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
            * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.
            * **C** *(PyTorch Float Tensor, optional)* - Cell state matrix for all nodes.
            * **lambda_max** *(PyTorch Tensor, optional but mandatory if normalization is not sym)* - Largest eigenvalue of Laplacian.

        Return types:
            * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
            * **C** *(PyTorch Float Tensor)* - Cell state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        C = self._set_cell_state(X, C)
        I = self._calculate_input_gate(X, edge_index, edge_weight, H, C, lambda_max)
        F = self._calculate_forget_gate(X, edge_index, edge_weight, H, C, lambda_max)
        C = self._calculate_cell_state(X, edge_index, edge_weight, H, C, I, F, lambda_max)
        O = self._calculate_output_gate(X, edge_index, edge_weight, H, C, lambda_max)
        H = self._calculate_hidden_state(O, C)
        return H, C


In [283]:
import networkx as nx
import numpy as np

def create_mock_data(number_of_nodes, edge_per_node, in_channels):
    """
    Creating a mock feature matrix and edge index.
    """
    graph = nx.watts_strogatz_graph(number_of_nodes, edge_per_node, 0.5)
    edge_index = torch.LongTensor(np.array([edge for edge in graph.edges()]).T)
    X = torch.FloatTensor(np.random.uniform(-1, 1, (number_of_nodes, in_channels)))
    #X = torch.FloatTensor(np.resize(np.arange(number_of_nodes), (in_channels, number_of_nodes)).T)
    return X, edge_index

def create_mock_edge_weight(edge_index):
    """
    Creating a mock edge weight tensor.
    """
    return torch.FloatTensor(np.random.uniform(0, 1, (edge_index.shape[1])))

In [284]:
from torch_geometric_temporal.nn.recurrent import GCLSTM

def test_gc_lstm_layer():
    """
    Testing the GCLSTM Layer.
    """
    number_of_nodes = 100
    edge_per_node = 10
    in_channels = 64
    out_channels = 16
    K = 2

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X, edge_index = create_mock_data(number_of_nodes, edge_per_node, in_channels)
    X = X.to(device)
    edge_index = edge_index.to(device)
    edge_weight = create_mock_edge_weight(edge_index).to(device)

    layer = GCLSTM(in_channels=in_channels, out_channels=out_channels, K=K).to(device)

    H, C = layer(X, edge_index, edge_weight)

    assert H.shape == (number_of_nodes, out_channels)
    assert C.shape == (number_of_nodes, out_channels)

    H, C = layer(X, edge_index, edge_weight, H, C)

    assert H.shape == (number_of_nodes, out_channels)
    assert C.shape == (number_of_nodes, out_channels)

    H, C = layer(X, edge_index, edge_weight, H, C)

    assert H.shape == (number_of_nodes, out_channels)
    assert C.shape == (number_of_nodes, out_channels)
    
    H, C = layer(X, edge_index, edge_weight, H, C)

    assert H.shape == (number_of_nodes, out_channels)
    assert C.shape == (number_of_nodes, out_channels)

    H, C = layer(X, edge_index, edge_weight, H, C)

    assert H.shape == (number_of_nodes, out_channels)
    assert C.shape == (number_of_nodes, out_channels)

    return H, C

In [285]:
print(test_gc_lstm_layer())

(tensor([[-0.0323,  0.3428, -0.2118,  ...,  0.0651,  0.5945, -0.0238],
        [ 0.3727,  0.2327, -0.0910,  ..., -0.1904,  0.4503, -0.5138],
        [ 0.0144,  0.0845, -0.0661,  ..., -0.0485,  0.1646, -0.1593],
        ...,
        [-0.5774,  0.0808, -0.0353,  ..., -0.0031,  0.5244,  0.2021],
        [-0.2700,  0.5079,  0.3706,  ...,  0.0141, -0.0491,  0.5180],
        [-0.0474, -0.0568, -0.2530,  ..., -0.0587,  0.3075, -0.0203]],
       device='cuda:0', grad_fn=<MulBackward0>), tensor([[-0.0889,  1.1728, -0.5472,  ...,  0.2536,  1.1242, -0.0355],
        [ 0.5635,  1.1498, -0.3582,  ..., -1.3228,  0.9414, -0.8022],
        [ 0.0274,  0.3380, -1.0526,  ..., -0.0832,  0.5213, -0.2658],
        ...,
        [-1.1701,  0.1954, -0.2491,  ..., -0.0268,  1.2052,  0.4665],
        [-1.1908,  0.8827,  1.0925,  ...,  0.0364, -1.3365,  0.9876],
        [-0.0873, -0.1680, -0.4009,  ..., -0.1585,  0.4510, -0.0413]],
       device='cuda:0', grad_fn=<AddBackward0>))


In [286]:
def k_hop_based_subgraph_generator_BFS(k, feature_vector, adjacency_matrix):
    N = len(feature_vector)        
    subgraphs = []

    def BFS(u):
        visit = [False] * N
        qu = [(u, 0)]
        visit[u] = True
        ret = []

        while len(qu) > 0:
            u, dist = qu.pop(0)
            if dist > k:
                break
            ret.append(u)

            for v in range(N):
                if adjacency_matrix[u][v] == 1 and visit[v] == False:
                    qu.append((v, dist + 1))
                    visit[v] = True
        return ret    
    
    for i in range(N):
        candidate_list = BFS(i)
        sub_feature = feature_vector[candidate_list]
        sub_adj = adjacency_matrix[:,candidate_list][candidate_list,:]
        subgraphs.append([sub_feature, sub_adj])
    return subgraphs

In [322]:
def src_city_model():

    number_of_nodes = 4
    edge_per_node = 2
    
    in_channels = 64
    out_channels = 16
    K = 2
    k = 1
    number_of_layers = 5

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X, edge_index = create_mock_data(number_of_nodes, edge_per_node, in_channels)
    X = X.to(device)
    edge_index = edge_index.to(device)
    edge_weight = torch.FloatTensor(np.ones(edge_index.shape[1])).to(device)
    adj = torch.FloatTensor(np.zeros((number_of_nodes, number_of_nodes))).to(device)
    for i in range(edge_index.shape[1]):
        adj[edge_index[0][i]][edge_index[1][i]] = edge_weight[i]
        adj[edge_index[1][i]][edge_index[0][i]] = edge_weight[i]
    
    print(X)
    print(adj)

    subgraph = k_hop_based_subgraph_generator_BFS(k, X, adj)
    subGCLSTM = []

    for sub_feature, sub_adj in subgraph:
        layer = GCLSTM(in_channels=in_channels, out_channels=out_channels, K=K).to(device)

        edges = []
        for i in range(len(sub_adj)):
            for j in range(len(sub_adj[0])):
                if sub_adj[i][j] == 1:
                    edges.append([i, j])
        
        X = sub_feature
        edge_index = torch.LongTensor(np.array(edges).T).to(device)
        edge_weight = torch.FloatTensor(np.ones(edge_index.shape[1])).to(device)
        
        #print(X.shape)
        #print(edge_index.shape)
        #print(edge_weight.shape)

        for i in range(number_of_layers):
            if i == 0:
                H, C = layer(X, edge_index, edge_weight)
            else:
                H, C = layer(X, edge_index, edge_weight, H, C)

            #print("!!!! Layer %d !!!!" % (i+1))
            #print(H.shape, C.shape)
        
        #subGCLSTM.append([H, C])
        subGCLSTM.append(H[0])
    
    subGCLSTM = torch.stack(subGCLSTM, 0)
    
    return subGCLSTM

In [323]:
src_city_model()

tensor([[-2.1266e-01, -8.6506e-01,  2.3652e-01, -6.4157e-02,  9.0521e-02,
         -1.4844e-01, -8.5602e-01, -7.8615e-01,  1.3495e-01,  7.5214e-01,
          9.4506e-01, -4.6812e-01,  5.2390e-01, -5.6317e-01,  4.4783e-01,
          7.8202e-01, -7.0438e-01, -4.3343e-01,  7.4851e-01,  4.5329e-01,
          1.0470e-01,  5.4400e-01, -2.0877e-01,  6.1310e-01, -8.2077e-01,
          1.3996e-01, -2.4717e-01,  5.1435e-01,  8.8744e-01,  7.9021e-02,
         -9.9111e-01,  7.0411e-01,  8.5922e-01, -9.4183e-01,  6.7335e-01,
          8.0959e-01, -4.8604e-02, -5.1650e-01,  4.0393e-01,  5.1963e-03,
         -7.8824e-01,  5.2493e-01, -9.9614e-01,  2.8320e-01, -3.0696e-01,
          2.0003e-01,  9.7612e-02,  3.4456e-03,  4.2026e-01, -4.2095e-01,
         -7.7523e-02, -4.3970e-01, -1.2322e-01, -9.2096e-01,  5.1951e-01,
         -4.4099e-01, -5.8768e-01, -1.5570e-02,  9.3897e-01, -1.4384e-01,
         -6.9912e-01,  7.2026e-01,  7.0450e-01,  8.1244e-01],
        [-7.7298e-01, -9.9869e-01,  5.9110e-01,  4

tensor([[ 0.0601, -0.2638, -0.0075, -0.1026,  0.0840,  0.1801,  0.1399,  0.1473,
          0.1354,  0.0799,  0.2612,  0.2260,  0.0583,  0.3097,  0.0861, -0.1621],
        [-0.2847, -0.1465,  0.3910,  0.1419, -0.0216,  0.2647,  0.0731,  0.1530,
         -0.1393, -0.2601, -0.1239,  0.3101,  0.0092,  0.0693, -0.2516, -0.2560],
        [ 0.1839, -0.2090,  0.1885, -0.0120, -0.1155,  0.0942, -0.2914, -0.1912,
          0.1138, -0.3456,  0.0674,  0.2487,  0.1117,  0.0846, -0.0811,  0.2475],
        [ 0.0530,  0.0848, -0.5887, -0.2768, -0.1067, -0.2809, -0.0773,  0.0977,
          0.1195,  0.3115,  0.3492, -0.1180,  0.1107,  0.1006,  0.5462, -0.0227]],
       device='cuda:0', grad_fn=<StackBackward0>)