# Estimate coordinates of nodes with adjacency matrix by GCNN

## Import packages

In [None]:
import datetime
import math
import os
import os.path as osp
import random
import time

import igraph
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from joblib import Parallel, delayed
import matplotlib as mpl
from matplotlib import pyplot as plt
from natsort import natsorted
from scipy.io import mmread
from torch_geometric.data import Data, DataLoader, Batch, Dataset, InMemoryDataset
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import Compose
from torch_scatter import scatter

# random seed
torch.manual_seed(1234)

np.random.seed(1234)
random.seed(1234)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)

## Set variables

In [None]:
is_trial = False

K, L0 = 1.0, 1.0
EPS = 1e-6

n_cpu_cores = os.cpu_count()
n_jobs = int(n_cpu_cores * 0.5)

data_size_type = "small_" if is_trial else "large_"

root = osp.join(
    "data",
    data_size_type + "reconstruction_Dataset"
)

os.makedirs("params", exist_ok=True)
model_path = 'params/model.pth'

if is_trial:
    epoch_num = 500
    batch_size = 3072
else:
    epoch_num = 500
    batch_size = 20

print("n_jobs = {}".format(n_jobs))

## Define dataset

In [None]:
class MyDataset(Dataset):

    processed_file_name = "data_{}.pt"

    def __init__(self, root, transform=None, pre_transform=None):
        super(MyDataset, self).__init__(root, transform, pre_transform)

    @property
    def processed_file_names(self):
        num_graph = len(os.listdir(osp.join(self.raw_dir, "adjs")))
        return [self.processed_file_name.format(i) for i in range(num_graph)]

    def process(self):
        adj_file_names = natsorted(os.listdir(osp.join(self.raw_dir, "adjs")))
        coords_file_names = natsorted(os.listdir(osp.join(self.raw_dir, "coords")))
        num_samples = len(adj_file_names)

        # joblibで並列化するとなぜかselfによる参照ができないのでこのようにして対処した
        raw_dir = self.raw_dir
        pre_transform = self.pre_transform
        processed_dir = self.processed_dir

        def generate_Data(index):
        #for index in range(num_samples):
            adj_file_name = adj_file_names[index]
            coords_file_name = coords_file_names[index]
            #adj_coo = mmread(osp.join(self.raw_dir, "adjs", adj_file_name))
            #coords_nda = mmread(osp.join(self.raw_dir, "coords", coords_file_name))
            adj_coo = mmread(osp.join(raw_dir, "adjs", adj_file_name))
            coords_nda = mmread(osp.join(raw_dir, "coords", coords_file_name))
            num_nodes = adj_coo.shape[0]
            edge_index, edge_attr = torch_geometric.utils.from_scipy_sparse_matrix(adj_coo)

            # compute graph_dist_matrix with igraph
            edges = np.array([adj_coo.row, adj_coo.col]).T
            g = igraph.Graph(n=num_nodes,edges=edges)
            #g.es["weight"] = adj_coo.data
            g.es["weight"] = adj_coo.data ** -1
            # ここまではなぜか普通のfor文でも並列処理になるが，次の処理があると並列処理にならない．
            weighted_dist = torch.tensor(g.shortest_paths_dijkstra(weights="weight"), dtype=torch.float)

            data = Data(
                x=torch.ones((num_nodes, 1)).float(),
                edge_index=edge_index,
                #edge_attr=edge_attr.view(-1, 1).float(),
                edge_attr=edge_attr.view(-1, 1).float() ** -1,
                pos=torch.tensor(coords_nda).float(),
                dist=weighted_dist.view(-1, 1),
            )

            if pre_transform is not None:
                data = pre_transform(data)

            torch.save(
                data,
                osp.join(processed_dir, "data_{}.pt".format(index)),
            )

        Parallel(n_jobs=n_jobs)(
            [delayed(generate_Data)(i) for i in range(num_samples)]
        )

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, "data_{}.pt".format(idx)))
        return data

