In [1]:
import conjugate_gradient as CG
import torch
from einops import rearrange, repeat
import networkx as nx
from scipy import sparse
import scipy.sparse.linalg as splinalg
import numpy as np
from torch.autograd import grad

In [40]:
def chebyshev(x, degree):
    retvar = torch.zeros(x.size(0), degree+1).type(x.type())
    retvar[:, 0] = x * 0 + 1
    if degree > 0:
        retvar[:, 1] = x
        for ii in range(1, degree):
            retvar[:, ii+1] = 2 * x * retvar[:, ii] -  retvar[:, ii-1]

    return retvar

In [41]:
x = torch.rand(10)*2 - 1

In [42]:
x.requires_grad_(True)

tensor([ 0.6714, -0.4194,  0.2677, -0.0922, -0.5779,  0.4677,  0.9019,  0.2290,
         0.3846, -0.8832], requires_grad=True)

In [43]:
y = chebyshev(x, degree=4)

In [44]:
y

tensor([[ 1.0000,  0.6714, -0.0985, -0.8036, -0.9806],
        [ 1.0000, -0.4194, -0.6483,  0.9631, -0.1595],
        [ 1.0000,  0.2677, -0.8566, -0.7264,  0.4677],
        [ 1.0000, -0.0922, -0.9830,  0.2733,  0.9326],
        [ 1.0000, -0.5779, -0.3321,  0.9617, -0.7794],
        [ 1.0000,  0.4677, -0.5626, -0.9939, -0.3670],
        [ 1.0000,  0.9019,  0.6268,  0.2286, -0.2143],
        [ 1.0000,  0.2290, -0.8951, -0.6389,  0.6025],
        [ 1.0000,  0.3846, -0.7042, -0.9262, -0.0081],
        [ 1.0000, -0.8832,  0.5600, -0.1060, -0.3727]], grad_fn=<CopySlices>)

In [9]:
y[:, 0].sum()

tensor(10., grad_fn=<SumBackward0>)

In [10]:
z = y[:, 0].sum()

In [75]:
class Chebyshev(torch.nn.Module):
    def __init__(self, M=2):
        super().__init__()
        self.register_buffer('M', torch.tensor([float(M)]))

    def forward(self, inp):
        inp = Non_zero().apply(inp)
        return chebyshev().apply(inp, self.M)

    def init_ident(self):
        with torch.no_grad():
            self.M = torch.ones_like(self.M)
        return self

# prevents nan
class Non_zero(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, inp):
        # if  0 add 1e-7 to it
        offset = (inp == 0).float()*(1e-7)
        return inp + offset
    
    @staticmethod
    def backward(ctx, outp):
        return outp


class chebyshev(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, inp, M):
        
        indices = torch.tensor(range(inp.size()[-1])).reshape(-1,2).t()
        
        outp = torch.empty_like(inp)
        
        #reused indexing/computations
        xi = inp[..., indices[0]]  
        xj = inp[..., indices[1]]
        
        
        x_norm = torch.sqrt(xi**2 + xj**2) 

        # trig form, clamp input to acos to prevent edge case with floats
        M_angle = M * torch.acos((xi / x_norm).clamp(min=-1.,max=1.))
        chebyt_outp = torch.cos(M_angle)
        chebyu_outp = torch.sin(M_angle)
        
        # function implementation
        outp[...,indices[0]] = x_norm / torch.sqrt(M) * chebyt_outp
        outp[...,indices[1]] = xj.sign() * x_norm / torch.sqrt(M) * chebyu_outp
        
        ctx.save_for_backward(xi, xj, x_norm ** 2, M, indices, outp)
        return outp
     
    @staticmethod
    def backward(ctx, grad_L_y):
        xi, xj, x2_norm, M, indices, outp = ctx.saved_tensors
        #read grad_a_b as the derivitive of a w.r.t b

        # split function output
        yi = outp[..., indices[0]]
        yj = outp[..., indices[1]]
        
        # function gradient computation w.r.t. inputs
        grad_yi_xi = (xi * yi + M * xj * yj) / x2_norm
        grad_yj_xi = (-M * xj * yi + xi * yj) / x2_norm
        grad_yi_xj = (xj * yi + -M * xi * yj) / x2_norm
        grad_yj_xj = (M * xi * yi + xj * yj) / x2_norm
        
        # given gradients
        grad_L_yi = grad_L_y[..., indices[0]]
        grad_L_yj = grad_L_y[..., indices[1]]

        # chain rule
        grad_L_xi = grad_L_yi * grad_yi_xi + grad_L_yj * grad_yj_xi
        grad_L_xj = grad_L_yi * grad_yi_xj + grad_L_yj * grad_yj_xj
        
        # splice gradients together
        grad_L_x = torch.empty_like(grad_L_y) 
        
        grad_L_x[..., indices[0]] = grad_L_xi
        grad_L_x[..., indices[1]] = grad_L_xj
        return grad_L_x, None

