In [1]:
import torch
from torch.nn import Parameter
from torch_geometric.nn import HeteroConv, SAGEConv, GCNConv, GATConv
from torch_geometric.nn.inits import glorot
import torch.nn as nn

# Some RNN versions

In [2]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMModel, self).__init__()
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size

    def forward(self, x):
        batch_size = x.size(0)
        hidden_state = torch.zeros(batch_size, self.hidden_size).to(x.device)
        cell_state = torch.zeros(batch_size, self.hidden_size).to(x.device)

        for t in range(x.size(1)):
            hidden_state, cell_state = self.lstm_cell(x[:, t, :], (hidden_state, cell_state))

        # Предсказание на несколько шагов вперед
        output = self.fc(hidden_state)
        return output

# Пример использования
input_size = 5   # размерность входных данных
hidden_size = 32 # размер скрытого состояния
output_size = 1  # размерность выходных данных (например, предсказание следующего значения)

model = LSTMModel(input_size, hidden_size, output_size)
input_data = torch.randn(32, 10, input_size)  # пример входных данных: (batch_size, sequence_length, input_size)
output = model(input_data)
print(output.shape)  # ожидаемый вывод: (32, 1)

torch.Size([32, 1])


In [3]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, out_steps):
        super(LSTMModel, self).__init__()
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, out_steps)  # Изменяем на out_steps
        self.hidden_size = hidden_size

    def forward(self, x):
        batch_size = x.size(0)
        hidden_state = torch.zeros(batch_size, self.hidden_size).to(x.device)
        cell_state = torch.zeros(batch_size, self.hidden_size).to(x.device)

        for t in range(x.size(1)):
            hidden_state, cell_state = self.lstm_cell(x[:, t, :], (hidden_state, cell_state))

        # Предсказание на несколько шагов вперед
        output = self.fc(hidden_state)
        
        # Изменяем размерность на (batch_size, out_steps, 1)
        output = output.view(batch_size, -1, 1)  # Здесь 1 - это количество признаков для предсказания
        return output

# Пример использования
input_size = 1   # размерность входных данных
hidden_size = 64 # размер скрытого состояния
out_steps = 5    # количество шагов вперед для предсказания

model = LSTMModel(input_size, hidden_size, out_steps)
input_data = torch.randn(32, 10, input_size)  # пример входных данных: (batch_size, sequence_length, input_size)
output = model(input_data)
print(output.shape)  # ожидаемый вывод: (32, out_steps, 1)

torch.Size([32, 5, 1])


In [5]:
class LSTMForecastingModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMForecastingModel, self).__init__()
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size

    def forward(self, x, steps_ahead):
        batch_size = x.size(0)
        hidden_state = torch.zeros(batch_size, self.hidden_size).to(x.device)
        cell_state = torch.zeros(batch_size, self.hidden_size).to(x.device)

        for t in range(x.size(1)):
            hidden_state, cell_state = self.lstm_cell(x[:, t, :], (hidden_state, cell_state))

        predictions = []
        
        # Предсказания на несколько шагов вперед
        for _ in range(steps_ahead):
            output = self.fc(hidden_state)
            predictions.append(output.unsqueeze(1))
            hidden_state, cell_state = self.lstm_cell(output, (hidden_state, cell_state))

        return torch.cat(predictions, dim=1)

# Пример использования
input_size = 1   # размерность входных данных
hidden_size = 64 # размер скрытого состояния
output_size = 1  # размерность выходных данных

model = LSTMForecastingModel(input_size, hidden_size, output_size)
input_data = torch.randn(32, 10, input_size)  # пример входных данных: (batch_size, sequence_length, input_size)
steps_ahead = 5
output = model(input_data, steps_ahead)
print(output.shape)  # ожидаемый вывод: (32, steps_ahead, 1)

torch.Size([32, 5, 1])


# My model versions

**Случайное изменение количества связей в графе**