## Network and Loss function

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(1, 16)
        self.conv2 = GCNConv(16, 32)
        self.conv3 = GCNConv(32, 48)
        self.conv4 = GCNConv(48, 64)
        self.conv5 = GCNConv(64, 96)
        self.conv6 = GCNConv(96, 128)
        self.linear1 = torch.nn.Linear(128, 64)
        self.linear2 = torch.nn.Linear(64, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # Network does not work if a tensor whoes shape is (n, 1) is used as edge_weight. 
        # Network works is the shape of the tensor is (n, ), which might be a bug of PyG?
        edge_weight = None if data.edge_attr is None else torch.squeeze(data.edge_attr)
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv3(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv4(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv5(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv6(x, edge_index, edge_weight)
        x = F.relu(x)

        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x


class KKLoss(nn.Module):
    """Return energy of Kamada-Kawai as loss"""

    def __init__(self, K=1.0, L0=1.0, eps=1e-6):
        super().__init__()
        self.K = K
        self.L0 = L0
        self.eps = eps

    def forward(self, batch, prediction):
        """
        The Kamada-Kawai loss of the graph included in
        the batch is calculated in parallel as in the mini-batch.
        """
        data_list = batch.to_data_list()
        s = torch.tensor(0, dtype=torch.float).to(device)
        num_graphs = 0
        for h, data in enumerate(data_list):

            if data.num_edges == 0:  # skip graph which has no edge
                continue

            num_graphs += 1
            graph_dist = data.dist.view(data.num_nodes, data.num_nodes)
            k = self.K * torch.where(graph_dist != 0, graph_dist ** -2, graph_dist)

            if float("inf") in graph_dist:
                d_max = torch.unique(graph_dist, sorted=True)[-2]  # avoid inf
            else:
                d_max = torch.unique(graph_dist, sorted=True)[-1]

            L = self.L0 / d_max
            l = L * graph_dist
            l[l == float("inf")] = 0  # avoid 0 * inf = nan
            positions = prediction[torch.flatten(batch.batch == h), :]
            dx = positions[:, [0]] - positions[:, 0]
            dy = positions[:, [1]] - positions[:, 1]
            e = torch.sum(
                0.5 * k * (dx**2 + dy**2 + l**2
                - 2*l*torch.sqrt(dx**2 + dy**2 + self.eps))
                ) * 0.5
            s += e

        return s / num_graphs

## Generate DataLoader

In [None]:
train_root = osp.join(root, "train")
val_root = osp.join(root, "val")
test_root = osp.join(root, "test")

train_set = MyDataset(train_root)
val_set = MyDataset(val_root)
test_set = MyDataset(test_root)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=n_cpu_cores)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=n_cpu_cores)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=n_cpu_cores)

train_size = len(train_set)
val_size = len(val_set)
test_size = len(test_set)

## Learning

In [None]:
start = time.time()
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.MSELoss()
#criterion = KKLoss(K=K, L0=L0, eps=EPS)
history = {
    "train_loss": [],
    "val_loss": [],
}

for epoch in range(epoch_num):
    train_loss = 0.0
    model.train()
    for i, batch in enumerate(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        prediction = model(batch)
        #loss = criterion(batch, prediction)  # KKLoss
        loss = criterion(prediction, batch.pos)  # MSELoss
        loss.backward()
        optimizer.step()
        train_loss += loss.cpu().item()

        if i % 10 == 9:
            progress_bar = (
                "["
                + ("=" * ((i + 1) // 10))
                + (" " * ((train_size // 100 - (i + 1)) // 10))
                + "]"
            )
            print(
                "\repoch: {:d} loss: {:.3f}  {}".format(
                    epoch + 1, loss.cpu().item(), progress_bar,
                ),
                end="  ",
            )

    print(
        "\repoch: {:d} loss: {:.3f}".format(
            epoch + 1, train_loss / math.ceil(train_size / batch_size)
        ),
        end="  ",
    )
    history["train_loss"].append(train_loss / math.ceil(train_size / batch_size))

    batch_num = 0
    loss = 0
    with torch.no_grad():
        model.eval()
        for batch in val_loader:
            batch = batch.to(device)
            prediction = model(batch)
            #loss += criterion(batch, prediction)  # KKLoss
            loss += criterion(prediction, batch.pos)  # MSELoss
            batch_num += 1

    history["val_loss"].append(loss.cpu().item() / batch_num)
    endstr = " " * max(1, (train_size // 1000 - 39)) + "\n"
    print(f"Val Loss: {loss.cpu().item()/batch_num:.3f}", end=endstr)

print("Finished Training")
print("Saving model params")
torch.save(model.state_dict(), model_path)

elapsed_time = time.time() - start
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))

## Visualization

In [None]:
# Loss
plt.xlabel("epoch")
plt.ylabel("loss")

x = np.arange(epoch_num) + 1
plt.plot(x, history["train_loss"], label="train loss")
plt.plot(x, history["val_loss"], label="val loss")
plt.legend()
plt.show()

In [None]:
# Plot nodes
test_index = 0
nBeads = test_set[test_index].num_nodes
node_list = list(range(nBeads))
test_data = test_set[test_index].to(device)

# Estimate coordinates of nodes of test data
with torch.no_grad():
    model.eval()
    estimated_coords = model(test_data)

true_coords = test_data.pos.to("cpu").detach().numpy().copy()

G = torch_geometric.utils.to_networkx(
    test_data,
    node_attrs=["pos"],
    edge_attrs=["edge_attr"],
    to_undirected=True)

true_pos = dict(zip(node_list, true_coords))
estimated_pos = dict(zip(node_list, estimated_coords.cpu().detach().numpy()))
KK_pos = nx.kamada_kawai_layout(G,pos=nx.random_layout(G, dim=2, seed=1), weight="edge_attr", dim=2)

titles = ["True", "Estimated", "Kamada-Kawai"]
positions = [true_pos, estimated_pos, KK_pos]
for i, (t,p) in enumerate(zip(titles, positions)):
    fig = plt.figure()
    ax=fig.add_subplot(1,1,1)
    ax.set_title(t)
    ax.set_xlim(left=-1.5, right=1.5)
    ax.set_ylim(bottom=-2, top=2)
    nx.draw_networkx_nodes(G, pos=p,ax=ax, node_color="red", node_size=2)
    ax.tick_params(which="both", left=True, bottom=True, labelleft=True, labelbottom=True)
    ax.set_aspect(aspect="equal")
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.5))
    ax.xaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1))
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.5))
    ax.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1))
    plt.show()

