# GCNNによって隣接行列からノードの座標の予測する

## ライブラリのインポート

In [None]:
import datetime
import math
import os
import os.path as osp
import random
import time

import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from natsort import natsorted
from scipy.io import mmread
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data, DataLoader, Dataset, InMemoryDataset
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import Compose
from torch_scatter import scatter

# random seed
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)

## 変数の設定

In [None]:
test_run = True
use_InMemoryDataset = True
add_noise = True
use_graph_distance = True
drop_probability = 0.01
mean, sigma = 0, 1
pow_ = -1.0
K, L0 = 1.0, 1.0
EPS = 1e-6

n_jobs = int(os.cpu_count() * 0.8)

data_size_type = "small_" if test_run else "large_"
Dataset_type = "InMemoryDataset" if use_InMemoryDataset else "Dataset"
root = osp.join("data", data_size_type + "reconstruction_" + Dataset_type)

if test_run:
    epoch_num = 5
    batch_size = 3
else:
    epoch_num = 100
    batch_size = 128

print("n_jobs = {}".format(n_jobs))

## Dataオブジェクトを作成する関数

In [None]:
def generate_nxG(adj_coo_mats, coords_ndas, i):
    num_beads = int(np.sqrt(adj_coo_mats.shape[1]))
    # convert each row of raw data to adj
    adj_coo = adj_coo_mats.getrow(i).reshape(num_beads, num_beads)
    # convert each row of raw data to coords
    coords_nda = coords_ndas[i, :].reshape(num_beads, -1)
    x = 1.0
    pos = dict(zip(range(num_beads), coords_nda))

    nxG = nx.from_scipy_sparse_matrix(adj_coo)
    nx.set_node_attributes(nxG, x, "x")
    nx.set_node_attributes(nxG, pos, "pos")
    return nxG


def drop_edge(nxG, prob=0.01):
    num_edges = nxG.number_of_edges()
    mask = np.where(np.random.rand(num_edges) > prob, True, False)
    ebunch = np.array(nxG.edges)[~mask]
    nxG.remove_edges_from(ebunch)
    return nxG


def multiply_lognormal_noise(nxG, mean=0, sigma=1):
    size = nxG.number_of_edges()
    lognormals = np.random.lognormal(mean=mean, sigma=sigma, size=size)
    ew = nx.get_edge_attributes(nxG, "weight")
    e = list(ew.keys())
    w = np.array(list(ew.values()))
    w = w * lognormals
    ew = dict(zip(e, w))
    nx.set_edge_attributes(nxG, ew, "weight")
    return nxG


def invert_edge_attr(nxG, pow_):
    ew = nx.get_edge_attributes(nxG, "weight")
    e = list(ew.keys())
    w = np.array(list(ew.values()))
    w = w ** pow_
    ew = dict(zip(e, w))
    nx.set_edge_attributes(nxG, ew, "weight")
    return nxG


def nxG_to_Data(nxG, use_graph_distance=False):
    num_nodes = nxG.number_of_nodes()

    x = torch.tensor(
        np.array(list(nx.get_node_attributes(nxG, "x").values())).reshape(
            num_nodes, -1
        ),
        dtype=torch.float,
    )
    pos = torch.tensor(
        list(nx.get_node_attributes(nxG, "pos").values()), dtype=torch.float
    )

    nxG = nx.to_directed(nxG)  # To represent Graph with coo format
    num_edges = nxG.number_of_edges()

    edge_index = torch.tensor(np.array(nxG.edges).T, dtype=torch.long)
    edge_attr = torch.tensor(
        np.array(list(nx.get_edge_attributes(nxG, "weight").values())).reshape(
            num_edges, -1
        ),
        dtype=torch.float,
    )

    if use_graph_distance:
        graph_dist = torch.full((num_nodes, num_nodes), np.inf).float()
        dict_graph_dist = dict(nx.shortest_path_length(nxG, weight="weight"))
        for i in range(num_nodes):
            for j, d in dict_graph_dist[i].items():
                graph_dist[i][j] = d
        graph_dist = graph_dist.view(-1, 1)  # to make the dim1 the same size
        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            pos=pos,
            graph_dist=graph_dist,
        )
    else:
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

    return data

