In [1]:
from SBM_attributed import *
import torch
import networkx as nx
from divergences import * 
from copy import deepcopy
from sklearn.metrics import rand_score, calinski_harabasz_score, mutual_info_score, accuracy_score

In [2]:
### Generate benchmark
P = np.array([[0.8, 0.2, 0.3],[0.2, 0.7, 0.4],[0.3, 0.4, 0.6]])
c = 3
n = 100
N = c*n
delta = 10
d = 2
dim = c*d
_,A,X = generate_benchmark(P,c,delta=delta,n=n,dim=d,sample_along_direction=True)
A = torch.tensor(A)
X = torch.tensor(X)

In [3]:
### initialize variables
B = torch.eye(c,dtype=torch.float, requires_grad=True)
W = torch.zeros((N,c),dtype=torch.float, requires_grad=True)
indexes = torch.randint(low=0,high=c,size=(N,))
with torch.no_grad():
    for index,col in enumerate(indexes):
        W[index].index_fill_(0, col, 1)
mu = torch.tensor(np.random.normal(size=(c,dim)), requires_grad=True)

In [4]:
phi_net = get_phi("euclidean",elementwise=True)[0]
phi_data = get_phi("euclidean")

In [5]:
def obj_func(W,B,mu,norm):
    net_divergence = torch.sum(phi_net(A,W@B@W.T), axis=1)
    data_divergence = torch.sum(torch.multiply(W, pairwise_bregman(X, mu, phi_data)),axis=1)
    K = torch.stack((net_divergence,data_divergence),axis=-1)
    return torch.sum(torch.pow(torch.sum(torch.pow(K, norm), axis=1),1/norm))
#     return torch.linalg.matrix_norm(K,ord=-np.inf,keepdim=False)

In [6]:
norm = -1
lr=1e-2
W_old = B_old = mu_old = classes_old = classes = None
convergence_cnt = 0
convergence_threshold = 5
iter = 0
loss = 0
max_iter = 100000
failed_to_converge = False
while True:
    iter += 1
    loss = obj_func(W,B,mu,norm)
    loss.backward()
    with torch.no_grad():
        # Save old variables
        W_old = deepcopy(W)
        B_old = deepcopy(B)
        mu_old = deepcopy(mu)
        # Gradient descent
        W -= W.grad * lr 
        B -= B.grad * lr
        mu -= mu.grad * lr
        # Set the gradients to zero
        W.grad.zero_()
        B.grad.zero_()
        mu.grad.zero_()
        # Normalize variables
        w_min = W.min(axis=1)[0][:, None].expand(-1,c)  
        w_max = W.max(axis=1)[0][:, None].expand(-1,c)   
        W -= w_min
        W /= (w_max - w_min)
        b_min = B.min(axis=1)[0][:, None].expand(-1,c)  
        b_max = B.max(axis=1)[0][:, None].expand(-1,c)  
        B -= b_min
        B /= (b_max - b_min)
        #anneal norm value
        if iter % 2 == 0:
            if norm > -1.0:
                norm += -.2
            elif norm > -120.0:
                norm *= 1.06
        # Check for nan values
        if (torch.isnan(torch.norm(W)) or torch.isnan(torch.norm(B)) or torch.isnan(torch.norm(mu))):
            print("variables are NAN'd, so terminating")
            mu = mu_old
            B = B_old
            W = W_old
            print("loss-",loss)
            failed_to_converge = True
            break
        #update values
        if classes is not None:
            classes_old = classes
        classes = torch.argmax(W, axis=1)
        #Check convergence
        if classes_old is not None and classes is not None and torch.equal(classes_old, classes):
            convergence_cnt += 1
        else:
            convergence_cnt = 0
        if convergence_cnt == convergence_threshold or iter>max_iter:
            print("point assignments have converged")
            break

variables are NAN'd, so terminating
loss- tensor(inf, dtype=torch.float64, grad_fn=<SumBackward0>)


In [7]:
norm,iter

(-126.00472097309353, 170)

In [8]:
pred_labels = torch.argmax(W, axis=1).detach().numpy()
true_labels = [0]*n + [1]*n + [2]*n
print(rand_score(true_labels, pred_labels),mutual_info_score(true_labels, pred_labels))

0.3311036789297659 0.0


In [9]:
B

tensor([[0.9268, 1.0000, 0.0000],
        [0.9420, 1.0000, 0.0000],
        [0.9235, 1.0000, 0.0000]], requires_grad=True)

In [10]:
mu

tensor([[ 13.4918, -46.7543,  20.9008, -51.1834,  -6.0738,  17.0653],
        [-63.7226, 112.3886,   2.5787,  -5.6345,  -0.8304,   8.2379],
        [  5.4346, -19.4572,   7.9261, -19.6438,  -3.5641,   9.3086]],
       dtype=torch.float64, requires_grad=True)