In [63]:
def change_edge_number(edges, attrs):

    num_cols = edges.size(1) # Количество столбцов
    num_remove = torch.randint(low=200, high=num_cols//2, size=(1,)).item()  # Количество столбцов для удаления
    remove_indices = torch.randperm(num_cols)[:num_remove]  # Случайные индексы для удаления
    # print(f"Индексы для удаления: {remove_indices}")

    # Генерируем индексы оставшихся столбцов
    remaining_indices = torch.tensor([i for i in range(num_cols) if i not in remove_indices])

    # Создаем новый тензор без удаляемых столбцов
    remaining_edges = edges[:, remaining_indices]
    remaining_attrs = attrs[remaining_indices]
    # print(f"Тензор после удаления столбцов:\n{result}")
    # print(f"Размер тензора после удаления столбцов:\n{result.shape}")
    return remaining_edges, remaining_attrs

In [2]:
class HeteroGCLSTM_SAGE(torch.nn.Module):
    r"""An implementation similar to the Integrated Graph Convolutional Long Short Term
        Memory Cell for heterogeneous Graphs.

        Args:
            in_channels_dict (dict of keys=str and values=int): Dimension of each node's input features.
            out_channels (int): Number of output features.
            metadata (tuple): Metadata on node types and edge types in the graphs. Can be generated via PyG method
                :obj:`snapshot.metadata()` where snapshot is a single HeteroData object.
            bias (bool, optional): If set to :obj:`False`, the layer will not learn
                an additive bias. (default: :obj:`True`)
    """

    def __init__(
            self,
            in_channels_dict: dict,
            out_channels: int,
            metadata: tuple,
            bias: bool = True
    ):
        super(HeteroGCLSTM_SAGE, self).__init__()

        self.in_channels_dict = in_channels_dict
        self.out_channels = out_channels
        self.metadata = metadata
        self.bias = bias
        self._create_parameters_and_layers()
        self._set_parameters()

    def _create_input_gate_parameters_and_layers(self):
        self.conv_i = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1),
                                                      out_channels=self.out_channels,
                                                      bias=self.bias) for edge_type in self.metadata[1]})

        self.W_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_forget_gate_parameters_and_layers(self):
        self.conv_f = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1),
                                                      out_channels=self.out_channels,
                                                      bias=self.bias) for edge_type in self.metadata[1]})

        self.W_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_cell_state_parameters_and_layers(self):
        self.conv_c = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1),
                                                      out_channels=self.out_channels,
                                                      bias=self.bias) for edge_type in self.metadata[1]})

        self.W_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_output_gate_parameters_and_layers(self):
        self.conv_o = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1),
                                                      out_channels=self.out_channels,
                                                      bias=self.bias) for edge_type in self.metadata[1]})

        self.W_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_parameters_and_layers(self):
        self._create_input_gate_parameters_and_layers()
        self._create_forget_gate_parameters_and_layers()
        self._create_cell_state_parameters_and_layers()
        self._create_output_gate_parameters_and_layers()

    def _set_parameters(self):
        for key in self.W_i:
            glorot(self.W_i[key])
        for key in self.W_f:
            glorot(self.W_f[key])
        for key in self.W_c:
            glorot(self.W_c[key])
        for key in self.W_o:
            glorot(self.W_o[key])
        for key in self.b_i:
            glorot(self.b_i[key])
        for key in self.b_f:
            glorot(self.b_f[key])
        for key in self.b_c:
            glorot(self.b_c[key])
        for key in self.b_o:
            glorot(self.b_o[key])

    def _set_hidden_state(self, x_dict, h_dict):
        if h_dict is None:
            h_dict = {node_type: torch.zeros(X.shape[0], self.out_channels) for node_type, X in x_dict.items()}
        return h_dict

    def _set_cell_state(self, x_dict, c_dict):
        if c_dict is None:
            c_dict = {node_type: torch.zeros(X.shape[0], self.out_channels) for node_type, X in x_dict.items()}
        return c_dict

    def _calculate_input_gate(self, x_dict, edge_index_dict, h_dict, c_dict):
        i_dict = {node_type: torch.matmul(X, self.W_i[node_type]) for node_type, X in x_dict.items()}
        conv_i = self.conv_i(h_dict, edge_index_dict)
        i_dict = {node_type: I + conv_i[node_type] for node_type, I in i_dict.items()}
        i_dict = {node_type: I + self.b_i[node_type] for node_type, I in i_dict.items()}
        i_dict = {node_type: torch.sigmoid(I) for node_type, I in i_dict.items()}
        return i_dict

    def _calculate_forget_gate(self, x_dict, edge_index_dict, h_dict, c_dict):
        f_dict = {node_type: torch.matmul(X, self.W_f[node_type]) for node_type, X in x_dict.items()}
        conv_f = self.conv_f(h_dict, edge_index_dict)
        f_dict = {node_type: F + conv_f[node_type] for node_type, F in f_dict.items()}
        f_dict = {node_type: F + self.b_f[node_type] for node_type, F in f_dict.items()}
        f_dict = {node_type: torch.sigmoid(F) for node_type, F in f_dict.items()}
        return f_dict

    def _calculate_cell_state(self, x_dict, edge_index_dict, h_dict, c_dict, i_dict, f_dict):
        t_dict = {node_type: torch.matmul(X, self.W_c[node_type]) for node_type, X in x_dict.items()}
        conv_c = self.conv_c(h_dict, edge_index_dict)
        t_dict = {node_type: T + conv_c[node_type] for node_type, T in t_dict.items()}
        t_dict = {node_type: T + self.b_c[node_type] for node_type, T in t_dict.items()}
        t_dict = {node_type: torch.tanh(T) for node_type, T in t_dict.items()}
        c_dict = {node_type: f_dict[node_type] * C + i_dict[node_type] * t_dict[node_type] for node_type, C in c_dict.items()}
        return c_dict

    def _calculate_output_gate(self, x_dict, edge_index_dict, h_dict, c_dict):
        o_dict = {node_type: torch.matmul(X, self.W_o[node_type]) for node_type, X in x_dict.items()}
        conv_o = self.conv_o(h_dict, edge_index_dict)
        o_dict = {node_type: O + conv_o[node_type] for node_type, O in o_dict.items()}
        o_dict = {node_type: O + self.b_o[node_type] for node_type, O in o_dict.items()}
        o_dict = {node_type: torch.sigmoid(O) for node_type, O in o_dict.items()}
        return o_dict

    def _calculate_hidden_state(self, o_dict, c_dict):
        h_dict = {node_type: o_dict[node_type] * torch.tanh(C) for node_type, C in c_dict.items()}
        return h_dict

    def forward(
        self,
        x_dict,
        edge_index_dict,
        h_dict=None,
        c_dict=None,
    ):
        """
        Making a forward pass. If the hidden state and cell state
        matrix dicts are not present when the forward pass is called these are
        initialized with zeros.

        Arg types:
            * **x_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensors)* - Node features dicts. Can
                be obtained via PyG method :obj:`snapshot.x_dict` where snapshot is a single HeteroData object.
            * **edge_index_dict** *(Dictionary where keys=Tuples and values=PyTorch Long Tensors)* - Graph edge type
                and index dicts. Can be obtained via PyG method :obj:`snapshot.edge_index_dict`.
            * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and
                hidden state matrix dict for all nodes.
            * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and
                cell state matrix dict for all nodes.

        Return types:
            * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and
                hidden state matrix dict for all nodes.
            * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and
                cell state matrix dict for all nodes.
        """

        h_dict = self._set_hidden_state(x_dict, h_dict)
        c_dict = self._set_cell_state(x_dict, c_dict)
        i_dict = self._calculate_input_gate(x_dict, edge_index_dict, h_dict, c_dict)
        f_dict = self._calculate_forget_gate(x_dict, edge_index_dict, h_dict, c_dict)
        c_dict = self._calculate_cell_state(x_dict, edge_index_dict, h_dict, c_dict, i_dict, f_dict)
        o_dict = self._calculate_output_gate(x_dict, edge_index_dict, h_dict, c_dict)
        h_dict = self._calculate_hidden_state(o_dict, c_dict)
        return h_dict, c_dict

