In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch_geometric.transforms as T
from torch import lgamma
from torch_geometric.data import DataLoader
from torch_scatter import scatter_mean
import argparse
import numpy as np
import random
import ogb
from ogb.graphproppred import PygGraphPropPredDataset

parser = argparse.ArgumentParser(description='PyTorch implementation of relative positional encodings for graph Transformers')
args = parser.parse_args("")
args.device = 3
args.device = torch.device('cuda:'+ str(args.device) if torch.cuda.is_available() else 'cpu')
print("device:", args.device)
torch.manual_seed(0)
np.random.seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed = 0
set_seed(seed)

device: cuda:3


In [None]:
class MultiHeadedAttention_RPR(nn.Module):
    def __init__(self, d_model, h, max_relative_position, dropout=.0):
        """
        multi-head attention
        :param h: nhead
        :param d_model: d_model
        :param dropout: float
        """
        super(MultiHeadedAttention_RPR, self).__init__()
        assert d_model % h == 0
        #  assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = utils.clones(nn.Linear(d_model, d_model), 4)
        self.dropout = nn.Dropout(p=dropout)

        self.max_relative_position = max_relative_position
        self.vocab_size = max_relative_position * 2 + 1
        self.embed_K = nn.Embedding(self.vocab_size, self.d_k)
        self.embed_V = nn.Embedding(self.vocab_size, self.d_k)

    def forward(self, query, key, value, mask=None):
        """
        ---------------------------
        L : target sequence length
        S : source sequence length:
        N : batch size
        E : embedding dim
        ---------------------------
        :param query: (N,L,E)
        :param key: (N,S,E)
        :param value: (N,S,E)
        :param mask:
        """
        nbatches = query.size(0)  # batch size
        seq_len = query.size(1)
        # 1) split embedding dim to h heads : from d_model => h * d_k
        # dim: (nbatch, h, seq_length, d_model//h)
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        # 2) rpr
        relation_keys = self.generate_relative_positions_embeddings(seq_len, seq_len, self.embed_K)
        relation_values = self.generate_relative_positions_embeddings(seq_len, seq_len, self.embed_V)
        logits = self._relative_attn_inner(query, key, relation_keys, True)
        weights = self.dropout(F.softmax(logits, -1))
        x = self._relative_attn_inner(weights, value, relation_values, False)
        # 3) "Concat" using a view and apply a final linear.
        # dim: (nbatch, h, d_model)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

    def _generate_relative_positions_matrix(self, len_q, len_k):
        """
        genetate rpr matrix
        ---------------------------
        :param len_q: seq_len
        :param len_k: seq_len
        :return: rpr matrix, dim: (len_q, len_q)
        """
        assert len_q == len_k
        range_vec_q = range_vec_k = torch.arange(len_q)
        distance_mat = range_vec_k.unsqueeze(0) - range_vec_q.unsqueeze(-1)
        disntance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
        return disntance_mat_clipped + self.max_relative_position

    def generate_relative_positions_embeddings(self, len_q, len_k, embedding_table):
        """
        generate relative position embedding
        ----------------------
        :param len_q:
        :param len_k:
        :return: rpr embedding, dim: (len_q, len_q, d_k)
        """
        relative_position_matrix = self._generate_relative_positions_matrix(len_q, len_k)
        return embedding_table(relative_position_matrix)

    def _relative_attn_inner(self, x, y, z, transpose):
        """
        efficient implementation
        ------------------------
        :param x: 
        :param y: 
        :param z: 
        :param transpose: 
        :return: 
        """
        nbatches = x.size(0)
        heads = x.size(1)
        seq_len = x.size(2)

        # (N, h, s, s)
        xy_matmul = torch.matmul(x, y.transpose(-1, -2) if transpose else y)
        # (s, N, h, d) => (s, N*h, d)
        x_t_v = x.permute(2, 0, 1, 3).contiguous().view(seq_len, nbatches * heads, -1)
        # (s, N*h, d) @ (s, d, s) => (s, N*h, s)
        x_tz_matmul = torch.matmul(x_t_v, z.transpose(-1, -2) if transpose else z)
        # (N, h, s, s)
        x_tz_matmul_v_t = x_tz_matmul.view(seq_len, nbatches, heads, -1).permute(1, 2, 0, 3)
        return xy_matmul + x_tz_matmul_v_t

In [None]:
class TransformerEncoderLayer(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiHeadedAttention_RPR(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

[docs]    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        from torch.nn import TransformerEncoder
        self.model_type = 'Transformer'
        # self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output