In [76]:
ch = Chebyshev(M=4)

In [77]:
ch(x)

tensor([-0.2435, -0.3121,  0.0343, -0.1374, -0.3394, -0.1516,  0.2535,  0.3901,
        -0.0346,  0.4804], grad_fn=<chebyshevBackward>)

In [71]:
chebyshev().apply(x, torch.tensor([float(4)]))

tensor([-0.2435, -0.3121,  0.0343, -0.1374, -0.3394, -0.1516,  0.2535,  0.3901,
        -0.0346,  0.4804], grad_fn=<chebyshevBackward>)

In [2]:
X = torch.diag(torch.rand(10))
B = torch.rand(10,1)

In [3]:
## batchify the above
X = rearrange(X, 'm n -> 1 m n')
B = rearrange(B, 'm n -> 1 m n')

In [4]:
X.requires_grad_(False)
B.requires_grad_(True)

tensor([[[0.9773],
         [0.3413],
         [0.3760],
         [0.4546],
         [0.5687],
         [0.5831],
         [0.6299],
         [0.3112],
         [0.4156],
         [0.5679]]], requires_grad=True)

In [5]:
cg = CG.CG

In [6]:
f = cg.apply

In [7]:
s = f(X,B)

torch.Size([1, 10, 1])
torch.Size([1, 10, 1])


In [8]:
s1 = s.sum()

In [9]:
s

tensor([[[2.8478],
         [1.4370],
         [0.5429],
         [0.6381],
         [8.4948],
         [0.6338],
         [0.7111],
         [0.4144],
         [1.4915],
         [1.2664]]], grad_fn=<CGBackward>)

In [10]:
s1.backward()

(tensor([[[0.9773],
         [0.3413],
         [0.3760],
         [0.4546],
         [0.5687],
         [0.5831],
         [0.6299],
         [0.3112],
         [0.4156],
         [0.5679]]], requires_grad=True), tensor([[[2.8478],
         [1.4370],
         [0.5429],
         [0.6381],
         [8.4948],
         [0.6338],
         [0.7111],
         [0.4144],
         [1.4915],
         [1.2664]]], grad_fn=<CGBackward>))
tensor([[[ 2.9140],
         [ 4.2104],
         [ 1.4439],
         [ 1.4036],
         [14.9363],
         [ 1.0868],
         [ 1.1289],
         [ 1.3314],
         [ 3.5889],
         [ 2.2298]]])


In [12]:
CG.CG.saved_variables

<attribute 'saved_variables' of 'torch._C._FunctionBase' objects>

In [19]:
def sparse_numpy_to_torch(A):
    rows, cols = A.nonzero()
    values = A.data
    indices = np.vstack((rows, cols))
    i = torch.LongTensor(indices)
    v = torch.DoubleTensor(values)
    return torch.sparse.DoubleTensor(i, v, A.shape)

n = 50
m = 50
K = 1
As = [nx.laplacian_matrix(
    nx.gnm_random_graph(n, 20 * n)) + .1 * sparse.eye(n) for _ in range(K)]
