Code adapted from HGCN paper by Chami and Ying et al. 

In [None]:
import os
import math
import torch 
import numpy as np 
from sklearn.metrics import roc_auc_score, average_precision_score
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn import Parameter
from torch.nn.modules.module import Module
import torch.optim

In [None]:
#Some basic math functions
def cosh(x, clamp=15):
    return x.clamp(-clamp, clamp).cosh()


def sinh(x, clamp=15):
    return x.clamp(-clamp, clamp).sinh()


def tanh(x, clamp=15):
    return x.clamp(-clamp, clamp).tanh()


def arcosh(x):
    return Arcosh.apply(x)


def arsinh(x):
    return Arsinh.apply(x)


def artanh(x):
    return Artanh.apply(x)


class Artanh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        x = x.clamp(-1 + 1e-15, 1 - 1e-15)
        ctx.save_for_backward(x)
        z = x.double()
        return (torch.log_(1 + z).sub_(torch.log_(1 - z))).mul_(0.5).to(x.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output / (1 - input ** 2)


class Arsinh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        z = x.double()
        return (z + torch.sqrt_(1 + z.pow(2))).clamp_min_(1e-15).log_().to(x.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output / (1 + input ** 2) ** 0.5


class Arcosh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        x = x.clamp(min=1.0 + 1e-15)
        ctx.save_for_backward(x)
        z = x.double()
        return (z + torch.sqrt_(z.pow(2) - 1)).clamp_min_(1e-15).log_().to(x.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output / (input ** 2 - 1) ** 0.5



In [None]:
#Operations on a Lorentz model of the Hyperbolic space
eps = {torch.float32: 1e-7, torch.float64: 1e-15}
min_norm = 1e-15
max_norm = 1e6

def minkowski_dot(x, y, keepdim=True):
    res = torch.sum(x * y, dim=-1) - 2 * x[..., 0] * y[..., 0]
    if keepdim:
        res = res.view(res.shape + (1,))
    return res

def minkowski_norm(u, keepdim=True):
    dot = minkowski_dot(u, u, keepdim=keepdim)
    return torch.sqrt(torch.clamp(dot, min=eps[u.dtype]))

def sqdist(x, y, c):
    K = 1. / c
    prod = minkowski_dot(x, y)
    theta = torch.clamp(-prod / K, min=1.0 + eps[x.dtype])
    sqdist = K * arcosh(theta) ** 2
        # clamp distance to avoid nans in Fermi-Dirac decoder
    return torch.clamp(sqdist, max=50.0)

def proj(x, c):
    K = 1. / c
    d = x.size(-1) - 1
    y = x.narrow(-1, 1, d)
    y_sqnorm = torch.norm(y, p=2, dim=1, keepdim=True) ** 2 
    mask = torch.ones_like(x)
    mask[:, 0] = 0
    vals = torch.zeros_like(x)
    vals[:, 0:1] = torch.sqrt(torch.clamp(K + y_sqnorm, min=eps[x.dtype]))
    return vals + mask * x

def proj_tan(u, x, c):
    K = 1. / c
    d = x.size(1) - 1
    ux = torch.sum(x.narrow(-1, 1, d) * u.narrow(-1, 1, d), dim=1, keepdim=True)
    mask = torch.ones_like(u)
    mask[:, 0] = 0
    vals = torch.zeros_like(u)
    vals[:, 0:1] = ux / torch.clamp(x[:, 0:1], min=eps[x.dtype])
    return vals + mask * u

def proj_tan0(u):
    narrowed = u.narrow(-1, 0, 1)
    vals = torch.zeros_like(u)
    vals[:, 0:1] = narrowed
    return u - vals

def expmap(u, x, c):
    K = 1. / c
    sqrtK = K ** 0.5
    normu = minkowski_norm(u)
    normu = torch.clamp(normu, max=max_norm)
    theta = normu / sqrtK
    theta = torch.clamp(theta, min=min_norm)
    result = cosh(theta) * x + sinh(theta) * u / theta
    return proj(result, c)
        
def logmap(x, y, c):
    K = 1. / c
    xy = torch.clamp(minkowski_dot(x, y) + K, max=-eps[x.dtype]) - K
    u = y + xy * x * c
    normu = minkowski_norm(u)
    normu = torch.clamp(normu, min=min_norm)
    dist = sqdist(x, y, c) ** 0.5
    result = dist * u / normu
    return proj_tan(result, x, c)

def expmap0(u, c):
    K = 1. / c
    sqrtK = K ** 0.5
    d = u.size(-1) - 1
    x = u.narrow(-1, 1, d).view(-1, d)
    x_norm = torch.norm(x, p=2, dim=1, keepdim=True)
    x_norm = torch.clamp(x_norm, min=min_norm)
    theta = x_norm / sqrtK
    res = torch.ones_like(u)
    res[:, 0:1] = sqrtK * cosh(theta)
    res[:, 1:] = sqrtK * sinh(theta) * x / x_norm
    return proj(res, c)

def logmap0(x, c):
    K = 1. / c
    sqrtK = K ** 0.5
    d = x.size(-1) - 1
    y = x.narrow(-1, 1, d).view(-1, d)
    y_norm = torch.norm(y, p=2, dim=1, keepdim=True)
    y_norm = torch.clamp(y_norm, min=min_norm)
    res = torch.zeros_like(x)
    theta = torch.clamp(x[:, 0:1] / sqrtK, min=1.0 + eps[x.dtype])
    res[:, 1:] = sqrtK * arcosh(theta) * y / y_norm
    return res

def mobius_add(x, y, c):
    u = logmap0(y, c)
    v = ptransp0(x, u, c)
    return expmap(v, x, c)

def mobius_matvec(m, x, c):
    u = logmap0(x, c)
    mu = u @ m.transpose(-1, -2)
    return expmap0(mu, c)

def ptransp(x, y, u, c):
    logxy = logmap(x, y, c)
    logyx = logmap(y, x, c)
    sqdist = torch.clamp(sqdist(x, y, c), min=min_norm)
    alpha = minkowski_dot(logxy, u) / sqdist
    res = u - alpha * (logxy + logyx)
    return proj_tan(res, y, c)

def ptransp0(x, u, c):
    K = 1. / c
    sqrtK = K ** 0.5
    x0 = x.narrow(-1, 0, 1)
    d = x.size(-1) - 1
    y = x.narrow(-1, 1, d)
    y_norm = torch.clamp(torch.norm(y, p=2, dim=1, keepdim=True), min=min_norm)
    y_normalized = y / y_norm
    v = torch.ones_like(x)
    v[:, 0:1] = - y_norm 
    v[:, 1:] = (sqrtK - x0) * y_normalized
    alpha = torch.sum(y_normalized * u[:, 1:], dim=1, keepdim=True) / sqrtK
    res = u - alpha * v
    return proj_tan(res, x, c)

In [None]:
#TODO These functions are not probably correct. Check and fix 
def _lambda_x(x, c):
        x_sqnorm = torch.sum(x.data.pow(2), dim=-1, keepdim=True)
        return 2 / (1. - c * x_sqnorm).clamp_min(1e-15)
    
def egrad2rgrad(p, dp, c):
        lambda_p = _lambda_x(p, c)
        dp /= lambda_p.pow(2)
        return dp
    
def inner(x, c, u, v=None, keepdim=False):
        if v is None:
            v = u
        lambda_x = _lambda_x(x, c)
        return lambda_x ** 2 * (u * v).sum(dim=-1, keepdim=keepdim)
    
def to_poincare(x, c):
        K = 1. / c
        sqrtK = K ** 0.5
        d = x.size(-1) - 1
        return sqrtK * x.narrow(-1, 1, d) / (x[:, 0:1] + sqrtK)
    
#TODO Write a an inverse function of the above 

In [None]:
class HypLinear(nn.Module):
    """
    Hyperbolic linear layer.
    """

    def __init__(self, in_features, out_features, use_bias):
        super(HypLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.c = nn.Parameter(torch.Tensor([1.0]))
        #self.dropout = dropout
        self.use_bias = use_bias
        self.bias = nn.Parameter(torch.Tensor(out_features))
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        init.xavier_uniform_(self.weight, gain=math.sqrt(2))
        init.constant_(self.bias, 0)

    def forward(self, x):
        drop_weight = F.dropout(self.weight, .4, training=self.training)
        mv = mobius_matvec(self.weight, x, self.c)
        res = proj(mv, self.c)
        if self.use_bias:
            bias = proj_tan0(self.bias.view(1, -1))
            hyp_bias = expmap0(bias, self.c)
            hyp_bias = proj(hyp_bias, self.c)
            res = mobius_add(res, hyp_bias, c=self.c)
            res = proj(res, self.c)
        return res

In [None]:
class HypAct(nn.Module):
    """
    Hyperbolic activation layer.
    """

    def __init__(self, c_in, c_out, act):
        super(HypAct, self).__init__()

        self.c_in = c_in
        self.c_out = c_out
        self.act = act

    def forward(self, x):
        xt = self.act(logmap0(x, c=self.c_in))
        xt = proj_tan0(xt)
        return proj(expmap0(xt, c=self.c_out), c=self.c_out)

    def extra_repr(self):
        return 'c_in={}, c_out={}'.format(
            self.c_in, self.c_out
        )

In [None]:
class HypAgg(nn.Module):
    """
    Hyperbolic aggregation layer.
    """

    def __init__(self, c, in_features, dropout, use_att):
        super(HypAgg, self).__init__()
        self.c = c

        self.in_features = in_features
        self.dropout = dropout
        self.use_att = use_att
        if self.use_att:
            raise NotImplementedError

    def forward(self, x, adj):
        x_tangent = logmap0(x, c=self.c)
        if self.use_att:
            raise NotImplementedError
        else:
            support_t = torch.spmm(adj, x_tangent)
        output = proj(expmap0(support_t, c=self.c), c=self.c)
        return output

    def extra_repr(self):
        return 'c={}'.format(self.c)

In [None]:
class GraphEncoder(nn.Module):
    """
    Hyperbolic Graph Encoder model for node classification
    """
    def __init__(self, in_feats, h1_feats, h2_feats, out_feats):
        super(GraphEncoder, self).__init__()
        
        self.in_feats = in_feats
        self.h1_feats = h1_feats
        self.h2_feats = h2_feats
        self.out_feats = out_feats
        
        
        #self.c_out = nn.Parameter(torch.Tensor([1.0]))
        
        self.linear1 = HypLinear(self.in_feats+1, self.h1_feats, use_bias=True)
        self.linear2 = HypLinear(self.h1_feats, self.h2_feats, use_bias=True)
        
        self.agg1 = HypAgg(self.linear1.state_dict()['c'], self.h1_feats, False, False)
        self.activation1 = HypAct(self.linear1.state_dict()['c'], self.linear2.state_dict()['c'], nn.Tanh())
        
        self.agg2 = HypAgg(self.linear2.state_dict()['c'], self.h2_feats, False, False)
        self.activation2 = HypAct(self.linear2.state_dict()['c'], self.linear2.state_dict()['c'], nn.Tanh())
        
        self.linear_out = nn.Linear(self.h2_feats, self.out_feats)
        
        
    def forward(self, x, adj):
        #Intreprating the vectors as elements of the tangent space, i.e. R^{n+1}
        o = torch.zeros_like(x)
        x = torch.cat([o[:, 0:1], x], dim=1)
        
        #Projecting to hyperbolic coordinates
        x_tan = proj_tan0(x)
        x_hyp = expmap0(x_tan, self.linear1.state_dict()['c'])
        x_hyp = proj(x_hyp, self.linear1.state_dict()['c'])
        
        #Encoder
        x = self.linear1(x_hyp)
        x = self.agg1(x, adj)
        x = self.activation1(x)
        
        x = self.linear2(x)
        x = self.agg2(x,adj)
        x = self.activation2(x)
        
        #Decoder
        h = proj_tan0(logmap0(x, self.linear2.state_dict()['c']))
        x = self.linear_out(x)
        x = F.log_softmax(x, dim=1)
        return x
        

In [None]:
#Initialize the model 
ge = GraphEncoder(32,7,5,2)

In [None]:
#Trying to figure the Riemannian Adam 
class OptimMixin(object):
    def __init__(self, *args, stabilize=None, **kwargs):
        self._stabilize = stabilize
        super().__init__(*args, **kwargs)

    def stabilize_group(self, group):
        pass

    def stabilize(self):
        """Stabilize parameters if they are off-manifold due to numerical reasons
        """
        for group in self.param_groups:
            self.stabilize_group(group)


def copy_or_set_(dest, source):
    """
    A workaround to respect strides of :code:`dest` when copying :code:`source`
    (https://github.com/geoopt/geoopt/issues/70)
    Parameters
    ----------
    dest : torch.Tensor
        Destination tensor where to store new data
    source : torch.Tensor
        Source data to put in the new tensor
    Returns
    -------
    dest
        torch.Tensor, modified inplace
    """
    if dest.stride() != source.stride():
        return dest.copy_(source)
    else:
        return dest.set_(source)


class RiemannianAdam(OptimMixin, torch.optim.Adam):
    r"""Riemannian Adam with the same API as :class:`torch.optim.Adam`
    Parameters
    ----------
    params : iterable
        iterable of parameters to optimize or dicts defining
        parameter groups
    lr : float (optional)
        learning rate (default: 1e-3)
    betas : Tuple[float, float] (optional)
        coefficients used for computing
        running averages of gradient and its square (default: (0.9, 0.999))
    eps : float (optional)
        term added to the denominator to improve
        numerical stability (default: 1e-8)
    weight_decay : float (optional)
        weight decay (L2 penalty) (default: 0)
    amsgrad : bool (optional)
        whether to use the AMSGrad variant of this
        algorithm from the paper `On the Convergence of Adam and Beyond`_
        (default: False)
    Other Parameters
    ----------------
    stabilize : int
        Stabilize parameters if they are off-manifold due to numerical
        reasons every ``stabilize`` steps (default: ``None`` -- no stabilize)
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    """

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments
        ---------
        closure : callable (optional)
            A closure that reevaluates the model
            and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()
        with torch.no_grad():
            for group in self.param_groups:
                if "step" not in group:
                    group["step"] = 0
                betas = group["betas"]
                weight_decay = group["weight_decay"]
                eps = group["eps"]
                learning_rate = group["lr"]
                amsgrad = group["amsgrad"]
                for point in group["params"]:
   #                 print(point)
                    grad = point.grad
                    if grad is None:
                        continue
#                     if isinstance(point, (ManifoldParameter)):
#                         manifold = point.manifold
#                         c = point.c
#                     else:
#                         manifold = _default_manifold
#                         c = None
                    if grad.is_sparse:
                        raise RuntimeError(
                                "Riemannian Adam does not support sparse gradients yet (PR is welcome)"
                        )

                    state = self.state[point]

                    # State initialization
                    if len(state) == 0:
                        state["step"] = 0
                        # Exponential moving average of gradient values
                        state["exp_avg"] = torch.zeros_like(point)
                        # Exponential moving average of squared gradient values
                        state["exp_avg_sq"] = torch.zeros_like(point)
                        if amsgrad:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state["max_exp_avg_sq"] = torch.zeros_like(point)
                    # make local variables for easy access
                    exp_avg = state["exp_avg"]
                    exp_avg_sq = state["exp_avg_sq"]
                    # actual step
                    grad.add_(weight_decay, point)
             #How to get the curvature ?       
                    c = point.item()
                    
                    grad = egrad2rgrad(point, grad, c)
                    exp_avg.mul_(betas[0]).add_(1 - betas[0], grad)
                    exp_avg_sq.mul_(betas[1]).add_(
                            1 - betas[1], inner(point, c, grad, keepdim=True)
                    )
                    if amsgrad:
                        max_exp_avg_sq = state["max_exp_avg_sq"]
                        # Maintains the maximum of all 2nd moment running avg. till now
                        torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                        # Use the max. for normalizing running avg. of gradient
                        denom = max_exp_avg_sq.sqrt().add_(eps)
                    else:
                        denom = exp_avg_sq.sqrt().add_(eps)
                    group["step"] += 1
                    bias_correction1 = 1 - betas[0] ** group["step"]
                    bias_correction2 = 1 - betas[1] ** group["step"]
                    step_size = (
                        learning_rate * bias_correction2 ** 0.5 / bias_correction1
                    )

                    # copy the state, we need it for retraction
                    # get the direction for ascend
                    direction = exp_avg / denom
                    # transport the exponential averaging to the new point
                    new_point = proj(expmap(-step_size * direction, point, c), c)
                    exp_avg_new = manifold.ptransp(point, new_point, exp_avg, c)
                    # use copy only for user facing point
                    copy_or_set_(point, new_point)
                    exp_avg.set_(exp_avg_new)

                    group["step"] += 1
                if self._stabilize is not None and group["step"] % self._stabilize == 0:
                    self.stabilize_group(group)
        return loss

    @torch.no_grad()
    def stabilize_group(self, group):
        for p in group["params"]:
            if not isinstance(p):
                 continue
            state = self.state[p]
            if not state:  # due to None grads
                continue
#             manifold = p.manifold
#Same issue here
            c = p.item()
            exp_avg = state["exp_avg"]
            copy_or_set_(p, proj(p, c))
            exp_avg.set_(proj_tan(exp_avg, u, c))