In [18]:
import torch
import torch.nn as nn
from ddn.pytorch.node import *

class NormalizedCuts(EqConstDeclarativeNode):
    """
    A declarative node to embed Normalized Cuts into a Neural Network
    
    Normalized Cuts and Image Segmentation https://people.eecs.berkeley.edu/~malik/papers/SM-ncut.pdf
    Shi, J., & Malik, J. (2000)
    """
    def __init__(self):
        super().__init__()
        self.chunk_size = 1
        
    def objective(self, x, y):
        """
        f(W,y) = y^T * (D-W) * y / y^T * D * y
        """
        
        for i in len(x):
            # W is an NxN symmetrical matrix with W(i,j) = w_ij
            D = x[i,...].sum(1).diag() # D is an NxN diagonal matrix with d on diagonal, for d(i) = sum_j(w(i,j))
            L = D - x

            y_t = torch.t(y)

        return torch.div(torch.matmul(torch.matmul(y_t, L),y),torch.matmul(torch.matmul(y_t,D),y))
    
    def equality_constraints(self, x, y):
        """
        subject to y^T * D * 1 = 0
        """
        # Ensure correct size and shape of y... scipy minimise flattens y         
        N = x.size(dim=0)
        
        #x is an NxN symmetrical matrix with W(i,j) = w_ij
        D = x.sum(1).diag() # D is an NxN diagonal matrix with d on diagonal, for d(i) = sum_j(w(i,j))
        ONE = torch.ones(N,1)   # Nx1 vector of all ones
        y_t = torch.t(y)
        
        
        return torch.matmul(torch.matmul(y_t,D), ONE)

    def solve(self, W):
        D = torch.diag(torch.sum(W, 0))
        D_half_inv = torch.diag(1.0 / torch.sqrt(torch.sum(W, 0)))
        M = torch.matmul(D_half_inv, torch.matmul((D - W), D_half_inv))

        # M is the normalised laplacian

        (w, v) = torch.linalg.eigh(M)

        print(W)
        print(D)
        print(w)
        #find index of second smallest eigenvalue
        index = torch.argsort(w)[1]

        v_partition = v[:, index]
        # instead of the sign of a digit being the binary split, let the NN learn it
        # v_partition = torch.sign(v_partition)
    
        # return the eigenvector and a blank context
        return v_partition, _
    
    def test_nobatch(self, x, y=y):
        """
        Test the function, without any batches present
        """
        # Evaluate objective function at (xs,y):
        f = torch.enable_grad()(self.objective)(x, y=y) # b

        # Compute partial derivative of f wrt y at (xs,y):
        fY = grad(f, y, grad_outputs=torch.ones_like(f), create_graph=True)
        return fY
    
node = NormalizedCuts()
x_nobatch = torch.tensor([[0,2,0], [2,0,4], [0,4,0]], dtype=torch.float, requires_grad=True)
y_nobatch, misc = node.solve(x_nobatch)
print(x)
print(y)

# node.gradient(x,y=y)
node.test_nobatch(x_nobatch,y=y_nobatch)


tensor([[0., 2., 0.],
        [2., 0., 4.],
        [0., 4., 0.]], requires_grad=True)
tensor([[2., 0., 0.],
        [0., 6., 0.],
        [0., 0., 4.]], grad_fn=<DiagBackward0>)
tensor([-1.4901e-08,  1.0000e+00,  2.0000e+00], grad_fn=<LinalgEighBackward0>)
tensor([[0., 2., 0.],
        [2., 0., 4.],
        [0., 4., 0.]], requires_grad=True)
tensor([ 8.1650e-01, -5.1165e-08, -5.7735e-01], grad_fn=<SelectBackward0>)


(tensor([1.1921e-07, 5.0731e-01, 1.1921e-07], grad_fn=<AddBackward0>),)

In [None]:
%debug

> [0;32m/tmp/ipykernel_22578/3196724601.py[0m(54)[0;36msolve[0;34m()[0m
[0;32m     52 [0;31m        [0mD[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mdiag[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mW[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     53 [0;31m        [0mD_half_inv[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mdiag[0m[0;34m([0m[0;36m1.0[0m [0;34m/[0m [0mtorch[0m[0;34m.[0m[0msqrt[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mW[0m[0;34m,[0m [0;36m0[0m[0;34m)[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 54 [0;31m        [0mM[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mmatmul[0m[0;34m([0m[0mD_half_inv[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0mmatmul[0m[0;34m([0m[0;34m([0m[0mD[0m [0;34m-[0m [0mW[0m[0;34m)[0m[0;34m,[0m [0mD_half_inv[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;

In [9]:
# # Starting again, trying to solve it all...

# import torch
# import torch.nn as nn
# import torch.autograd as grad
        
# def objective(x, y):
#     """
#     f(W,y) = y^T * (D-W) * y / y^T * D * y
#     """
#     # W is an NxN symmetrical matrix with W(i,j) = w_ij
#     D = x.sum(1).diag() # D is an NxN diagonal matrix with d on diagonal, for d(i) = sum_j(w(i,j))
#     L = D - x

#     y_t = torch.t(y)

#     return torch.div(torch.matmul(torch.matmul(y_t, L),y),torch.matmul(torch.matmul(y_t,D),y))

# def equality_constraints(x, y):
#     """
#     subject to y^T * D * 1 = 0
#     """
#     # Ensure correct size and shape of y... scipy minimise flattens y         
#     N = x.size(dim=0)

#     #x is an NxN symmetrical matrix with W(i,j) = w_ij
#     D = x.sum(1).diag() # D is an NxN diagonal matrix with d on diagonal, for d(i) = sum_j(w(i,j))
#     ONE = torch.ones(N,1)   # Nx1 vector of all ones
#     y_t = torch.t(y)


#     return torch.matmul(torch.matmul(y_t,D), ONE)

# def solve(W):
#     D = torch.diag(torch.sum(W, 0))
#     D_half_inv = torch.diag(1.0 / torch.sqrt(torch.sum(W, 0)))
#     M = torch.matmul(D_half_inv, torch.matmul((D - W), D_half_inv))

#     # M is the normalised laplacian

#     (w, v) = torch.linalg.eigh(M)

#     #find index of second smallest eigenvalue
#     index = torch.argsort(w)[1]

#     v_partition = v[:, index]
#     # instead of the sign of a digit being the binary split, let the NN learn it
#     # v_partition = torch.sign(v_partition)

#     # return the eigenvector and a blank context
#     return v_partition, _
    
# x = torch.tensor([[0,2,0], [2,0,4], [0,4,0]]).double()
# y, ctx = solve(x)




# f = torch.enable_grad()(self.objective)(*xs, y=y) # b

# # Compute partial derivative of f wrt y at (xs,y):
# fY = grad(f, y, grad_outputs=torch.ones_like(f), create_graph=True)[0]
# fY = torch.enable_grad()(fY.reshape)(self.b, -1) # bxm

# print(x)
# print(y)



tensor([[0., 2., 0.],
        [2., 0., 4.],
        [0., 4., 0.]], dtype=torch.float64)
tensor([ 8.1650e-01,  6.1835e-17, -5.7735e-01], dtype=torch.float64)