In [3]:
class HeteroGCLSTM_GAT_edge(torch.nn.Module):
    r"""An implementation similar to the Integrated Graph Convolutional Long Short Term
        Memory Cell for heterogeneous Graphs.

        Args:
            in_channels_dict (dict of keys=str and values=int): Dimension of each node's input features.
            out_channels (int): Number of output features.
            metadata (tuple): Metadata on node types and edge types in the graphs. Can be generated via PyG method
                :obj:`snapshot.metadata()` where snapshot is a single HeteroData object.
            bias (bool, optional): If set to :obj:`False`, the layer will not learn
                an additive bias. (default: :obj:`True`)
    """

    def __init__(
            self,
            in_channels_dict: dict,
            out_channels: int,
            metadata: tuple,
            bias: bool = True
    ):
        super(HeteroGCLSTM_GAT_edge, self).__init__()

        self.in_channels_dict = in_channels_dict
        self.out_channels = out_channels
        self.metadata = metadata
        self.bias = bias
        self._create_parameters_and_layers()
        self._set_parameters()

    def _create_input_gate_parameters_and_layers(self):
        self.conv_i = HeteroConv({edge_type: GATConv(in_channels=(-1, -1),
                                                     out_channels=self.out_channels,
                                                     heads=1,
                                                     concat=False,
                                                     edge_dim=1,
                                                     bias=self.bias,
                                                     add_self_loops=False) for edge_type in self.metadata[1]})

        self.W_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_forget_gate_parameters_and_layers(self):
        self.conv_f = HeteroConv({edge_type: GATConv(in_channels=(-1, -1),
                                                     out_channels=self.out_channels,
                                                     heads=1,
                                                     concat=False,
                                                     edge_dim=1,
                                                     bias=self.bias,
                                                     add_self_loops=False) for edge_type in self.metadata[1]})

        self.W_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_cell_state_parameters_and_layers(self):
        self.conv_c = HeteroConv({edge_type: GATConv(in_channels=(-1, -1),
                                                     out_channels=self.out_channels,
                                                     heads=1,
                                                     concat=False,
                                                     edge_dim=1,
                                                     bias=self.bias,
                                                     add_self_loops=False) for edge_type in self.metadata[1]})

        self.W_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_output_gate_parameters_and_layers(self):
        self.conv_o = HeteroConv({edge_type: GATConv(in_channels=(-1, -1),
                                                     out_channels=self.out_channels,
                                                     heads=1,
                                                     concat=False,
                                                     edge_dim=1,
                                                     bias=self.bias,
                                                     add_self_loops=False) for edge_type in self.metadata[1]})

        self.W_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_parameters_and_layers(self):
        self._create_input_gate_parameters_and_layers()
        self._create_forget_gate_parameters_and_layers()
        self._create_cell_state_parameters_and_layers()
        self._create_output_gate_parameters_and_layers()

    def _set_parameters(self):
        for key in self.W_i:
            glorot(self.W_i[key])
        for key in self.W_f:
            glorot(self.W_f[key])
        for key in self.W_c:
            glorot(self.W_c[key])
        for key in self.W_o:
            glorot(self.W_o[key])
        for key in self.b_i:
            glorot(self.b_i[key])
        for key in self.b_f:
            glorot(self.b_f[key])
        for key in self.b_c:
            glorot(self.b_c[key])
        for key in self.b_o:
            glorot(self.b_o[key])

    def _set_hidden_state(self, x_dict, h_dict):
        if h_dict is None:
            h_dict = {node_type: torch.zeros(X.shape[0], self.out_channels) for node_type, X in x_dict.items()}
        return h_dict

    def _set_cell_state(self, x_dict, c_dict):
        if c_dict is None:
            c_dict = {node_type: torch.zeros(X.shape[0], self.out_channels) for node_type, X in x_dict.items()}
        return c_dict

    def _calculate_input_gate(self, x_dict, edge_index_dict, edge_attr_dict, h_dict, c_dict):
        i_dict = {node_type: torch.matmul(X, self.W_i[node_type]) for node_type, X in x_dict.items()}
        conv_i = self.conv_i(h_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        i_dict = {node_type: I + conv_i[node_type] for node_type, I in i_dict.items()}
        i_dict = {node_type: I + self.b_i[node_type] for node_type, I in i_dict.items()}
        i_dict = {node_type: torch.sigmoid(I) for node_type, I in i_dict.items()}
        return i_dict

    def _calculate_forget_gate(self, x_dict, edge_index_dict, edge_attr_dict, h_dict, c_dict):
        f_dict = {node_type: torch.matmul(X, self.W_f[node_type]) for node_type, X in x_dict.items()}
        conv_f = self.conv_f(h_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        f_dict = {node_type: F + conv_f[node_type] for node_type, F in f_dict.items()}
        f_dict = {node_type: F + self.b_f[node_type] for node_type, F in f_dict.items()}
        f_dict = {node_type: torch.sigmoid(F) for node_type, F in f_dict.items()}
        return f_dict

    def _calculate_cell_state(self, x_dict, edge_index_dict, edge_attr_dict, h_dict, c_dict, i_dict, f_dict):
        t_dict = {node_type: torch.matmul(X, self.W_c[node_type]) for node_type, X in x_dict.items()}
        conv_c = self.conv_c(h_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        t_dict = {node_type: T + conv_c[node_type] for node_type, T in t_dict.items()}
        t_dict = {node_type: T + self.b_c[node_type] for node_type, T in t_dict.items()}
        t_dict = {node_type: torch.tanh(T) for node_type, T in t_dict.items()}
        c_dict = {node_type: f_dict[node_type] * C + i_dict[node_type] * t_dict[node_type] for node_type, C in c_dict.items()}
        return c_dict

    def _calculate_output_gate(self, x_dict, edge_index_dict, edge_attr_dict, h_dict, c_dict):
        o_dict = {node_type: torch.matmul(X, self.W_o[node_type]) for node_type, X in x_dict.items()}
        conv_o = self.conv_o(h_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        o_dict = {node_type: O + conv_o[node_type] for node_type, O in o_dict.items()}
        o_dict = {node_type: O + self.b_o[node_type] for node_type, O in o_dict.items()}
        o_dict = {node_type: torch.sigmoid(O) for node_type, O in o_dict.items()}
        return o_dict

    def _calculate_hidden_state(self, o_dict, c_dict):
        h_dict = {node_type: o_dict[node_type] * torch.tanh(C) for node_type, C in c_dict.items()}
        return h_dict

    def forward(
        self,
        x_dict,
        edge_index_dict,
        edge_attr_dict,
        h_dict=None,
        c_dict=None,
    ):
        """
        Making a forward pass. If the hidden state and cell state
        matrix dicts are not present when the forward pass is called these are
        initialized with zeros.

        Arg types:
            * **x_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensors)* - Node features dicts. Can
                be obtained via PyG method :obj:`snapshot.x_dict` where snapshot is a single HeteroData object.
            * **edge_index_dict** *(Dictionary where keys=Tuples and values=PyTorch Long Tensors)* - Graph edge type
                and index dicts. Can be obtained via PyG method :obj:`snapshot.edge_index_dict`.
            * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and
                hidden state matrix dict for all nodes.
            * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and
                cell state matrix dict for all nodes.

        Return types:
            * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and
                hidden state matrix dict for all nodes.
            * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and
                cell state matrix dict for all nodes.
        """

        h_dict = self._set_hidden_state(x_dict, h_dict)
        c_dict = self._set_cell_state(x_dict, c_dict)
        i_dict = self._calculate_input_gate(x_dict, edge_index_dict, edge_attr_dict, h_dict, c_dict)
        f_dict = self._calculate_forget_gate(x_dict, edge_index_dict, edge_attr_dict, h_dict, c_dict)
        c_dict = self._calculate_cell_state(x_dict, edge_index_dict, edge_attr_dict, h_dict, c_dict, i_dict, f_dict)
        o_dict = self._calculate_output_gate(x_dict, edge_index_dict, edge_attr_dict, h_dict, c_dict)
        h_dict = self._calculate_hidden_state(o_dict, c_dict)
        return h_dict, c_dict

In [4]:
class HeteroGCLSTM_GAT(torch.nn.Module):
    r"""An implementation similar to the Integrated Graph Convolutional Long Short Term
        Memory Cell for heterogeneous Graphs.

        Args:
            in_channels_dict (dict of keys=str and values=int): Dimension of each node's input features.
            out_channels (int): Number of output features.
            metadata (tuple): Metadata on node types and edge types in the graphs. Can be generated via PyG method
                :obj:`snapshot.metadata()` where snapshot is a single HeteroData object.
            bias (bool, optional): If set to :obj:`False`, the layer will not learn
                an additive bias. (default: :obj:`True`)
    """

    def __init__(
            self,
            in_channels_dict: dict,
            out_channels: int,
            metadata: tuple,
            bias: bool = True
    ):
        super(HeteroGCLSTM_GAT, self).__init__()

        self.in_channels_dict = in_channels_dict
        self.out_channels = out_channels
        self.metadata = metadata
        self.bias = bias
        self._create_parameters_and_layers()
        self._set_parameters()

    def _create_input_gate_parameters_and_layers(self):
        self.conv_i = HeteroConv({edge_type: GATConv(in_channels=(-1, -1),
                                                     out_channels=self.out_channels,
                                                     heads=1,
                                                     # concat=False,
                                                     # edge_dim=1,
                                                     bias=self.bias,
                                                     add_self_loops=False) for edge_type in self.metadata[1]})

        self.W_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_forget_gate_parameters_and_layers(self):
        self.conv_f = HeteroConv({edge_type: GATConv(in_channels=(-1, -1),
                                                     out_channels=self.out_channels,
                                                     heads=1,
                                                     # concat=False,
                                                     # edge_dim=1,
                                                     bias=self.bias,
                                                     add_self_loops=False) for edge_type in self.metadata[1]})

        self.W_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_cell_state_parameters_and_layers(self):
        self.conv_c = HeteroConv({edge_type: GATConv(in_channels=(-1, -1),
                                                     out_channels=self.out_channels,
                                                     heads=1,
                                                     # concat=False,
                                                     # edge_dim=1,
                                                     bias=self.bias,
                                                     add_self_loops=False) for edge_type in self.metadata[1]})

        self.W_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_output_gate_parameters_and_layers(self):
        self.conv_o = HeteroConv({edge_type: GATConv(in_channels=(-1, -1),
                                                     out_channels=self.out_channels,
                                                     heads=1,
                                                     # concat=False,
                                                     # edge_dim=1,
                                                     bias=self.bias,
                                                     add_self_loops=False) for edge_type in self.metadata[1]})

        self.W_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels))
                    for node_type, in_channels in self.in_channels_dict.items()})
        self.b_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels))
                    for node_type in self.in_channels_dict})

    def _create_parameters_and_layers(self):
        self._create_input_gate_parameters_and_layers()
        self._create_forget_gate_parameters_and_layers()
        self._create_cell_state_parameters_and_layers()
        self._create_output_gate_parameters_and_layers()

    def _set_parameters(self):
        for key in self.W_i:
            glorot(self.W_i[key])
        for key in self.W_f:
            glorot(self.W_f[key])
        for key in self.W_c:
            glorot(self.W_c[key])
        for key in self.W_o:
            glorot(self.W_o[key])
        for key in self.b_i:
            glorot(self.b_i[key])
        for key in self.b_f:
            glorot(self.b_f[key])
        for key in self.b_c:
            glorot(self.b_c[key])
        for key in self.b_o:
            glorot(self.b_o[key])

    def _set_hidden_state(self, x_dict, h_dict):
        if h_dict is None:
            h_dict = {node_type: torch.zeros(X.shape[0], self.out_channels) for node_type, X in x_dict.items()}
        return h_dict

    def _set_cell_state(self, x_dict, c_dict):
        if c_dict is None:
            c_dict = {node_type: torch.zeros(X.shape[0], self.out_channels) for node_type, X in x_dict.items()}
        return c_dict

    def _calculate_input_gate(self, x_dict, edge_index_dict, h_dict, c_dict):
        i_dict = {node_type: torch.matmul(X, self.W_i[node_type]) for node_type, X in x_dict.items()}
        conv_i = self.conv_i(h_dict, edge_index_dict)
        i_dict = {node_type: I + conv_i[node_type] for node_type, I in i_dict.items()}
        i_dict = {node_type: I + self.b_i[node_type] for node_type, I in i_dict.items()}
        i_dict = {node_type: torch.sigmoid(I) for node_type, I in i_dict.items()}
        return i_dict

    def _calculate_forget_gate(self, x_dict, edge_index_dict, h_dict, c_dict):
        f_dict = {node_type: torch.matmul(X, self.W_f[node_type]) for node_type, X in x_dict.items()}
        conv_f = self.conv_f(h_dict, edge_index_dict)
        f_dict = {node_type: F + conv_f[node_type] for node_type, F in f_dict.items()}
        f_dict = {node_type: F + self.b_f[node_type] for node_type, F in f_dict.items()}
        f_dict = {node_type: torch.sigmoid(F) for node_type, F in f_dict.items()}
        return f_dict

    def _calculate_cell_state(self, x_dict, edge_index_dict, h_dict, c_dict, i_dict, f_dict):
        t_dict = {node_type: torch.matmul(X, self.W_c[node_type]) for node_type, X in x_dict.items()}
        conv_c = self.conv_c(h_dict, edge_index_dict)
        t_dict = {node_type: T + conv_c[node_type] for node_type, T in t_dict.items()}
        t_dict = {node_type: T + self.b_c[node_type] for node_type, T in t_dict.items()}
        t_dict = {node_type: torch.tanh(T) for node_type, T in t_dict.items()}
        c_dict = {node_type: f_dict[node_type] * C + i_dict[node_type] * t_dict[node_type] for node_type, C in c_dict.items()}
        return c_dict

    def _calculate_output_gate(self, x_dict, edge_index_dict, h_dict, c_dict):
        o_dict = {node_type: torch.matmul(X, self.W_o[node_type]) for node_type, X in x_dict.items()}
        conv_o = self.conv_o(h_dict, edge_index_dict)
        o_dict = {node_type: O + conv_o[node_type] for node_type, O in o_dict.items()}
        o_dict = {node_type: O + self.b_o[node_type] for node_type, O in o_dict.items()}
        o_dict = {node_type: torch.sigmoid(O) for node_type, O in o_dict.items()}
        return o_dict

    def _calculate_hidden_state(self, o_dict, c_dict):
        h_dict = {node_type: o_dict[node_type] * torch.tanh(C) for node_type, C in c_dict.items()}
        return h_dict

    def forward(
        self,
        x_dict,
        edge_index_dict,
        h_dict=None,
        c_dict=None,
    ):
        """
        Making a forward pass. If the hidden state and cell state
        matrix dicts are not present when the forward pass is called these are
        initialized with zeros.

        Arg types:
            * **x_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensors)* - Node features dicts. Can
                be obtained via PyG method :obj:`snapshot.x_dict` where snapshot is a single HeteroData object.
            * **edge_index_dict** *(Dictionary where keys=Tuples and values=PyTorch Long Tensors)* - Graph edge type
                and index dicts. Can be obtained via PyG method :obj:`snapshot.edge_index_dict`.
            * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and
                hidden state matrix dict for all nodes.
            * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and
                cell state matrix dict for all nodes.

        Return types:
            * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and
                hidden state matrix dict for all nodes.
            * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and
                cell state matrix dict for all nodes.
        """

        h_dict = self._set_hidden_state(x_dict, h_dict)
        c_dict = self._set_cell_state(x_dict, c_dict)
        i_dict = self._calculate_input_gate(x_dict, edge_index_dict, h_dict, c_dict)
        f_dict = self._calculate_forget_gate(x_dict, edge_index_dict, h_dict, c_dict)
        c_dict = self._calculate_cell_state(x_dict, edge_index_dict, h_dict, c_dict, i_dict, f_dict)
        o_dict = self._calculate_output_gate(x_dict, edge_index_dict, h_dict, c_dict)
        h_dict = self._calculate_hidden_state(o_dict, c_dict)
        return h_dict, c_dict

