In [7]:
#TODO: make this nicer
path = "data\\tox21_original\\tox21.sdf"

from rdkit import Chem
import pandas as pd
data_molecules = Chem.SDMolSupplier(path)

info_file = pd.read_csv("data\\tox21_original\\tox21_compoundData.csv", sep=",", header=0)

targets = info_file.to_numpy()[:, -12:]
ids = info_file["ID"].to_numpy()
to_del = []
with open("file.csv", mode = "w") as fi:
    fi.write("NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53,mol_id,smiles\n")

    for mol, target, i, infofile in zip(data_molecules, targets, ids, range(len(info_file))):
        try:
            smiles = Chem.MolToSmiles(mol)
            #fi.write(",".join([str(int(t)) if str(t)!="nan" else "" for t in target]) + "," + i + "," + smiles + "\n")
        except BaseException:
            to_del.append(infofile)


In [42]:
import pdb

In [208]:
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros




class GATConv(MessagePassing):
    r"""The graph attentional operator from the `"Graph Attention Networks"
    <https://arxiv.org/abs/1710.10903>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
        \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},

    where the attention coefficients :math:`\alpha_{i,j}` are computed as

    .. math::
        \alpha_{i,j} =
        \frac{
        \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
        [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
        \right)\right)}
        {\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
        \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
        [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
        \right)\right)}.

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
        out_channels (int): Size of each output sample.
        heads (int, optional): Number of multi-head-attentions.
            (default: :obj:`1`)
        concat (bool, optional): If set to :obj:`False`, the multi-head
            attentions are averaged instead of concatenated.
            (default: :obj:`True`)
        negative_slope (float, optional): LeakyReLU angle of the negative
            slope. (default: :obj:`0.2`)
        dropout (float, optional): Dropout probability of the normalized
            attention coefficients which exposes each node to a stochastically
            sampled neighborhood during training. (default: :obj:`0`)
        add_self_loops (bool, optional): If set to :obj:`False`, will not add
            self-loops to the input graph. (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    _alpha: OptTensor

    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, heads: int = 1, concat: bool = True,
                 negative_slope: float = 0.2, dropout: float = 0.0,
                 add_self_loops: bool = True, bias: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(GATConv, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops

        if isinstance(in_channels, int):
            self.lin_l = Linear(in_channels, heads * out_channels, bias=False)
            self.lin_r = self.lin_l
        else:
            self.lin_l = Linear(in_channels[0], heads * out_channels, False)
            self.lin_r = Linear(in_channels[1], heads * out_channels, False)

        self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_r = Parameter(torch.Tensor(1, heads, out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.lin_l.weight)
        glorot(self.lin_r.weight)
        glorot(self.att_l)
        glorot(self.att_r)
        zeros(self.bias)

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                size: Size = None, return_attention_weights=None):
        # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor]  # noqa
        r"""
        Args:
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.heads, self.out_channels
        #pdb.set_trace()

        x_l: OptTensor = None
        x_r: OptTensor = None
        alpha_l: OptTensor = None
        alpha_r: OptTensor = None
        if isinstance(x, Tensor):
            assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
            x_l = x_r = self.lin_l(x).view(-1, H, C)
            alpha_l = (x_l * self.att_l).sum(dim=-1)
            alpha_r = (x_r * self.att_r).sum(dim=-1)
        else:
            x_l, x_r = x[0], x[1]
            assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'
            x_l = self.lin_l(x_l).view(-1, H, C)
            alpha_l = (x_l * self.att_l).sum(dim=-1)
            if x_r is not None:
                x_r = self.lin_r(x_r).view(-1, H, C)
                alpha_r = (x_r * self.att_r).sum(dim=-1)

        assert x_l is not None
        assert alpha_l is not None

        #Self loops might be that nodes are also connected to themselfes
        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                num_nodes = x_l.size(0)
                if x_r is not None:
                    num_nodes = min(num_nodes, x_r.size(0))
                if size is not None:
                    num_nodes = min(size[0], size[1])
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
            elif isinstance(edge_index, SparseTensor):
                edge_index = set_diag(edge_index)

        # propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
        print("x_l:\n",x_l)
        print("x_r:\n",x_r)
        print("alpha_l:\n",alpha_l)
        print("alpha_r:\n", alpha_r)
        print("size:\n",size)
        print("edge_index:\n", edge_index)
        out = self.propagate(edge_index, x=(x_l, x_r),
                             alpha=(alpha_l, alpha_r), size=size)

        alpha = self._alpha
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out += self.bias

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
                edge_index_i: Tensor, edge_index_j: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:
        """Attention construction

        For the Sinkhorn-Knoff implementation the reparametization trick is used to keep differentiablity

        edge_index_i: Tensor
            The receiving edges of the messages
        edge_index_j: Tensor
            The sourcing edges of the messages
        """
        #index: contains the index where each x_j and alpha goes to (where the message is sent to)
        #We now need also where the message comes from to make the matrix double stochastic! (need to add it)
        #This is contained in the edge_index matrix
        
        alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index_i, ptr, size_i)
        self._alpha = alpha
        #alpha = F.dropout(alpha, p=self.dropout, training=self.training) #for now no dropout
        print("alpha:\n", alpha)

        print("alpha_shape:\n",alpha.shape)
        print("x_j:\n",x_j)
        print("x_j shape", x_j.shape)
        print("edge_index_i:\n", edge_index_i)
        print("ptr\n",ptr)
        print("size_i\n",size_i)
        print("edge_index_j:\n",edge_index_j)
        if True:
            z = torch.zeros((size_i, size_i))
            z[edge_index_j, edge_index_i] = alpha.squeeze()#maybe switch j and i here?
            sk = SinkhornKnopp()
            _ = sk.fit(z.detach().cpu().numpy())
            D1 = torch.tensor(sk._D1, dtype=torch.float32)
            D2 = torch.tensor(sk._D2, dtype=torch.float32)
            new_z = D1 @ z @ D2
            alpha = new_z[edge_index_j, edge_index_i]
            alpha = alpha.unsqueeze(-1)
            print("alpha_new:\n",alpha)
            #TODO: reparametize here
            print("z new:\n",new_z)

        return x_j * alpha.unsqueeze(-1)

    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.heads)

