# GCNNによって隣接行列からノードの座標の予測する

## ライブラリのインポート

In [None]:
import datetime
import math
import os
import os.path as osp
import random
import time

import graph_tool as gt
import graph_tool.topology as topology
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from natsort import natsorted
from scipy.io import mmread
from torch_geometric.data import Data, DataLoader, Dataset, InMemoryDataset
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import Compose
from torch_scatter import scatter

# random seed
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)

## 変数の設定

In [None]:
is_trial = True
sigma = 0.2  # sigmas = [0.2, 0.4, 0.6, 0.8]
K, L0 = 1.0, 1.0
EPS = 1e-6

n_jobs = int(os.cpu_count() * 0.8)

data_size_type = "small_" if is_trial else "large_"
Dataset_type = "Dataset"
prefix = "_" + str(sigma)
root = osp.join("data", data_size_type + "reconstruction_" + Dataset_type + prefix)

if is_trial:
    epoch_num = 5
    batch_size = 3
else:
    epoch_num = 100
    batch_size = 8

print("n_jobs = {}".format(n_jobs))

## Datasetクラスの定義

In [None]:
class MyDataset(Dataset):

    processed_file_name = "data_{}.pt"

    def __init__(self, root, transform=None, pre_transform=None):
        super(MyDataset, self).__init__(root, transform, pre_transform)

    @property
    def processed_file_names(self):
        num_graph = len(os.listdir(osp.join(self.raw_dir, "adjs")))
        return [self.processed_file_name.format(i) for i in range(num_graph)]

    def process(self):
        adj_file_names = natsorted(os.listdir(osp.join(self.raw_dir, "adjs")))
        coords_file_names = natsorted(os.listdir(osp.join(self.raw_dir, "coords")))
        num_samples = len(adj_file_names)

        def generate_Data(index):
            adj_file_name = adj_file_names[index]
            coords_file_name = coords_file_names[index]
            adj_coo = mmread(osp.join(self.raw_dir, "adjs", adj_file_name))
            coords_nda = mmread(osp.join(self.raw_dir, "coords", coords_file_name))
            num_nodes = adj_coo.shape[0]
            edge_index, edge_attr = torch_geometric.utils.from_scipy_sparse_matrix(
                adj_coo
            )

            data = Data(
                x=torch.ones((num_nodes, 1)).float(),
                edge_index=edge_index,
                edge_attr=edge_attr.float() ** -1,  # invert edge_attr
                pos=torch.tensor(coords_nda).float(),
            )

            g = gt.Graph(directed=False)
            edge_list = np.array([adj_coo.row, adj_coo.col, adj_coo.data ** -1]).T
            eweight = g.new_ep("float")
            g.add_vertex(num_nodes)
            g.add_edge_list(edge_list, eprops=[eweight])
            dist = topology.shortest_distance(g, weights=eweight)
            graph_dist = torch.tensor(
                list(map(lambda i: dist[g.vertex(i)].a, range(g.num_vertices()))),
                dtype=torch.float,
            )
            data.distance = graph_dist.view(-1, 1)

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(
                data,
                osp.join(self.processed_dir, self.processed_file_name.format(index)),
            )

        Parallel(n_jobs=n_jobs)([delayed(generate_Data)(i) for i in range(num_samples)])

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, "data_{}.pt".format(idx)))
        return data


class EdgeAttrInvert:
    r"""Raise edge_attr to the power pow_."""

    def __init__(self, pow_=-1.0):
        self.pow_ = pow_

    def __call__(self, data):
        data.edge_attr = data.edge_attr ** self.pow_
        return data

    def __repr__(self):
        return "{}()".format(self.__class__.__name__)

