Here is a naive implementation of network reconstruction via Gumbel-softmax trick.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
%matplotlib inline
import torch
from torch.nn.functional import gumbel_softmax
import scipy
from torch.utils.data import DataLoader

Coupled Map Lattice (CML) dynamics on a random k-regular graph. Here we do not train a GNN and we assume that we know the exact form of dynamics. We use current state to predict the next state. 

In [5]:
def main():
    def logistic_map(x, lambd):
        return lambd * x * (1 - x)

    def one_step_dynamics(current_state, A, s=0.2, lambd=3.8, eps=0, device='cpu'):
        """Given x_t (n by 1) and coupling A, produce x_t+1"""
        diags = A.sum(1)
        if isinstance(current_state, np.ndarray):
            noise = np.random.randn(n, 1) * eps
            diags[np.where(diags==0)] = 1 # avoid nan
        else:
            noise = torch.randn(n, 1) * eps
            noise = noise.to(device)
            diags = torch.where(diags!=0, diags, torch.ones(n, device=device))
        next_state = (1 - s) * logistic_map(current_state, lambd=lambd) + s/diags * A @ logistic_map(current_state, lambd=lambd)
        next_state += noise
        return next_state

    def gumbel_sample(x, temp=1, hard=False, eps=1e-5):
        logp = x.view(-1, 2)
        out = gumbel_softmax(logp, temp, hard)
        out = out[:, 0].view(n, n)
        return out

    device = 'cuda:1'
    n = 10
    instances = 100
    steps = 50
    bs = 128
    lr=1e-3
    epochs = 800
    tau =5

    # data generation
    G = nx.random_regular_graph(4, n)
    A = nx.adjacency_matrix(G)
    A = A.todense() 
    A = np.array(A) # adjacency matrix A
    # For every initialization, we run 50 steps
    s = np.random.rand(n, 1)
    simulates = np.empty((instances*steps, n, 1))
    for i in range(instances):
        for j in range(steps):
            s = one_step_dynamics(s, A=A)
            simulates[i*steps+j, :, :] = s
            if (j+1) % steps == 0:
                s = np.random.rand(n, 1)
    # We use current step to predict the next step, [# data, N, 2]
    data = np.empty((1, n, 2))
    for i in range(simulates.shape[0]):
        if (i+1) % steps != 0:
            temp = np.empty((1, n, 2))
            temp[:, :, 0] = simulates[i, :, :].squeeze()
            temp[:, :, 1] = simulates[i+1, :, :].squeeze()
            data = np.concatenate((data, temp), axis=0)
    data = torch.from_numpy(data[1:, :, :]).to(torch.float32)
    data = data.to(device)
    data_loader = DataLoader(data, batch_size=bs, shuffle=True)

    x = -1 * torch.rand(n, n, 2, device=device) # unnormalized logits
    x.requires_grad = True
    optimizer = torch.optim.Adam([x], lr=lr)
    loss_arr = []
    error_arr = []

    for epoch in range(epochs):
        # Training
        for idx, data in enumerate(data_loader): # data size [bs, n, 2]
            A_p = gumbel_sample(x, hard=False, temp=tau)
            s_t = torch.unsqueeze(data[:, :, 1], 2)
            s_p = one_step_dynamics(torch.unsqueeze(data[:, :, 0], 2), A_p, device=device)
            loss = torch.mean((s_t - s_p) ** 2)
            loss_arr.append(loss.data.cpu().numpy())
    #         print('Epoch: %i\t Batch Num: %i\t Loss: %.5f' % (epoch, idx, loss.data))
            loss.backward()
            optimizer.step()
        # Test
        A_p = gumbel_sample(x, hard=True, temp=1)
        A_p = A_p.data.cpu().numpy()
        mask = np.ones((n, n))
        np.fill_diagonal(mask, np.zeros(n))
        error = np.abs(A_p * mask - A).sum()/(n**2-n)
        error_arr.append(error)
        if epoch % 100 == 0:
            print('Epoch: %i\t Error rate: %.5f' % (epoch, error))
#             print(x)
        if np.abs(error) < 0.0001:
            break
    return error

In [6]:
results_arr = []
for _ in range(10):
    result = main()
    results_arr.append(result)

Epoch: 0	 Error rate: 0.43333
Epoch: 0	 Error rate: 0.53333
Epoch: 0	 Error rate: 0.51111
Epoch: 0	 Error rate: 0.56667
Epoch: 0	 Error rate: 0.47778
Epoch: 0	 Error rate: 0.47778
Epoch: 0	 Error rate: 0.44444
Epoch: 100	 Error rate: nan
Epoch: 200	 Error rate: nan
Epoch: 300	 Error rate: nan
Epoch: 400	 Error rate: nan
Epoch: 500	 Error rate: nan
Epoch: 600	 Error rate: nan
Epoch: 700	 Error rate: nan
Epoch: 0	 Error rate: 0.47778
Epoch: 0	 Error rate: 0.58889
Epoch: 0	 Error rate: 0.38889


In [9]:
results_arr

[0.06436781609195402,
 nan,
 nan,
 nan,
 nan,
 nan,
 0.0735632183908046,
 nan,
 nan,
 nan]

In [36]:
results_arr

[0.0, nan, 0.0, nan, 0.0, nan, nan, nan, 0.0, nan]

In [4]:
results_arr

[0.0,
 0.0,
 0.0,
 nan,
 0.1111111111111111,
 0.0,
 0.022222222222222223,
 0.0,
 0.044444444444444446,
 0.1111111111111111]

In [7]:
results_arr

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, nan, 0.0, 0.0, 0.0]

In [37]:
plt.plot(loss_arr)
# plt.plot(error_arr)

NameError: name 'loss_arr' is not defined

In [38]:
plt.plot(error_arr)

NameError: name 'error_arr' is not defined

In [39]:
plt.imshow(A_p)

NameError: name 'A_p' is not defined

In [40]:
plt.imshow(A)

NameError: name 'A' is not defined

In [41]:
plt.imshow(A_p - A)

NameError: name 'A_p' is not defined