In [209]:
from typing import Optional, Callable, List
from torch_geometric.typing import Adj

import copy

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Sequential, Linear, BatchNorm1d, ReLU

from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge

from utils.basic_modules import BasicGNN

class GAT(BasicGNN):
    r"""The Graph Neural Network from the `"Graph Attention Networks"
    <https://arxiv.org/abs/1710.10903>`_ paper, using the
    :class:`~torch_geometric.nn.GATConv` operator for message passing.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Hidden node feature dimensionality.
        num_layers (int): Number of GNN layers.
        dropout (float, optional): Dropout probability. (default: :obj:`0.`)
        act (Callable, optional): The non-linear activation function to use.
            (default: :meth:`torch.nn.ReLU(inplace=True)`)
        norm (torch.nn.Module, optional): The normalization operator to use.
            (default: :obj:`None`)
        jk (str, optional): The Jumping Knowledge mode
            (:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"last"`).
            (default: :obj:`"last"`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.GATConv`.
    """
    def __init__(self, in_channels: int, hidden_channels: int, num_layers: int,
                 dropout: float = 0.0,
                 act: Optional[Callable] = ReLU(inplace=True),
                 norm: Optional[torch.nn.Module] = None, jk: str = 'last',
                 **kwargs):
        super().__init__(in_channels, hidden_channels, num_layers, dropout,
                         act, norm, jk)

        if 'concat' in kwargs:
            del kwargs['concat']

        if 'heads' in kwargs:
            assert hidden_channels % kwargs['heads'] == 0
        out_channels = hidden_channels // kwargs.get('heads', 1)

        self.convs.append(
            GATConv(in_channels, out_channels, dropout=dropout, **kwargs))
        for _ in range(1, num_layers):
            self.convs.append(GATConv(hidden_channels, out_channels, **kwargs))

In [214]:
gat = GAT(2, 3, 1, add_self_loops = True)

In [218]:
#data = torch.ones((5, 2))

In [219]:
#adj_data = torch.tensor([[1,2,3,1],[2,1,1,3]],dtype=torch.long)

In [220]:
gat(data, adj_data)

x_l:
 tensor([[[-1.3076,  0.9852, -0.9541]],

        [[-1.3076,  0.9852, -0.9541]],

        [[-1.3076,  0.9852, -0.9541]],

        [[-1.3076,  0.9852, -0.9541]],

        [[-1.3076,  0.9852, -0.9541]]], grad_fn=<ViewBackward>)
x_r:
 tensor([[[-1.3076,  0.9852, -0.9541]],

        [[-1.3076,  0.9852, -0.9541]],

        [[-1.3076,  0.9852, -0.9541]],

        [[-1.3076,  0.9852, -0.9541]],

        [[-1.3076,  0.9852, -0.9541]]], grad_fn=<ViewBackward>)
alpha_l:
 tensor([[-0.0963],
        [-0.0963],
        [-0.0963],
        [-0.0963],
        [-0.0963]], grad_fn=<SumBackward1>)