# Models

In [64]:
class MyModel_basic_edge(nn.Module):

    def __init__(
            self,
            input_size: dict,
            hidden_size: int,
            out_steps: int,
            node_type_to_pred: str,
            metadata: tuple,
            bias: bool=True,
            n_layers: int = 1,
        ):
        super().__init__()
        self.n_layers = n_layers
        self.node_type_to_pred = node_type_to_pred
        self.hidden_size = hidden_size
        self.heterogcnlstm_layer = HeteroGCLSTM_GAT_edge(input_size, hidden_size, metadata, bias)
        self.linear = nn.Linear(in_features=hidden_size, out_features=out_steps)

    def forward(self, graph_seq):
        n_steps = 3
        h_dict = {node_type: torch.zeros(X.shape[0], self.hidden_size) for node_type, X in graph_seq.x_dict.items()}
        c_dict = {node_type: torch.zeros(X.shape[0], self.hidden_size) for node_type, X in graph_seq.x_dict.items()}
        print("initial hidden states")
        print(h_dict["user"])
        print(c_dict["user"])
        outputs = []

        for t in range(n_steps):
            print(f"STEP NUMBER {t}")
            data = graph_seq
            for edge_type in data.edge_types:
                data[edge_type].edge_attr = torch.rand(80000)
                data[edge_type].edge_index, data[edge_type].edge_attr = change_edge_number(
                    data[edge_type].edge_index, data[edge_type].edge_attr
                )
            print("input data")
            print(data["user"])
            h_dict, c_dict = self.heterogcnlstm_layer.forward(
                data.x_dict, data.edge_index_dict, data.edge_attr_dict, h_dict, c_dict
            )
            print("hidden state for 'user'")
            print(h_dict["user"])
            outputs.append(h_dict[self.node_type_to_pred])

        return self.linear(h_dict[self.node_type_to_pred]), outputs
        # out = torch.reshape(intermediate_out, (NUM_NODES, OUT_STEPS))