## ネットワーク, Lossの定義

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # TODO: 入力グラフのチャネル数への依存をなくす
        self.conv1 = GCNConv(1, 16)
        self.conv2 = GCNConv(16, 32)
        self.conv3 = GCNConv(32, 48)
        self.conv4 = GCNConv(48, 64)
        self.conv5 = GCNConv(64, 96)
        self.conv6 = GCNConv(96, 128)
        self.linear1 = torch.nn.Linear(128, 64)
        # TODO: 分類数への依存をなくす
        self.linear2 = torch.nn.Linear(64, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # edge_weightとして用いるedge_attrのshapeが(n,1)だとうまくいかない．
        # edge_attr.shapeが(n,)だと動く．おそらくPyG側のバグ？
        edge_weight = torch.squeeze(data.edge_attr)
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv3(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv4(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv5(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv6(x, edge_index, edge_weight)
        x = F.relu(x)

        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x


class KKLoss(nn.Module):
    """Return energy of Kamada-Kawai as loss"""

    def __init__(self, K=1.0, L0=1.0, eps=1e-6):
        super().__init__()
        self.K = K
        self.L0 = L0
        self.eps = eps

    def forward(self, batch, prediction):
        """
        The Kamada-Kawai loss of the graph included in
        the batch is calculated in parallel as in the mini-batch.
        """
        data_list = batch.to_data_list()
        s = torch.tensor(0, dtype=torch.float).to(device)
        for h, data in enumerate(data_list):
            graph_dist = data.distance.view(data.num_nodes, data.num_nodes)
            k = self.K * torch.where(graph_dist != 0, graph_dist ** -2, graph_dist)

            if float("inf") in graph_dist:
                d_max = torch.unique(graph_dist, sorted=True)[-2]  # avoid inf
            else:
                d_max = torch.unique(graph_dist, sorted=True)[-1]

            L = self.L0 / d_max
            l = L * graph_dist
            l[l == float("inf")] = 0  # avoid 0 * inf = nan
            positions = prediction[torch.flatten(batch.batch == h), :]
            dx = positions[:, [0]] - positions[:, 0]
            dy = positions[:, [1]] - positions[:, 1]
            e = torch.sum(
                0.5 * k * (dx**2 + dy**2 + l**2
                - 2*l*torch.sqrt(dx**2 + dy**2 + self.eps))
                ) * 0.5
            s += e

        return s / batch.num_graphs

## DataLoaderの作成

In [None]:
# edge_attr_invert = EdgeAttrInvert(pow_=pow_)
# my_pre_transform = edge_attr_invert

train_root = osp.join(root, "train")
val_root = osp.join(root, "val")
test_root = osp.join(root, "test")

train_set = MyDataset(train_root)
val_set = MyDataset(val_root)
test_set = MyDataset(test_root)

train_loader = DataLoader(train_set, batch_size=batch_size)
val_loader = DataLoader(val_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)

train_size = len(train_set)
val_size = len(val_set)
test_size = len(test_set)

## 学習

In [None]:
start = time.time()
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters())
# criterion = nn.MSELoss()
criterion = KKLoss(K=K, L0=L0, eps=EPS)
history = {
    "train_loss": [],
    "val_loss": [],
}

for epoch in range(epoch_num):
    train_loss = 0.0
    model.train()
    for i, batch in enumerate(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        prediction = model(batch)
        loss = criterion(batch, prediction)
        loss.backward()
        optimizer.step()
        train_loss += loss.cpu().item()

        if i % 10 == 9:
            progress_bar = (
                "["
                + ("=" * ((i + 1) // 10))
                + (" " * ((train_size // 100 - (i + 1)) // 10))
                + "]"
            )
            print(
                "\repoch: {:d} loss: {:.3f}  {}".format(
                    epoch + 1, loss.cpu().item(), progress_bar,
                ),
                end="  ",
            )

    print(
        "\repoch: {:d} loss: {:.3f}".format(
            epoch + 1, train_loss / math.ceil(train_size / batch_size)
        ),
        end="  ",
    )
    history["train_loss"].append(train_loss / math.ceil(train_size / batch_size))

    batch_num = 0
    loss = 0
    with torch.no_grad():
        model.eval()
        for batch in val_loader:
            batch = batch.to(device)
            prediction = model(batch)
            loss += criterion(batch, prediction)
            batch_num += 1

    history["val_loss"].append(loss.cpu().item() / batch_num)
    endstr = " " * max(1, (train_size // 1000 - 39)) + "\n"
    print(f"Val Loss: {loss.cpu().item()/batch_num:.3f}", end=endstr)


print("Finished Training")
elapsed_time = time.time() - start
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))

## 可視化

In [None]:
# 損失
plt.xlabel("epoch")
plt.ylabel("loss")

x = np.arange(epoch_num) + 1
plt.plot(x, history["train_loss"], label="train loss")
plt.plot(x, history["val_loss"], label="val loss")
plt.legend()
plt.show()

In [None]:
# 頂点の描画
test_index = 0
nBeads = test_set[test_index].num_nodes
node_list = list(range(nBeads))
test_data = test_set[test_index].to(device)

# テストデータのノードの座標を予想
with torch.no_grad():
    model.eval()
    estimated_coords = model(test_data)

# 描画のためのグラフを作成
test_edge_indices = torch.t(test_data.edge_index).to("cpu").detach().numpy()
true_coords = test_data.pos.to("cpu").detach().numpy().copy()

G = nx.Graph()
G.add_nodes_from(node_list)
G.add_edges_from(test_edge_indices)

fig = plt.figure()

ax1 = fig.add_subplot(1, 2, 1)
ax1.set_title("True")
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_title("Estimated")

true_pos = dict(zip(node_list, true_coords))
# nx.draw_networkx(G, pos=true_pos, with_labels=False, ax=ax1,
#                 node_color="red", node_size=2)
nx.draw_networkx_nodes(G, pos=true_pos, ax=ax1, node_color="red", node_size=2)

estimated_pos = dict(zip(node_list, estimated_coords.cpu().detach().numpy()))
# nx.draw_networkx(G, pos=estimated_pos, with_labels=False, ax=ax2,
#                 node_color="red", node_size=2)
nx.draw_networkx_nodes(G, pos=estimated_pos, ax=ax2, node_color="red", node_size=2)

plt.show()

## 実験条件

In [None]:
t = 20
drop_probability = 0.01
mean = 0
pow_ = -1.0

print("=====Simulation conditions=====")
print("目的：隣接町列を生成した数字の予測")
print("ネットワーク：GCNN")
print("Test run: {}".format(is_trial))

print("Probabirity of edge drop: {}".format(drop_probability))
print("Pow: {}".format(pow_))
print("Info about normal distribution: mean: {}, sigma: {}".format(mean, sigma))

print("Number of beads: {}".format(nBeads))
print("Number of samples for training: {}".format(train_size))

print("Number of epochs: {}".format(epoch_num))
print("Batch size: {}".format(batch_size))
print("Diffusion time: {}".format(t))

print("=====Results=====")
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))