alpha_r:
 tensor([[0.1376],
        [0.1376],
        [0.1376],
        [0.1376],
        [0.1376]], grad_fn=<SumBackward1>)
size:
 None
edge_index:
 tensor([[1, 2, 3, 1, 0, 1, 2, 3, 4],
        [2, 1, 1, 3, 0, 1, 2, 3, 4]])
alpha:
 tensor([[0.5000],
        [0.3333],
        [0.3333],
        [0.5000],
        [1.0000],
        [0.3333],
        [0.5000],
        [0.5000],
        [1.0000]], grad_fn=<Diffe

tensor([[0.0000, 0.9852, 0.0000],
        [0.0000, 0.9852, 0.0000],
        [0.0000, 0.9852, 0.0000],
        [0.0000, 0.9852, 0.0000],
        [0.0000, 0.9852, 0.0000]], grad_fn=<AsStridedBackward>)

In [101]:
#https://github.com/btaba/sinkhorn_knopp

import warnings

import numpy as np


class SinkhornKnopp:
    """
    Sinkhorn Knopp Algorithm

    Takes a non-negative square matrix P, where P =/= 0
    and iterates through Sinkhorn Knopp's algorithm
    to convert P to a doubly stochastic matrix.
    Guaranteed convergence if P has total support.

    For reference see original paper:
        http://msp.org/pjm/1967/21-2/pjm-v21-n2-p14-s.pdf

    Parameters
    ----------
    max_iter : int, default=1000
        The maximum number of iterations.

    epsilon : float, default=1e-3
        Metric used to compute the stopping condition,
        which occurs if all the row and column sums are
        within epsilon of 1. This should be a very small value.
        Epsilon must be between 0 and 1.

    Attributes
    ----------
    _max_iter : int, default=1000
        User defined parameter. See above.

    _epsilon : float, default=1e-3
        User defined paramter. See above.

    _stopping_condition: string
        Either "max_iter", "epsilon", or None, which is a
        description of why the algorithm stopped iterating.

    _iterations : int
        The number of iterations elapsed during the algorithm's
        run-time.

    _D1 : 2d-array
        Diagonal matrix obtained after a stopping condition was met
        so that _D1.dot(P).dot(_D2) is close to doubly stochastic.

    _D2 : 2d-array
        Diagonal matrix obtained after a stopping condition was met
        so that _D1.dot(P).dot(_D2) is close to doubly stochastic.

    Example
    -------

    .. code-block:: python
        >>> import numpy as np
        >>> from sinkhorn_knopp import sinkhorn_knopp as skp
        >>> sk = skp.SinkhornKnopp()
        >>> P = [[.011, .15], [1.71, .1]]
        >>> P_ds = sk.fit(P)
        >>> P_ds
        array([[ 0.06102561,  0.93897439],
           [ 0.93809928,  0.06190072]])
        >>> np.sum(P_ds, axis=0)
        array([ 0.99912489,  1.00087511])
        >>> np.sum(P_ds, axis=1)
        array([ 1.,  1.])

    """

    def __init__(self, max_iter=1000, epsilon=1e-3):
        assert isinstance(max_iter, int) or isinstance(max_iter, float),\
            "max_iter is not of type int or float: %r" % max_iter
        assert max_iter > 0,\
            "max_iter must be greater than 0: %r" % max_iter
        self._max_iter = int(max_iter)

        assert isinstance(epsilon, int) or isinstance(epsilon, float),\
            "epsilon is not of type float or int: %r" % epsilon
        assert epsilon > 0 and epsilon < 1,\
            "epsilon must be between 0 and 1 exclusive: %r" % epsilon
        self._epsilon = epsilon

        self._stopping_condition = None
        self._iterations = 0
        self._D1 = np.ones(1)
        self._D2 = np.ones(1)

    def fit(self, P):
        """Fit the diagonal matrices in Sinkhorn Knopp's algorithm

        Parameters
        ----------
        P : 2d array-like
        Must be a square non-negative 2d array-like object, that
        is convertible to a numpy array. The matrix must not be
        equal to 0 and it must have total support for the algorithm
        to converge.

        Returns
        -------
        A double stochastic matrix.

        """
        P = np.asarray(P)
        assert np.all(P >= 0)
        assert P.ndim == 2
        assert P.shape[0] == P.shape[1]

        N = P.shape[0]
        max_thresh = 1 + self._epsilon
        min_thresh = 1 - self._epsilon

        # Initialize r and c, the diagonals of D1 and D2
        # and warn if the matrix does not have support.
        r = np.ones((N, 1))
        pdotr = P.T.dot(r)
        total_support_warning_str = (
            "Matrix P must have total support. "
            "See documentation"
        )
        if not np.all(pdotr != 0):
            warnings.warn(total_support_warning_str, UserWarning)

        c = 1 / pdotr
        pdotc = P.dot(c)
        if not np.all(pdotc != 0):
            warnings.warn(total_support_warning_str, UserWarning)

        r = 1 / pdotc
        del pdotr, pdotc

        P_eps = np.copy(P)
        while np.any(np.sum(P_eps, axis=1) < min_thresh) \
                or np.any(np.sum(P_eps, axis=1) > max_thresh) \
                or np.any(np.sum(P_eps, axis=0) < min_thresh) \
                or np.any(np.sum(P_eps, axis=0) > max_thresh):

            c = 1 / P.T.dot(r)
            r = 1 / P.dot(c)

            self._D1 = np.diag(np.squeeze(r))
            self._D2 = np.diag(np.squeeze(c))
            P_eps = self._D1.dot(P).dot(self._D2)

            self._iterations += 1

            if self._iterations >= self._max_iter:
                self._stopping_condition = "max_iter"
                break

        if not self._stopping_condition:
            self._stopping_condition = "epsilon"

        self._D1 = np.diag(np.squeeze(r))
        self._D2 = np.diag(np.squeeze(c))
        P_eps = self._D1.dot(P).dot(self._D2)

        return P_eps

In [120]:
sh = SinkhornKnopp()
mat = np.random.uniform(0, 1, size=(15,15))
mat

array([[0.77047748, 0.12255091, 0.29305959, 0.40155392, 0.93531142,
        0.58960897, 0.089743  , 0.38033475, 0.54087382, 0.04370785,
        0.99758221, 0.68192138, 0.82704384, 0.56828631, 0.97440023],
       [0.98092976, 0.12752975, 0.37398163, 0.73441382, 0.66602245,
        0.01436947, 0.45737407, 0.68147954, 0.88430475, 0.71784137,
        0.4528176 , 0.06150887, 0.10395306, 0.99527805, 0.64350764],
       [0.45413436, 0.34270935, 0.52220594, 0.25037549, 0.19730073,
        0.04870015, 0.74208657, 0.04089297, 0.70621456, 0.22730178,
        0.17932379, 0.59211356, 0.70461578, 0.7215873 , 0.27473461],
       [0.2507698 , 0.05314883, 0.48538292, 0.93623987, 0.03893189,
        0.24950645, 0.11417394, 0.58147414, 0.77541077, 0.90110477,
        0.87222485, 0.73209217, 0.90234108, 0.62504812, 0.7592827 ],
       [0.14085701, 0.19139359, 0.75797466, 0.71468478, 0.03009083,
        0.20164331, 0.36195132, 0.97064605, 0.04649732, 0.21018688,
        0.61119584, 0.82771367, 0.74387093, 

In [121]:
res = sh.fit(mat)
print(sh._iterations)
res

2


array([[0.08087882, 0.02049092, 0.03033349, 0.05142852, 0.10369476,
        0.12204   , 0.01230698, 0.04313213, 0.0617261 , 0.00435555,
        0.12617845, 0.07703239, 0.08494819, 0.05866753, 0.12278618],
       [0.11377027, 0.02355985, 0.04276936, 0.10392427, 0.08158409,
        0.00328621, 0.06930083, 0.08538936, 0.1115041 , 0.07903661,
        0.06328136, 0.00767702, 0.0117972 , 0.11352489, 0.0895946 ],
       [0.06773911, 0.08142376, 0.07680482, 0.04556508, 0.03108204,
        0.01432348, 0.1446058 , 0.00658967, 0.11452222, 0.03218599,
        0.03222958, 0.09504391, 0.10283908, 0.10585217, 0.04919327],
       [0.0272189 , 0.00918881, 0.05194831, 0.12398461, 0.004463  ,
        0.0533999 , 0.01618968, 0.06818453, 0.09150086, 0.0928495 ,
        0.11407369, 0.08551166, 0.09583339, 0.06672132, 0.09893185],
       [0.02027748, 0.04388672, 0.10759247, 0.1255265 , 0.00457505,
        0.05723777, 0.06807095, 0.15095816, 0.00727715, 0.02872435,
        0.10601752, 0.12822714, 0.10478134, 

In [122]:
print(np.sum(res, axis=1), np.sum(res, axis=0))

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] [1.00000998 1.0000764  0.99997218 1.00006148 0.9999322  0.99972499
 1.00025883 0.99994672 1.00004345 1.00002224 0.99992801 1.00008741
 1.00001301 1.00002716 0.99989593]


In [124]:
np.allclose(sh._D1 @ mat @ sh._D2, res)

True