## Simulation conditions

In [None]:
t = 20

print("=====Simulation conditions=====")
print("Network：GCNN")
print("Is trial run?: {}".format(is_trial))

print("Number of beads: {}".format(nBeads))
print("Number of samples for training: {}".format(train_size))

print("Number of epochs: {}".format(epoch_num))
print("Batch size: {}".format(batch_size))
print("Diffusion time: {}".format(t))

print("=====Results=====")
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))

## 保存したパラメータを読み込んでグラフを描画する

In [None]:
def plot_nodes(id, title, nxG, pos):
    fig = plt.figure()
    ax=fig.add_subplot(1,1,1)
    ax.set_title(title)
    ax.set_xlim(left=-1.5, right=1.5)
    ax.set_ylim(bottom=-2, top=2)
    nx.draw_networkx_nodes(nxG, pos=pos,ax=ax, node_color="red", node_size=2)
    ax.tick_params(which="both", left=True, bottom=True, labelleft=True, labelbottom=True)
    ax.set_aspect(aspect="equal")
    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(0.5))
    ax.xaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1))
    ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.5))
    ax.yaxis.set_minor_locator(mpl.ticker.MultipleLocator(0.1))
    fig.savefig("plot_{}_{:02d}.pdf".format(title, id), bbox_inches="tight")
    plt.show()

