**ATTENTION ASED TEMPORAL DEPENDENCY LEARNING**

In [3]:
import torch
import torch.nn as nn
import numpy as np

We are trying to learn temporal dependency among products bought and as we know from everyday life, some elements (products) will appear in our basket quite frequently and regularly, while others will appear irregularly and occasionally. This makes the temporal dependency learning really hard! <br>
Previous models, such as RNNs, fail to model this kind of data,  because they do not take into account the temporal dependency learning. We are going to construct our model, using self attention, so that we will not loose this information, we are going to construct a temporal dependeny learning component.

**Input** to this component will be the sequences we have construccted in the previous step (weighted GCN), where the output were the sequences of embeddings $\mathbb{C}_i = \{C_{i,1},...,C_{i,|\mathcal{V}_i|}\} $ where $C_{i,j}=\{c_{i,j}^1,...,c_{i,j}^T\}$, are the representations of $v_{i,j}$ over time.

The **output** of this component will be $\mathbb{Z}_i = \{z_{i,1},...,z_{i,|\mathcal{V}_i|}\}$, where $z_{i,j} \in \mathbb{R}^{F''} $ are the representations of $v_{i,j} \in \mathcal{V}_i$.

$C_{i,j} \in \mathbb{R}^{T \times F'}
\xrightarrow[\text{}]{\text{temporal dependency}}
Z_{i,j} \in \mathbb{R}^{T \times F''}$

$c_{i,j}^t \in \mathcal{R}^{F'}$ <br>
$z_{i,j} \in \mathcal{R}^{F''}$

$Z_{i,j} = softmax\left( \frac{(C_{i,j}W_q) \cdot (C_{i,j}W_k)^T}{\sqrt{F''}} + M_i \right) \cdot (C_{i,j}W_v)$, where $W_q \in \mathbb{R}^{F' \times F''}$, $W_k \in \mathbb{R}^{F' \times F''} $, $W_v \in \mathbb{R}^{F' \times F''}$ are trainable parameters, $Z_{i,j} \in \mathbb{R}^{T \times F''}$ is the stacked representation of $v_{i,j}$'s sequence, $M_i \in \mathbb{R}^{T \times T}$ is a masked matrix, which is used to avoid the future information leakage and guarantee that the state of each timestamp is only affected by its previous states, $\begin{equation}
  M_i^{t,t'}=\begin{cases}
    0, & \text{if $t<t'$},\\
    -\infty, & \text{otherwise}.
  \end{cases}
\end{equation}$

We get the final representation by the following equation $z_{i,j} = \left( (Z_{i,j} \cdot w_{agg})^T \cdot Z_{i,j}\right)^T$, where $w_{agg}$ is a trainable parameter to learn the importance of different timestamps adaptively. <br>
We finally get $z_{i,j} \in \mathbb{R}^{F''}$ as the compact representation for element $v_{i,j}$ that now considers all the possible temporal dependencies.

In [None]:
class masked_self_attention(nn.Module):

    def __init__(self, input_dim, output_dim, n_heads=4, attention_aggregate="concat"):
        super(masked_self_attention, self).__init__()
        # aggregate multi-heads by concatenation or mean
        self.attention_aggregate = attention_aggregate

        # the dimension of each head is dq // n_heads
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.n_heads = n_heads

        if attention_aggregate == "concat":
            self.per_head_dim = self.dq = self.dk = self.dv = output_dim // n_heads
        elif attention_aggregate == "mean":
            self.per_head_dim = self.dq = self.dk = self.dv = output_dim
        else:
            raise ValueError(f"wrong value for aggregate {attention_aggregate}")

        # inicialization of the weights as described above in the text
        self.Wq = nn.Linear(input_dim, n_heads * self.dq, bias=False)
        self.Wk = nn.Linear(input_dim, n_heads * self.dk, bias=False)
        self.Wv = nn.Linear(input_dim, n_heads * self.dv, bias=False)

    def forward(self, input_tensor):
        """
        Args:
            input_tensor: tensor, shape (nodes_num, T_max, features_num)
        Returns:
            output: tensor, shape (nodes_num, T_max, output_dim = features_num)
        """
        seq_length = input_tensor.shape[1]
        # tensor, shape (nodes_num, T_max, n_heads * dim_per_head)
        Q = self.Wq(input_tensor)
        K = self.Wk(input_tensor)
        V = self.Wv(input_tensor)
        # multi_head attention
        # Q, tensor, shape (nodes_num, n_heads, T_max, dim_per_head)
        Q = Q.reshape(input_tensor.shape[0], input_tensor.shape[1], self.n_heads, self.dq).transpose(1, 2)
        # K after transpose, tensor, shape (nodes_num, n_heads, dim_per_head, T_max)
        K = K.reshape(input_tensor.shape[0], input_tensor.shape[1], self.n_heads, self.dk).permute(0, 2, 3, 1)
        # V, tensor, shape (nodes_num, n_heads, T_max, dim_per_head)
        V = V.reshape(input_tensor.shape[0], input_tensor.shape[1], self.n_heads, self.dv).transpose(1, 2)

        # scaled attention_score, tensor, shape (nodes_num, n_heads, T_max, T_max)
        attention_score = Q.matmul(K) / np.sqrt(self.per_head_dim)

        # attention_mask, tensor, shape -> (T_max, T_max)  -inf in the top and right
        attention_mask = torch.zeros(seq_length, seq_length).masked_fill(
            torch.tril(torch.ones(seq_length, seq_length)) == 0, -np.inf).to(input_tensor.device)
        # attention_mask will be broadcast to (nodes_num, n_heads, T_max, T_max)
        attention_score = attention_score + attention_mask
        # (nodes_num, n_heads, T_max, T_max)
        attention_score = torch.softmax(attention_score, dim=-1)

        # multi_result, tensor, shape (nodes_num, n_heads, T_max, dim_per_head)
        multi_head_result = attention_score.matmul(V)
        if self.attention_aggregate == "concat":
            # multi_result, tensor, shape (nodes_num, T_max, n_heads * dim_per_head = output_dim)
            # concat multi-head attention results
            output = multi_head_result.transpose(1, 2).reshape(input_tensor.shape[0],
                                                               seq_length, self.n_heads * self.per_head_dim)
        elif self.attention_aggregate == "mean":
            # multi_result, tensor, shape (nodes_num, T_max, dim_per_head = output_dim)
            # mean multi-head attention results
            output = multi_head_result.transpose(1, 2).mean(dim=2)
        else:
            raise ValueError(f"wrong value for aggregate {self.attention_aggregate}")

        return output

In [4]:
class aggregate_nodes_temporal_feature(nn.Module):

    def __init__(self, item_embed_dim):
        """
        :param item_embed_dim: the dimension of input features
        """
        super(aggregate_nodes_temporal_feature, self).__init__()

        self.Wq = nn.Linear(item_embed_dim, 1, bias=False)

    def forward(self, graph, lengths, nodes_output):
        """
        :param graph: batched graphs, with the total number of nodes is `node_num`,
                        including `batch_size` disconnected subgraphs
        :param lengths: tensor, (batch_size, )
        :param nodes_output: the output of self-attention model in time dimension, (n_1+n_2+..., T_max, F)
        :return: aggregated_features, (n_1+n_2+..., F)
        """
        nums_nodes, id = graph.batch_num_nodes(), 0
        aggregated_features = []
        for num_nodes, length in zip(nums_nodes, lengths):
            # get each user's length, tensor, shape, (user_nodes, user_length, item_embed_dim)
            output_node_features = nodes_output[id:id + num_nodes, :length, :]
            # weights for each timestamp, tensor, shape, (user_nodes, 1, user_length)
            # (user_nodes, user_length, 1) transpose to -> (user_nodes, 1, user_length)
            weights = self.Wq(output_node_features).transpose(1, 2)
            # (user_nodes, 1, user_length) matmul (user_nodes, user_length, item_embed_dim)
            # -> (user_nodes, 1, item_embed_dim) squeeze to (user_nodes, item_embed_dim)
            # aggregated_feature, tensor, shape, (user_nodes, item_embed_dim)
            aggregated_feature = weights.matmul(output_node_features).squeeze(dim=1)
            aggregated_features.append(aggregated_feature)
            id += num_nodes
        # (n_1+n_2+..., item_embed_dim)
        aggregated_features = torch.cat(aggregated_features, dim=0)
        return aggregated_features