In [1]:
from SBM_attributed import *
import torch
import networkx as nx
from divergences import * 
from copy import deepcopy
from sklearn.metrics import rand_score, calinski_harabasz_score, mutual_info_score, accuracy_score

In [2]:
### Generate benchmark
P = np.array([[0.8, 0.2, 0.3],[0.2, 0.7, 0.4],[0.3, 0.4, 0.6]])
c = 3
n = 100
N = c*n
delta = 10
d = 2
dim = c*d
_,A,X = generate_benchmark(P,c,delta=delta,n=n,dim=d,sample_along_direction=True)
A = torch.tensor(A)
X = torch.tensor(X)

In [3]:
### initialize variables
B = torch.eye(c,dtype=torch.float, requires_grad=True)
W = torch.zeros((N,c),dtype=torch.float, requires_grad=True)
indexes = torch.randint(low=0,high=c,size=(N,))
with torch.no_grad():
    for index,col in enumerate(indexes):
        W[index].index_fill_(0, col, 1)
mu = torch.tensor(np.random.normal(size=(c,dim)), requires_grad=True)

In [4]:
phi_net = get_phi("euclidean",elementwise=True)[0]
phi_data = get_phi("euclidean")

In [5]:
def obj_func(W,B,mu,norm):
    net_divergence = torch.sum(phi_net(A,W@B@W.T), axis=1)
    data_divergence = torch.sum(torch.multiply(W, pairwise_bregman(X, mu, phi_data)),axis=1)
    K = torch.stack((net_divergence,data_divergence),axis=-1)
    return torch.sum(torch.pow(torch.sum(torch.pow(K, norm), axis=1),1/norm))
    #return torch.linalg.matrix_norm(K,ord=norm,keepdim=False)

In [6]:
norm = -1
lr=1e-2
W_old = B_old = mu_old = classes_old = classes = None
convergence_cnt = 0
convergence_threshold = 5
iter = 0
loss = 0
max_iter = 100000
failed_to_converge = False
while True:
    iter += 1
    loss = obj_func(W,B,mu,norm)
    loss.backward()
    with torch.no_grad():
        # Save old variables
        W_old = deepcopy(W)
        B_old = deepcopy(B)
        mu_old = deepcopy(mu)
        # Gradient descent
        W -= W.grad * lr 
        B -= B.grad * lr
        mu -= mu.grad * lr
        # Set the gradients to zero
        W.grad.zero_()
        B.grad.zero_()
        mu.grad.zero_()
        # Normalize variables
        w_min = W.min()
        w_max = W.max()
        W -= w_min
        W /= (w_max - w_min)
        b_min = B.min()
        b_max = B.max()
        B -= b_min
        B /= (b_max - b_min)
        #anneal s value
        if iter % 2 == 0:
            if norm > -1.0:
                norm += -.2
            elif norm > -120.0:
                norm *= 1.06
        if (torch.isnan(torch.norm(W)) or torch.isnan(torch.norm(B)) or torch.isnan(torch.norm(mu))):
            print("variables are NAN'd, so terminating")
            mu = mu_old
            B = B_old
            W = W_old
            print("loss-",loss)
            failed_to_converge = True
            break
    
        if classes is not None:
            classes_old = classes
        classes = torch.argmax(W, axis=1)
        if classes_old is not None and classes is not None and torch.equal(classes_old, classes):
            convergence_cnt += 1
        else:
            convergence_cnt = 0
        if convergence_cnt == convergence_threshold or iter>max_iter:
            print("point assignments have converged")
            break

point assignments have converged


In [7]:
norm,iter

(-1.6894789590026928, 18)

In [8]:
pred_labels = torch.argmax(W, axis=1).detach().numpy()
true_labels = [0]*n + [1]*n + [2]*n
print(rand_score(true_labels, pred_labels),mutual_info_score(true_labels, pred_labels))

0.3311036789297659 0.0


In [9]:
W

tensor([[0.9065, 0.2673, 0.5776],
        [0.8539, 0.2917, 0.6073],
        [0.9085, 0.2491, 0.5539],
        [0.8947, 0.2388, 0.5272],
        [0.8930, 0.3153, 0.6155],
        [0.8584, 0.2073, 0.5929],
        [0.8554, 0.2311, 0.5942],
        [0.7955, 0.2204, 0.5432],
        [0.8983, 0.2643, 0.5822],
        [0.8959, 0.2593, 0.5801],
        [0.8870, 0.2276, 0.5711],
        [0.8706, 0.2319, 0.5832],
        [0.9160, 0.3232, 0.6132],
        [0.8982, 0.2385, 0.5625],
        [0.8660, 0.2273, 0.5883],
        [0.9067, 0.2977, 0.5892],
        [0.9134, 0.3031, 0.5987],
        [0.8986, 0.2620, 0.5893],
        [0.8989, 0.2503, 0.5987],
        [0.8025, 0.0115, 0.4672],
        [0.8614, 0.2199, 0.5920],
        [0.9089, 0.2726, 0.5705],
        [0.8366, 0.1068, 0.5017],
        [0.8436, 0.3409, 0.5937],
        [0.8965, 0.2437, 0.5751],
        [0.7833, 0.0000, 0.4868],
        [0.9055, 0.2993, 0.6184],
        [0.8352, 0.0993, 0.4488],
        [0.9033, 0.4257, 0.6355],
        [0.825

In [10]:
P

array([[0.8, 0.2, 0.3],
       [0.2, 0.7, 0.4],
       [0.3, 0.4, 0.6]])

In [11]:
B

tensor([[0.0000, 0.4321, 0.6950],
        [0.4969, 0.7382, 0.8861],
        [0.7943, 0.9225, 1.0000]], requires_grad=True)

In [12]:
mu

tensor([[-0.3618, -2.1854, -3.0756, -5.7328, -5.6794, -2.6108],
        [ 1.4500,  4.7212,  6.3004,  3.9607,  5.2399,  1.6690],
        [-1.3957, -2.1459, -3.6560, -5.7952, -5.8293, -2.5820]],
       dtype=torch.float64, requires_grad=True)