In [41]:
class MyModel_basic(nn.Module):

    def __init__(
            self,
            input_size: dict,
            hidden_size: int,
            out_steps: int,
            node_type_to_pred: str,
            metadata: tuple,
            bias: bool=True,
            n_layers: int = 1,
        ):
        super().__init__()
        self.n_layers = n_layers
        self.node_type_to_pred = node_type_to_pred
        self.hidden_size = hidden_size
        self.heterogcnlstm_layer = HeteroGCLSTM_GAT(input_size, hidden_size, metadata, bias)
        self.linear = nn.Linear(in_features=hidden_size, out_features=out_steps)

    def forward(self, graph_seq):
        n_steps = 3
        h_dict = {node_type: torch.zeros(X.shape[0], self.hidden_size) for node_type, X in graph_seq.x_dict.items()}
        c_dict = {node_type: torch.zeros(X.shape[0], self.hidden_size) for node_type, X in graph_seq.x_dict.items()}
        print("initial hidden states")
        print(h_dict["user"])
        print(c_dict["user"])
        outputs = []

        for t in range(n_steps):
            print(f"STEP NUMBER {t}")
            data = graph_seq
            for edge_type in data.edge_types:
                data[edge_type].edge_attr = torch.rand(80000)
            print("input data")
            print(data["user"])
            h_dict, c_dict = self.heterogcnlstm_layer.forward(
                data.x_dict, data.edge_index_dict, h_dict, c_dict
            )
            print("hidden state for 'user'")
            print(h_dict["user"])
            outputs.append(h_dict[self.node_type_to_pred])

        return self.linear(h_dict[self.node_type_to_pred]), outputs
        # out = torch.reshape(intermediate_out, (NUM_NODES, OUT_STEPS))

