In [1]:
import torch
import torch.nn.functional as F

class NeighborAggregator(torch.nn.Module):
    def __init__(self, batch_size, dim, neighbor_aggregator):
        super(NeighborAggregator, self).__init__()
        self.batch_size = batch_size
        self.dim = dim
        self.neighbor_aggregator = neighbor_aggregator
        if neighbor_aggregator == 'concat':
            self.weights = torch.nn.Linear(2 * dim, dim, bias=True)
        else:
            self.weights = torch.nn.Linear(dim, dim, bias=True)

    def forward(self, self_embeddings, neighbor_embeddings, neighbor_relations, user_embeddings, activation=None):
        batch_size = user_embeddings.size(0)
        if batch_size != self.batch_size:
            self.batch_size = batch_size

        neighbors_aggregated = self._aggregate_neighbor_vectors(neighbor_embeddings, neighbor_relations, user_embeddings)

        if self.neighbor_aggregator == 'sum':
            output = (self_embeddings + neighbors_aggregated).view((-1, self.dim))
        elif self.neighbor_aggregator == 'concat':
            output = torch.cat((self_embeddings, neighbors_aggregated), dim=-1)
            output = output.view((-1, 2 * self.dim))
        elif self.neighbor_aggregator == 'neighbor':
            output = neighbors_aggregated.view((-1, self.dim))
        else:
            raise ValueError("Unknown neighbor aggregator: " + self.neighbor_aggregator)

        output = self.weights(output)
        if activation is not None:
            output = activation(output)
        return output.view((self.batch_size, -1, self.dim))

    def _aggregate_neighbor_vectors(self, neighbor_embeddings, neighbor_relations, user_embeddings):
        user_embeddings = user_embeddings.view((self.batch_size, 1, 1, self.dim))
        user_relation_scores = (user_embeddings * neighbor_relations).sum(dim=-1)
        user_relation_scores_normalized = F.softmax(user_relation_scores, dim=-1)
        user_relation_scores_normalized = user_relation_scores_normalized.unsqueeze(dim=-1)
        neighbors_aggregated = (user_relation_scores_normalized * neighbor_embeddings).sum(dim=2)
        return neighbors_aggregated
