In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch_geometric.transforms as T
from torch import lgamma
from torch_geometric.data import DataLoader
from torch_scatter import scatter_mean
import argparse
import numpy as np
import random
import ogb
from ogb.graphproppred import PygGraphPropPredDataset

parser = argparse.ArgumentParser(description='PyTorch implementation of relative positional encodings for graph Transformers')
args = parser.parse_args("")
args.device = 3
args.device = torch.device('cuda:'+ str(args.device) if torch.cuda.is_available() else 'cpu')
print("device:", args.device)
torch.manual_seed(0)
np.random.seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed = 0
set_seed(seed)

device: cuda:3


In [None]:
class MultiHeadedAttention_RPR(nn.Module):
    def __init__(self, d_model, h, vocab_size=2, dropout=.0):
        """
        multi-head attention
        :param h: nhead
        :param d_model: d_model
        :param dropout: float
        """
        super(MultiHeadedAttention_RPR, self).__init__()
        assert d_model % h == 0
        #  assume d_v always equals d_k
        
        self.d_k = d_model // h
        self.h = h
        self.linears = nn.ModuleList([nn.Linear(d_model, d_model) for i in range(4)])
        self.dropout = nn.Dropout(p=dropout)

        self.vocab_size = vocab_size
        self.embed_K = nn.Embedding(self.vocab_size, self.d_k)
        self.embed_V = nn.Embedding(self.vocab_size, self.d_k)

    def forward(self, query, key, value, rpr_matrix, mask=None):
        nbatches = query.size(0)  # batch size
        seq_len = query.size(1)
        # 1) split embedding dim to h heads : from d_model => h * d_k
        # dim: (nbatch, h, seq_length, d_model//h)
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        # 2) rpr
        assert rpr_matrix.size(0) == seq_len and rpr_matrix.size(1) == seq_len
        relation_keys = self.embed_K(rpr_matrix)
        relation_values = self.embed_V(rpr_matrix)
        logits = self._relative_attn_inner(query, key, relation_keys, True)
        weights = self.dropout(F.softmax(logits, -1))
        x = self._relative_attn_inner(weights, value, relation_values, False)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

    def _relative_attn_inner(self, x, y, z, transpose):
        nbatches = x.size(0)
        heads = x.size(1)
        seq_len = x.size(2)

        # (N, h, s, s)
        xy_matmul = torch.matmul(x, y.transpose(-1, -2) if transpose else y)
        # (s, N, h, d) => (s, N*h, d)
        x_t_v = x.permute(2, 0, 1, 3).contiguous().view(seq_len, nbatches * heads, -1)
        # (s, N*h, d) @ (s, d, s) => (s, N*h, s)
        x_tz_matmul = torch.matmul(x_t_v, z.transpose(-1, -2) if transpose else z)
        # (N, h, s, s)
        x_tz_matmul_v_t = x_tz_matmul.view(seq_len, nbatches, heads, -1).permute(1, 2, 0, 3)
        return xy_matmul + x_tz_matmul_v_t

In [None]:
import torch
import copy
from torch import Tensor
import torch.nn.functional as F
import torch.nn.Module as Module
import torch.nn.ModuleList as ModuleList
from torch.nn.init import xavier_uniform_
import torch.nn.Dropout as Dropout
import torch.nn.Linear as Linear
import torch.nn.LayerNorm as LayerNorm

def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])

def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

class TransformerEncoderLayer(Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.0, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiHeadedAttention_RPR(d_model, nhead, 2, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src: Tensor, rpr_matrix: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
        Shape:
            see the docs in Transformer class.
        """
        src2 = self.self_attn(src, src, src, rpr_matrix, attn_mask=src_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class TransformerEncoder(Module):
    r"""TransformerEncoder is a stack of N encoder layers

    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, rpr_matrix: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = src

        for mod in self.layers:
            output = mod(output, rpr_matrix, src_mask=mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder

class TransformerModel(nn.Module):

    def __init__(self, nclasses, ninp, nhead, nhid, nlayers, dropout=0.0):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        # self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = AtomEncoder(emb_dim=ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, nclasses)

        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, rpr_matrix, src_mask=None):
        src = self.encoder(src)
        # src = self.pos_encoder(src)
        output = self.transformer_encoder(src, rpr_matrix, src_mask)
        output = F.sigmoid(self.decoder(output))
        return output

In [None]:
args.dataset = 'ogbg-moltox21'
args.n_classes = 1
args.batch_size = 128

print("Loading data...")
print("dataset: {} ".format(args.dataset))
dataset = PygGraphPropPredDataset(name=args.dataset).shuffle()

loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=True)

emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.0 # the dropout value
model = TransformerModel(args.n_classes, emsize, nhead, nhid, nlayers, dropout).to(args.device)

for idx, data in enumerate(loader):
    
    data = data.to(args.device)
    
    
    model(data)