In [7]:
class MyModel_autoregressive(nn.Module):

    def __init__(
            self,
            input_size: dict,
            hidden_size: int,
            out_size: int,
            out_steps: int,
            node_type_to_pred: str,
            metadata: tuple,
            bias: bool=True,
            n_layers: int = 1
        ):
        super().__init__()
        self.n_layers = n_layers
        self.out_steps = out_steps
        self.node_type_to_pred = node_type_to_pred
        self.hidden_size = hidden_size
        self.heterogcnlstm_layer = HeteroGCLSTM_GAT(input_size, hidden_size, metadata, bias)
        self.linear_1 = {
            node_type: nn.Linear(
                in_features=hidden_size,
                out_features=input_size[node_type]
            ) for node_type, dim in input_size.items()
        } 
        self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_size)

    def warmup(self, graph_seq):
        n_steps = 3
        h_dict = {node_type: torch.zeros(X.shape[0], self.hidden_size) for node_type, X in graph_seq.x_dict.items()}
        c_dict = {node_type: torch.zeros(X.shape[0], self.hidden_size) for node_type, X in graph_seq.x_dict.items()}
        edges = []
        
        for t in range(n_steps):
            data = graph_seq
            h_dict, c_dict = self.heterogcnlstm_layer.forward(
                data.x_dict, data.edge_index_dict, h_dict, c_dict
            )
            edges.append(data.edge_index_dict)

        pred = {node_type: self.linear_1[node_type](X) for node_type, X in h_dict.items()}
        pred_to_out = self.linear_2(h_dict[self.node_type_to_pred])
        return pred, edges[-1], h_dict, c_dict, pred_to_out

    def forward(self, graph_seq):

        predictions = []
        prediction, edges, h, c, prediction_to_out = self.warmup(graph_seq)
        predictions.append(prediction_to_out)
        print("prediction after warmap")
        print(prediction_to_out.shape)
        
        for t in range(1, self.out_steps):
            x = prediction
            h, c = self.heterogcnlstm_layer(x, edges, h, c)
            # prediction = {node_type: self.linear_2(X) for node_type, X in h.items()}
            prediction = {node_type: self.linear_1[node_type](X) for node_type, X in h.items()}
            prediction_to_out = self.linear_2(h[self.node_type_to_pred])
            predictions.append(prediction_to_out)

        return torch.stack(predictions, dim=1)