Ms = [sparse.diags(1. / A.diagonal(), format='csc') for A in As]
A_bdiag = sparse.block_diag(As)
M_bdiag = sparse.block_diag(Ms)
Bs = [np.random.randn(n, m) for _ in range(K)]
As_torch = [None] * K
Ms_torch = [None] * K
B_torch = torch.DoubleTensor(K, n, m).requires_grad_()
A_bdiag_torch = sparse_numpy_to_torch(A_bdiag)
M_bdiag_torch = sparse_numpy_to_torch(M_bdiag)

for i in range(K):
    As_torch[i] = sparse_numpy_to_torch(As[i])
    Ms_torch[i] = sparse_numpy_to_torch(Ms[i])
    B_torch[i] = torch.tensor(Bs[i])


def A_bmm(X):
    Y = [(As_torch[i]@X[i]).unsqueeze(0) for i in range(K)]
    return torch.cat(Y, dim=0)


def M_bmm(X):
    Y = [(Ms_torch[i]@X[i]).unsqueeze(0) for i in range(K)]
    return torch.cat(Y, dim=0)


In [22]:
As_torch[i]

tensor(indices=tensor([[ 0,  0,  0,  ..., 49, 49, 49],
                       [ 0,  1,  2,  ..., 46, 47, 49]]),
       values=tensor([41.1000, -1.0000, -1.0000,  ..., -1.0000, -1.0000,
                      36.1000]),
       size=(50, 50), nnz=2050, dtype=torch.float64, layout=torch.sparse_coo)

In [8]:
CG.cg_batch(X, B, X0=B)

(tensor([[[0.4912],
          [0.9723],
          [0.6954],
          [0.8210],
          [1.7861],
          [0.8587],
          [2.3634],
          [3.5718],
          [1.5046],
          [1.9391]]]), {'niter': 7, 'optimal': True})

In [6]:
x0 = torch.eye(3)

In [9]:
repeat(x0, ' h c -> 1 h c')

tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])

In [19]:
EPS = 1e-6
def _lanczos_layer(A, num_eig_vec, mask=None, use_reorthogonalization=False):
    """ Lanczos for symmetric matrix A
    
      Args:
        A: float tensor, shape B X N X N
        mask: float tensor, shape B X N
        num_eig_vec = K
      Returns:
      T: shape B X K X K, tridiagonal matrix
      Q: shape B X N X K, orthonormal matrix
      
    """
    batch_size = A.shape[0]
    num_node = A.shape[1]
    lanczos_iter = min(num_node, num_eig_vec)

    # initialization
    alpha = [None] * (lanczos_iter + 1)
    beta = [None] * (lanczos_iter + 1)
    Q = [None] * (lanczos_iter + 2)

    beta[0] = torch.zeros(batch_size, 1, 1).to(A.device)
    Q[0] = torch.zeros(batch_size, num_node, 1).to(A.device)
    Q[1] = torch.randn(batch_size, num_node, 1).to(A.device)

    if mask is not None:
        mask = mask.unsqueeze(dim=2).float()
        Q[1] = Q[1] * mask

    Q[1] = Q[1] / torch.norm(Q[1], 2, dim=1, keepdim=True)

    # Lanczos loop
    lb = 1.0e-4
    valid_mask = []
    for ii in range(1, lanczos_iter + 1):
      z = torch.bmm(A, Q[ii])  # shape B X N X 1
      alpha[ii] = torch.sum(Q[ii] * z, dim=1, keepdim=True)  # shape B X 1 X 1
      z = z - alpha[ii] * Q[ii] - beta[ii - 1] * Q[ii - 1]  # shape B X N X 1

      if use_reorthogonalization and ii > 1:
        # N.B.: Gram Schmidt does not bring significant difference of performance
        def _gram_schmidt(xx, tt):
          # xx shape B X N X 1
          for jj in range(1, tt):
            xx = xx - torch.sum(
                xx * Q[jj], dim=1, keepdim=True) / (
                    torch.sum(Q[jj] * Q[jj], dim=1, keepdim=True) + EPS) * Q[jj]
          return xx

        # do Gram Schmidt process twice
        for _ in range(2):
          z = _gram_schmidt(z, ii)

      beta[ii] = torch.norm(z, p=2, dim=1, keepdim=True)  # shape B X 1 X 1

      # N.B.: once lanczos fails at ii-th iteration, all following iterations
      # are doomed to fail
      tmp_valid_mask = (beta[ii] >= lb).float()  # shape
      if ii == 1:
        valid_mask += [tmp_valid_mask]
      else:
        valid_mask += [valid_mask[-1] * tmp_valid_mask]

      # early stop
      Q[ii + 1] = (z * valid_mask[-1]) / (beta[ii] + EPS)

    # get alpha & beta
    alpha = torch.cat(alpha[1:], dim=1).squeeze(dim=2)  # shape B X T
    beta = torch.cat(beta[1:-1], dim=1).squeeze(dim=2)  # shape B X (T-1)

    valid_mask = torch.cat(valid_mask, dim=1).squeeze(dim=2)  # shape B X T
    idx_mask = torch.sum(valid_mask, dim=1).long()
    if mask is not None:
      idx_mask = torch.min(idx_mask, torch.sum(mask, dim=1).squeeze().long())

    for ii in range(batch_size):
      if idx_mask[ii] < valid_mask.shape[1]:
        valid_mask[ii, idx_mask[ii]:] = 0.0

    # remove spurious columns
    alpha = alpha * valid_mask
    beta = beta * valid_mask[:, :-1]

    T = []
    for ii in range(batch_size):
      T += [
          torch.diag(alpha[ii]) + torch.diag(beta[ii], diagonal=1) + torch.diag(
              beta[ii], diagonal=-1)
      ]

    T = torch.stack(T, dim=0)  # shape B X T X T
    Q = torch.cat(Q[1:-1], dim=2)  # shape B X N X T
    Q_mask = valid_mask.unsqueeze(dim=1).repeat(1, Q.shape[1], 1)

    # remove spurious rows
    for ii in range(batch_size):
      if idx_mask[ii] < Q_mask.shape[1]:
        Q_mask[ii, idx_mask[ii]:, :] = 0.0

    Q = Q * Q_mask

    # pad 0 when necessary
    if lanczos_iter < num_eig_vec:
      pad = (0, num_eig_vec - lanczos_iter, 0,
             num_eig_vec - lanczos_iter)
      T = F.pad(T, pad)
      pad = (0, self.num_eig_vec - lanczos_iter)
      Q = F.pad(Q, pad)

    return T, Q