## ネットワーク, Lossの定義

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # TODO: 入力グラフのチャネル数への依存をなくす
        self.conv1 = GCNConv(1, 16)
        self.conv2 = GCNConv(16, 32)
        self.conv3 = GCNConv(32, 48)
        self.conv4 = GCNConv(48, 64)
        self.conv5 = GCNConv(64, 96)
        self.conv6 = GCNConv(96, 128)
        self.linear1 = torch.nn.Linear(128, 64)
        # TODO: 分類数への依存をなくす
        self.linear2 = torch.nn.Linear(64, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # edge_weightとして用いるedge_attrのshapeが(n,1)だとうまくいかない．
        # edge_attr.shapeが(n,)だと動く．おそらくPyG側のバグ？
        edge_weight = torch.squeeze(data.edge_attr)
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv3(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv4(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv5(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv6(x, edge_index, edge_weight)
        x = F.relu(x)

        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x


class KKLoss(nn.Module):
    """Return energy of Kamada-Kawai as loss"""

    def __init__(self, K=1.0, L0=1.0, eps=1e-6):
        super().__init__()
        self.K = K
        self.L0 = L0
        self.eps = eps

    def forward(self, batch, prediction):
        """
        The Kamada-Kawai loss of the graph included in 
        the batch is calculated in parallel as in the mini-batch.
        """
        k = torch.zeros(batch.num_nodes, batch.num_nodes)
        l = torch.zeros_like(k)
        data_list = batch.to_data_list()
        s = 0

        for data in data_list:
            num_nodes = data.num_nodes
            d = data.graph_dist.view(num_nodes, num_nodes)
            d_max = torch.unique(d, sorted=True)[-2]  # avoid inf
            L = self.L0 / d_max
            k[s : s + num_nodes, s : s + num_nodes] = self.K * torch.where(
                d != 0, d ** -2, d
            )
            l[s : s + num_nodes, s : s + num_nodes] = L * d
            l[l==float("inf")] = 0  # avoid 0 * inf = nan
            s += num_nodes

        x = prediction[:, [0]] - prediction[:, 0]
        y = prediction[:, [1]] - prediction[:, 1]
        e = 0.5 * torch.sum(
            0.5 * k * (x ** 2 + y ** 2 + l ** 2 - 2 * l * torch.sqrt(x ** 2 + y ** 2 + self.eps))
        )
        return e

## データの読み込み

In [None]:
start = time.time()

adj_path = osp.join(root, "raw", "adjMats.mtx")
coords_path = osp.join(root, "raw", "coords.mtx")

print("Reading graphs from {}".format(root))

adj_coo_mats = mmread(adj_path)
coords_ndas = mmread(coords_path)
num_samples = adj_coo_mats.shape[0]
data_list = []

elapsed_time = time.time() - start

print("Finish reading graphs from storage.")
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))
print("num_samples: {}".format(num_samples))

## データの処理

In [None]:
start = time.time()

print("Generating nx graph object")

nxG_list = Parallel(n_jobs=n_jobs)(
    [
        delayed(generate_nxG)(adj_coo_mats=adj_coo_mats, coords_ndas=coords_ndas, i=i)
        for i in range(num_samples)
    ]
)

print("Post-processing")

if add_noise:
    print("Dropping edges")

    data_list = Parallel(n_jobs=n_jobs)(
        [delayed(drop_edge)(nxG=nxG, prob=drop_probability) for nxG in nxG_list]
    )

    print("Adding noise")

    data_list = Parallel(n_jobs=n_jobs)(
        [
            delayed(multiply_lognormal_noise)(nxG=nxG, mean=mean, sigma=sigma)
            for nxG in nxG_list
        ]
    )

print("Inverting edge_attr")

nxG_list = Parallel(n_jobs=n_jobs)(
    [delayed(invert_edge_attr)(nxG=nxG, pow_=pow_) for nxG in nxG_list]
)

print("Finish post-processing!")
print("Converting to Data object")

data_list = Parallel(n_jobs=n_jobs)(
    [
        delayed(nxG_to_Data)(nxG, use_graph_distance=use_graph_distance)
        for nxG in nxG_list
    ]
)

elapsed_time = time.time() - start
print("Finish generating Data objects")
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))

## DataLoaderの作成

In [None]:
val_ratio = 0.1
test_ratio = 0.1
val_size = round(num_samples * val_ratio)
test_size = round(num_samples * test_ratio)
train_size = num_samples - val_size - test_size
cums = np.array([train_size, val_size, test_size]).cumsum()
random.shuffle(data_list)

train_data = data_list[0 : cums[0]]
val_data = data_list[cums[0] : cums[1]]
test_data = data_list[cums[1] : cums[2]]