In [1]:
# from torch_geometric.datasets import OGB_MAG
from torch_geometric.datasets import MovieLens100K

# dataset = OGB_MAG(root='./data', preprocess='metapath2vec')
dataset = MovieLens100K(root='./data')
data = dataset[0]

In [43]:
len(dataset)

1

In [57]:
# for edge_type in data.edge_types:
#     print(data[edge_type].edge_index)

In [2]:
data

HeteroData(
  movie={ x=[1682, 18] },
  user={ x=[943, 24] },
  (user, rates, movie)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
    edge_label_index=[2, 20000],
    edge_label=[20000],
  },
  (movie, rated_by, user)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  }
)

In [3]:
import torch_geometric.transforms as T

data = T.ToUndirected()(data)
data

HeteroData(
  movie={ x=[1682, 18] },
  user={ x=[943, 24] },
  (user, rates, movie)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
    edge_label_index=[2, 20000],
    edge_label=[20000],
  },
  (movie, rated_by, user)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  },
  (movie, rev_rates, user)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  },
  (user, rev_rated_by, movie)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  }
)

In [45]:
data["user", "rates", "movie"]["rating"][:5]

tensor([5, 3, 4, 3, 3])

In [12]:
# data["user", "rates", "movie"].edge_attr = data["user", "rates", "movie"]["rating"].float()
# data["movie", "rated_by", "user"].edge_attr = data["movie", "rated_by", "user"]["rating"].float()

In [46]:
data.edge_index_dict

{('user',
  'rates',
  'movie'): tensor([[   0,    0,    0,  ...,  942,  942,  942],
         [   0,    1,    2,  ..., 1187, 1227, 1329]]),
 ('movie',
  'rated_by',
  'user'): tensor([[   0,    1,    2,  ..., 1187, 1227, 1329],
         [   0,    0,    0,  ...,  942,  942,  942]])}

**Случайный атрибуы ребер**

In [49]:
tmp = torch.rand(80000)
tmp

tensor([0.0130, 0.0379, 0.3303,  ..., 0.0781, 0.9283, 0.6880])

In [34]:
# data.edge_attr_dict

In [111]:
# data.edge_attr_dict = {
#     ('user', 'rates', 'movie'): torch.reshape(data["user", "rates", "movie"]["rating"].float(), (-1, 1)),
#     ('movie', 'rated_by', 'user'): torch.reshape(data["movie", "rated_by", "user"]["rating"].float(), (-1, 1))
# }

In [50]:
# data.edge_attr_dict['user',
#   'rates',
#   'movie'].dim()

In [137]:
# data["paper"]["train_mask"].unsqueeze(-1)[1050:1060, :]

In [138]:
# data["paper"]["x"][1050:1060, :5]

In [139]:
# tmp = (data["paper"]["x"] * data["paper"]["train_mask"].unsqueeze(-1))[1050:1060, :5]

