In [1]:
%config InlineBackend.figure_format ='retina'
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [2]:
import torch
import torch_geometric
import networkx as nx
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib
import numpy as np
from scipy import sparse
from random import shuffle
import itertools
import multiprocessing

In [3]:
class BeliefPropagation():
    
    def __init__():
        pass

In [4]:
beta = 1.0
sizes = [500, 500]
Q = len(sizes)
P = [[0.005, 0.0005], [0.0005, 0.005]]
G = nx.stochastic_block_model(sizes, P)
N = G.number_of_nodes()

In [5]:
G.number_of_edges()

1307

In [6]:
G.node[0]

{'block': 0}

In [7]:
G.number_of_nodes()

1000

# Initialize beta (inverse temperature)
TODO: extend to weighted version
$$
\beta^{*}(q, c)=\log \left(\frac{q}{\sqrt{c}-1}+1\right)
$$

## weighted version by Shi Chen et al.
$$
\hat{c}=\sum_{d=1}^{\infty} \frac{d p(d)}{c}(d-1)=\frac{\left\langle d^{2}\right\rangle}{c}-1
$$
The **spin-glass transition temperature** can be obtained by solving the following equations.
$$
\left\langle\left(\frac{e^{\beta^{*} \omega_{i j}}-1}{e^{\beta^{*} \omega_{i j}}+q-1}\right)^{2}\right\rangle_{\omega_{i j}} \hat{c}=1
$$

In [8]:
mean_degree = np.asarray(G.degree())[:, 1].mean()
beta_ast = np.log(Q / (np.sqrt(mean_degree) - 1) + 1)
beta = torch.tensor([beta_ast], dtype=torch.float, requires_grad=True)

print(mean_degree, beta)

2.614 tensor([1.4452], requires_grad=True)


# Initialize psi and message_map

In [9]:
# initialize by ramdom messages
marginal_psi = torch.rand(N, Q) # the marginal probability of node i (in N) at block q (in Q), normalized
marginal_psi = marginal_psi / marginal_psi.sum(1).reshape(-1, 1)
marginal_psi = torch.tensor(marginal_psi, dtype=torch.float)
# message_map = np.zeros((N, N, Q)) # TODO: reduce memory concumption
message_map = []
for i in range(N):
    message_map_at_i = torch.rand(len(list(G.neighbors(i))), Q)
    message_map_at_i = message_map_at_i / message_map_at_i.sum(1).reshape(-1, 1)
    message_map.append(message_map_at_i)

  after removing the cwd from sys.path.


In [10]:
marginal_psi

tensor([[0.1549, 0.8451],
        [0.6262, 0.3738],
        [0.5680, 0.4320],
        ...,
        [0.0165, 0.9835],
        [0.5599, 0.4401],
        [0.6956, 0.3044]])

In [11]:
message_map[0]

tensor([[0.3843, 0.6157],
        [0.5279, 0.4721]])

# Initialize h (external field)
non-weighted:
$$
\theta_{t}=\sum_{j=1}^{n} d_{j} \psi_{t}^{j}
$$
weighted:
$$
h(t)=-\beta \overline{\omega} \sum_{i} \psi_{t}^{i}
$$

In [12]:
W = nx.to_numpy_matrix(G)

In [13]:
W.mean() * N

2.614

In [13]:
h = np.empty(Q)
for q in range(Q):
    h_q = -beta * W.mean() * marginal_psi[:, q].sum()
    h[q] = h_q
h = torch.tensor(h, dtype=torch.float)

In [14]:
h

tensor([-1.9162, -1.8819])

# BP converge
$$
\psi_{t_{i}}^{i \rightarrow k} \approx \frac{e^{h\left(t_{i}\right)}}{Z_{i \rightarrow k}} \prod_{j \in \partial i | k}\left(1+\psi_{t_{i}}^{j \rightarrow i}\left(e^{\beta \omega_{i j}}-1\right)\right)
$$
$$
h(t)=-\beta \overline{\omega} \sum_{i} \psi_{t}^{i}
$$
$$
\psi_{t_{i}}^{i}=\frac{e^{h\left(t_{i}\right)}}{Z_{i}} \prod_{j \in \partial i}\left(1+\psi_{t_{i}}^{j \rightarrow i}\left(e^{\beta \omega_{i j}}-1\right)\right)
$$

In [15]:
(i, j) = list(G.edges())[0]

In [16]:
i, j

(0, 292)

In [17]:
def update_message_i_to_j(i, j, learning_rate):
    message_i_to_j = torch.zeros(Q)
    i_to_j = list(G.neighbors(i)).index(j)
    # all neighbors except j
    neighbors = list(G.neighbors(i))
    neighbors.remove(j)
    # sum all message to i
    
    for q in range(Q):
        this_value = 1.0
        for k in neighbors:
            i_to_k = list(G.neighbors(i)).index(k)
            k_to_i = list(G.neighbors(k)).index(i)