In [22]:
A = torch.rand(30,30)
A1 = A + A.t()

In [29]:
T, Q = _lanczos_layer(A1.reshape(1,30,30), 5, use_reorthogonalization=True )

In [30]:
Q.shape

torch.Size([1, 30, 5])

In [14]:
from scipy.linalg import solve

def jacobi(A, b, x, n):

    D = np.diag(A)
    R = A - np.diagflat(D)
    
    for i in range(n):
        x = (b - np.dot(R,x))/ D
    return x


In [15]:
A = np.random.rand(20,20)

In [30]:
A1 = A + A.transpose() + 20*np.identity(20)

In [31]:
b = np.random.rand(20)

In [32]:
x = np.random.rand(20)

In [38]:
jacobi(A1,b,x, 200)

array([ 2.63298592e-02,  9.01206811e-04,  2.74598641e-02,  3.90240159e-03,
        9.35773454e-05, -9.70254066e-04,  3.48676971e-02,  2.38328224e-02,
        1.04884637e-02, -3.68662329e-03,  2.64929282e-02,  4.32761896e-03,
        2.63773283e-02,  3.00777617e-02,  1.74043244e-03,  2.86262706e-02,
        3.57581306e-02, -7.97009305e-03, -1.06529760e-03,  1.58886845e-02])

In [37]:
solve(A1,b)

array([ 2.63298592e-02,  9.01206810e-04,  2.74598641e-02,  3.90240159e-03,
        9.35773445e-05, -9.70254067e-04,  3.48676970e-02,  2.38328224e-02,
        1.04884637e-02, -3.68662329e-03,  2.64929282e-02,  4.32761895e-03,
        2.63773283e-02,  3.00777617e-02,  1.74043244e-03,  2.86262706e-02,
        3.57581306e-02, -7.97009305e-03, -1.06529760e-03,  1.58886845e-02])