if is_trial:
    indices = [0,500,1000,1500,2000,2500,3000,3500,4000]
else:
    indices = [0,20,40,60,80,100,120,140,160]

for i, test_index in enumerate(indices):

    nBeads = test_set[test_index].num_nodes
    node_list = list(range(nBeads))
    test_data = test_set[test_index].to(device)

    if i==0:
        print("Number of beads: {}".format(nBeads))

    nxG = torch_geometric.utils.to_networkx(
        test_data,
        node_attrs=["pos"],
        edge_attrs=["edge_attr"],
        to_undirected=True)

    # 各頂点を真の座標にプロットする
    true_coords = test_data.pos.to("cpu").detach().numpy().copy()
    true_pos = dict(zip(node_list, true_coords))
    plot_nodes(i+1,"True", nxG, true_pos)


    # 平均二乗誤差を損失関数として訓練したモデルによる推定座標をプロット
    MSE_model= Net().to(device)
    MSE_filename = "model_MSE_n1000_b8_e500.pth"
    MSE_model_path = osp.join("params", MSE_filename)
    MSE_model.load_state_dict(torch.load(MSE_model_path, map_location=torch.device(device)))
    print("MSE model path: {}".format(MSE_model_path))
    with torch.no_grad():
        MSE_model.eval()
        MSE_model_coords = MSE_model(test_data)
    MSE_model_pos = dict(zip(node_list, MSE_model_coords.cpu().detach().numpy()))
    plot_nodes(i+1,"MSE_Loss", nxG, MSE_model_pos)

    # Kamada-Kawaiのアルゴリズムにおけるエネルギーを損失関数として訓練したモデルによる推定座標をプロット
    KK_model = Net().to(device)
    KK_filename = "model_KK_n1000_b8_e500.pth"  ## KKLoss
    KK_model_path = osp.join("params", KK_filename)
    KK_model.load_state_dict(torch.load(KK_model_path, map_location=torch.device(device)))
    print("KK model path: {}".format(KK_model_path))
    with torch.no_grad():
        KK_model.eval()
        KK_model_coords = KK_model(test_data)
    KK_model_pos = dict(zip(node_list, KK_model_coords.cpu().detach().numpy()))
    plot_nodes(i+1,"Kamada-Kawai_Loss", nxG, KK_model_pos)


    # 通常のKamada-Kawaiのアルゴリズムによるグラフ描画
    # 各頂点の初期配置をランダムとしたときのプロット
    KK_pos = nx.kamada_kawai_layout(nxG,pos=nx.random_layout(nxG, dim=2, seed=1), weight="edge_attr", dim=2)
    plot_nodes(i+1,"Kamada-Kawai_random", nxG, KK_pos)
    # 各頂点の初期配置を円周上としたときのプロット
    KK_pos = nx.kamada_kawai_layout(nxG,pos=nx.circular_layout(nxG, dim=2), weight="edge_attr", dim=2)
    plot_nodes(i+1,"Kamada-Kawai_circular", nxG, KK_pos)
    # 各頂点の初期配置としてMSELossで訓練した学習機による推定座標を用いたときのプロット
    KK_pos = nx.kamada_kawai_layout(nxG,pos=MSE_model_pos, weight="edge_attr", dim=2)
    plot_nodes(i+1,"Kamada-Kawai_MSE", nxG, KK_pos)

## make_dotによるモデルアーキテクチャの可視化（バックグラウンドの処理まで描画されるので見にくい）

In [None]:
from torchviz import make_dot
model = Net().to("cpu")
x = test_set[0]  #ダミーの入力を用意する
y=model(x)

model_arch = make_dot(y,params=dict(model.named_parameters()))
#
from graphviz import Source
Source(model_arch).render("model_arch")