# mask = tmp != 0
# mask

In [6]:
# data["author", "affiliated_with", "institution"]["edge_index"][:, :10]

In [7]:
# data["paper", "cites", "paper"]["edge_index"][:, :10]

In [51]:
data.metadata()

(['movie', 'user'],
 [('user', 'rates', 'movie'), ('movie', 'rated_by', 'user')])

In [10]:
# import torch_geometric.transforms as T

# data_undirected = T.ToUndirected()(data)
# data_undirected

In [66]:
# input_dims = {
#     "paper": 128,
#     "author": 128,
#     "institution": 128,
#     "field_of_study": 128
# }

input_dims = {
    "movie": 18,
    "user": 24,
}

# meta = data_undirected.metadata()
meta = data.metadata()

hidden_size = 32

# layer = HeteroGCLSTM_SAGE(in_channels_dict=input_dims, out_channels=32, metadata=meta)
layer = HeteroGCLSTM_GAT(in_channels_dict=input_dims, out_channels=hidden_size, metadata=meta)

In [67]:
model_basic = MyModel_basic_edge(input_dims, hidden_size, 3, "user", meta)

In [68]:
final_out, outputs = model_basic(data)

initial hidden states
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
STEP NUMBER 0
input data
{'x': tensor([[0.3288, 0.0000, 1.0000,  ..., 0.0000, 1.0000, 0.0000],
        [0.7260, 1.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3151, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 1.0000],
        ...,
        [0.2740, 0.0000, 1.0000,  ..., 1.0000, 0.0000, 0.0000],
        [0.6575, 1.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3014, 0.0000, 1.0000,  ..., 1.0000, 0.0000, 0.0000]])}
hidden state for 'user'
tensor([[ 0.0375

In [37]:
print(final_out.shape)
print(final_out)

torch.Size([943, 3])
tensor([[ 0.2007, -0.0007, -0.0098],
        [ 0.2084,  0.0436,  0.1494],
        [ 0.1952, -0.0385,  0.0323],
        ...,
        [ 0.2839, -0.0335,  0.1205],
        [ 0.2694,  0.0938,  0.1230],
        [ 0.2828, -0.0307,  0.1216]], grad_fn=<AddmmBackward0>)


In [38]:
model_autoregressive = MyModel_autoregressive(input_dims, hidden_size, 1, 3, "user", meta)

In [39]:
preds = model_autoregressive(data)

prediction after warmap
torch.Size([943, 1])


In [40]:
preds.squeeze().shape

torch.Size([943, 3])

In [12]:
# data_undirected.x_dict

In [13]:
# data_undirected.edge_index_dict

In [15]:
# h_out, c_out = layer.forward(data_undirecte.x_dict, data_undirected.edge_index_dict)
h_out, c_out = layer.forward(data.x_dict, data.edge_index_dict)

In [16]:
h_out

{'movie': tensor([[-0.1559, -0.0837, -0.0984,  ..., -0.1043, -0.0718,  0.0447],
         [-0.0432,  0.0050, -0.1661,  ..., -0.1044, -0.1881, -0.0776],
         [-0.0599, -0.0981, -0.0705,  ..., -0.0822, -0.1083,  0.0332],
         ...,
         [-0.0996, -0.0671, -0.0393,  ...,  0.0235, -0.1518, -0.0476],
         [-0.1389, -0.0848, -0.0360,  ..., -0.1173, -0.0162,  0.0766],
         [-0.1194, -0.0333, -0.0662,  ..., -0.0258, -0.0711,  0.0015]],
        grad_fn=<MulBackward0>),
 'user': tensor([[-0.0247,  0.0794, -0.1272,  ...,  0.0662,  0.1731, -0.0283],
         [ 0.1836,  0.0941, -0.0265,  ..., -0.0192,  0.1276, -0.0270],
         [ 0.0782, -0.0103, -0.0802,  ...,  0.1032,  0.1847,  0.0144],
         ...,
         [ 0.0989,  0.0677, -0.0168,  ...,  0.1402,  0.0903, -0.0371],
         [ 0.0691,  0.1197, -0.0362,  ...,  0.0220,  0.1665,  0.0121],
         [ 0.0969,  0.0681, -0.0183,  ...,  0.1403,  0.0911, -0.0366]],
        grad_fn=<MulBackward0>)}

In [17]:
h_out["user"]

tensor([[-0.0247,  0.0794, -0.1272,  ...,  0.0662,  0.1731, -0.0283],
        [ 0.1836,  0.0941, -0.0265,  ..., -0.0192,  0.1276, -0.0270],
        [ 0.0782, -0.0103, -0.0802,  ...,  0.1032,  0.1847,  0.0144],
        ...,
        [ 0.0989,  0.0677, -0.0168,  ...,  0.1402,  0.0903, -0.0371],
        [ 0.0691,  0.1197, -0.0362,  ...,  0.0220,  0.1665,  0.0121],
        [ 0.0969,  0.0681, -0.0183,  ...,  0.1403,  0.0911, -0.0366]],
       grad_fn=<MulBackward0>)

In [18]:
h_out["user"].shape

torch.Size([943, 32])

In [19]:
time_steps = 10
num_nodes = h_out["user"].shape[0]

linear_layer = nn.Linear(in_features=hidden_size, out_features=1)
out = linear_layer(h_out["user"])
out.shape

torch.Size([943, 1])