#             print(i, i_to_k, k, k_to_i)
            this_value *= (1 + message_map[k][k_to_i][q].clone() * 
                           (torch.exp(beta * W[i, k]) - 1))
        this_value *= torch.exp(h[q])
        message_i_to_j[q] = this_value    
    message_i_to_j = message_i_to_j.clone() / message_i_to_j.clone().sum()
    
    diff = torch.abs(message_i_to_j.clone() - message_map[i][i_to_j].clone()).sum()
    print("message_i_to_j: ", i, j, message_i_to_j)
    return diff, message_i_to_j

In [18]:
update_message_i_to_j(i, j, 0.1)

message_i_to_j:  0 292 tensor([0.7812, 0.2188], grad_fn=<DivBackward0>)


(tensor(1.0133, grad_fn=<SumBackward0>),
 tensor([0.7812, 0.2188], grad_fn=<DivBackward0>))

In [19]:
def update_marginal_psi(i):
    marginal_psi_i = torch.zeros(Q)
    neighbors = list(G.neighbors(i))
    for q in range(Q):
        this_value = 1.0
        for j in neighbors:
            j_to_i = list(G.neighbors(j)).index(i)
            this_value *= (1 + message_map[j][j_to_i][q].clone() * 
                           (torch.exp(beta * W[i, j]) - 1))
        this_value *= torch.exp(h[q])
        marginal_psi_i[q] = this_value
    marginal_psi_i = marginal_psi_i.clone() / marginal_psi_i.clone().clone().sum()
    print("marginal_psi_i:", i, marginal_psi_i)
    return marginal_psi_i

In [20]:
update_marginal_psi(j)

marginal_psi_i: 292 tensor([0.2132, 0.7868], grad_fn=<DivBackward0>)


tensor([0.2132, 0.7868], grad_fn=<DivBackward0>)

In [21]:
def bp_iter_step(i, j, learning_rate, h):
    
    print("bp_iter_step:", i, j)
    
    diff, message_i_to_j = update_message_i_to_j(i, j, learning_rate)
    marginal_psi_i = update_marginal_psi(i)
    
    i_to_j = list(G.neighbors(i)).index(j)
    
    message_map[i][i_to_j] = learning_rate * message_i_to_j.clone() + \
                        (1 - learning_rate) * message_map[i][i_to_j].clone()
    h -= -beta * W.mean() * marginal_psi[j].clone()
    marginal_psi[j] = marginal_psi_i.clone()
    h += -beta * W.mean() * marginal_psi[j].clone()
    
    return diff

In [22]:
diff = 0

In [17]:
torch.max(torch.tensor([10, 22]))

tensor(22)

In [23]:
diff += bp_iter_step(i, j, 0.1, h)

bp_iter_step: 0 292
message_i_to_j:  0 292 tensor([0.7812, 0.2188], grad_fn=<DivBackward0>)
marginal_psi_i: 0 tensor([0.7143, 0.2857], grad_fn=<DivBackward0>)


In [24]:
diff.backward(retain_graph=True)

In [25]:
diff += bp_iter_step(j, i, 0.1, h)

bp_iter_step: 292 0
message_i_to_j:  292 0 tensor([0.3165, 0.6835], grad_fn=<DivBackward0>)
marginal_psi_i: 292 tensor([0.2344, 0.7656], grad_fn=<DivBackward0>)


In [26]:
diff.backward(retain_graph=True)

In [27]:
beta.grad

tensor([0.8912])

In [28]:
diff

tensor(1.0783, grad_fn=<AddBackward0>)

In [38]:
_, assignment = torch.max(marginal_psi, 1)

In [41]:
assignment[2]

tensor(1)

In [32]:
indices.shape

torch.Size([1000])

In [None]:
max_num_iter = 5
bp_conf = 0.1
learning_rate = 0.5

for num_iter in range(max_num_iter):
    diff = 0
    job_list = list(G.edges())[:]
    shuffle(job_list)
    for i, j in job_list:
        diff += bp_iter_step(i, j, learning_rate, h)
#         diff = update_message_i_to_j(i, j, learning_rate)
#         max_diff = diff if diff > max_diff else max_diff
#         if max_diff < bp_conf:
#             return num_iter
    if diff < bp_conf:
        break

In [34]:
%%timeit
marginal_psi.sum().backward(retain_graph=True)

43.6 s ± 1.74 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [31]:
beta.grad

tensor([0.8912])

$$
\mathcal{L}_{\mathrm{reg}}(\mathbf{S})=\sum_{i=1}^{N} \sum_{j=1}^{N} a_{i j} \cdot\left\|\mathbf{s}_{i}-\mathbf{s}_{j}\right\|^{2}=\operatorname{tr}\left(\mathbf{S} \mathbf{L} \mathbf{S}^{\top}\right)
$$

In [45]:
torch.pow((marginal_psi[0] - marginal_psi[1]), 2).sum()

tensor(0.4856)