train_loader = DataLoader(train_data, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)

nBeads = train_data[0].num_nodes
nDim = train_data[0].pos.shape[1]  ## 各頂点の座標の次数

## 学習

In [None]:
start = time.time()
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters())
# criterion = nn.MSELoss()
criterion = KKLoss(K=K, L0=L0, eps=EPS)
history = {
    "train_loss": [],
    "val_loss": [],
}

for epoch in range(epoch_num):
    train_loss = 0.0
    model.train()
    for i, batch in enumerate(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        prediction = model(batch)
        #loss = criterion(prediction, batch.pos)
        loss = criterion(batch, prediction)
        loss.backward()
        optimizer.step()
        train_loss += loss.cpu().item() * nDim
        
        if i % 10 == 9:
            progress_bar = (
                "["
                + ("=" * ((i + 1) // 10))
                + (" " * ((train_size // 100 - (i + 1)) // 10))
                + "]"
            )
            print(
                "\repoch: {:d} loss: {:.3f}  {}".format(
                    epoch + 1,
                    loss.cpu().item() * nDim,  # 表示するlossはノード間の距離の二条の平均
                    progress_bar,
                ),
                end="  ",
            )

    print(
        "\repoch: {:d} loss: {:.3f}".format(
            epoch + 1, train_loss / math.ceil(train_size / batch_size)
        ),
        end="  ",
    )
    history["train_loss"].append(train_loss / math.ceil(train_size / batch_size))

    batch_num = 0
    loss = 0
    with torch.no_grad():
        model.eval()
        for data in val_loader:
            data = data.to(device)
            prediction = model(data)
            #loss += criterion(prediction, data.pos) * nDim
            loss += criterion(data, prediction)
            batch_num += 1

    history["val_loss"].append(loss.cpu().item() / batch_num)
    endstr = " " * max(1, (train_size // 1000 - 39)) + "\n"
    print(f"Val Loss: {loss.cpu().item()/batch_num:.3f}", end=endstr)


print("Finished Training")
elapsed_time = time.time() - start
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))

## 可視化

In [None]:
# 損失
plt.xlabel("epoch")
plt.ylabel("loss")

x = np.arange(epoch_num) + 1
plt.plot(x, history["train_loss"], label="train loss")
plt.plot(x, history["val_loss"], label="val loss")
plt.legend()
plt.show()

In [None]:
# 頂点の描画
t_index = 0
node_list = list(range(nBeads))
test_data = test_data[t_index].to(device)

# テストデータのノードの座標を予想
with torch.no_grad():
    model.eval()
    estimated_coords = model(test_data)

# 描画のためのグラフを作成
test_edge_indices = torch.t(test_data.edge_index).to("cpu").detach().numpy()
true_coords = test_data.pos.to("cpu").detach().numpy().copy()

G = nx.Graph()
G.add_nodes_from(node_list)
G.add_edges_from(test_edge_indices)

fig = plt.figure()

ax1 = fig.add_subplot(1, 2, 1)
ax1.set_title("True")
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_title("Estimated")

true_pos = dict(zip(node_list, true_coords))
# nx.draw_networkx(G, pos=true_pos, with_labels=False, ax=ax1,
#                 node_color="red", node_size=2)
nx.draw_networkx_nodes(G, pos=true_pos, ax=ax1, node_color="red", node_size=2)

estimated_pos = dict(zip(node_list, estimated_coords.cpu().detach().numpy()))
# nx.draw_networkx(G, pos=estimated_pos, with_labels=False, ax=ax2,
#                 node_color="red", node_size=2)
nx.draw_networkx_nodes(G, pos=estimated_pos, ax=ax2, node_color="red", node_size=2)

plt.show()

## 実験条件

In [None]:
t = 20

print("=====Simulation conditions=====")
print("目的：隣接町列を生成した数字の予測")
print("ネットワーク：GCNN")
print("Test run: {}".format(test_run))
print("Add noise: {}".format(add_noise))
if add_noise:
    print("Probabirity of edge drop: {}".format(drop_probability))
    print("Info about normal distribution: mean: {}, sigma: {}".format(mean, sigma))
    print("Pow: {}".format(pow_))
print("Number of beads: {}".format(nBeads))
print("Number of samples for training: {}".format(train_size))
print("Add noise: {}".format(add_noise))
print("Number of epochs: {}".format(epoch_num))
print("Batch size: {}".format(batch_size))
print("Diffusion time: {}".format(t))

print("=====Results=====")
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))