# GCNNによって隣接行列からノードの座標の予測する

## ライブラリのインポート

In [None]:
import datetime
import math
import os
import os.path as osp
import random
import time

import numpy as np
from scipy.io import mmread
from matplotlib import pyplot as plt
import networkx as nx
from natsort import natsorted
from sklearn.model_selection import train_test_split
from joblib import Parallel, delayed

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data, DataLoader, Dataset, InMemoryDataset
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import Compose
from torch_scatter import scatter

# random seed
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)

## 変数の設定

In [None]:
test_run = True
use_InMemoryDataset = True
add_noise = True
drop_probability = 0.01
mean, sigma = 0, 1
pow_ = -1.0

n_jobs = int(os.cpu_count() * 0.8)

data_size_type = "small_" if test_run else "large_"
Dataset_type = "InMemoryDataset" if use_InMemoryDataset else "Dataset"
root = osp.join("data", data_size_type + "reconstruction_" + Dataset_type)

if test_run:
    epoch_num = 10
    batch_size = 128
else:
    epoch_num = 100
    batch_size = 128

print("n_jobs = {}".format(n_jobs))   

## Dataオブジェクトを作成する関数

In [None]:
def drop_edge(data, prob=0.01):
    mask = np.where(np.random.rand(int(data.num_edges / 2)) > prob, True, False)
    mask = np.concatenate((mask, mask), axis=0)
    data.edge_index = data.edge_index[:, mask]
    data.edge_attr = data.edge_attr[mask]
    return data


def multiply_lognormal_noise(data, mean=0, sigma=1):
    size = np.array(data.edge_attr.shape)
    size[0] /= 2
    lognormals = torch.tensor(
        np.random.lognormal(mean=mean, sigma=sigma, size=size)
    ).float()
    data.edge_attr = data.edge_attr * torch.cat((lognormals, lognormals), 0)
    return data


def invert_edge_attr(data, pow_):
    data.edge_attr = data.edge_attr ** pow_
    return data


def load_graphs(adj_coo_mats, coords_ndas, i):
    num_beads = int(np.sqrt(adj_coo_mats.shape[1]))
    # convert each row of raw data to adj
    adj_coo = adj_coo_mats.getrow(i).reshape(num_beads, num_beads)
    # convert each row of raw data to coords
    coords_nda = coords_ndas[i, :].reshape(num_beads, -1)
    
    num_edges = adj_coo.nnz
    nnf = 1  ## nnf: num_node_features

    src, dst = np.array(adj_coo.row), np.array(adj_coo.col)
    edge_attr = np.array(adj_coo.data.reshape(num_edges, -1))
    mask = np.where(src > dst, True, False)
    edge_index_L = np.array([src[mask], dst[mask]])
    edge_attr_L = np.array(edge_attr[mask])    
    edge_index = np.concatenate((edge_index_L, edge_index_L[[1, 0]]), axis=1)
    edge_attr = np.concatenate((edge_attr_L, edge_attr_L), axis=0)

    data = Data(
        x=torch.ones((num_beads, nnf)).float(),
        edge_index=torch.tensor(edge_index, dtype=torch.long),
        edge_attr=torch.tensor(edge_attr).float(),
        pos=torch.tensor(coords_nda).float(),
    )

    return data

## ネットワークの定義

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # TODO: 入力グラフのチャネル数への依存をなくす
        self.conv1 = GCNConv(1, 16)
        self.conv2 = GCNConv(16, 32)
        self.conv3 = GCNConv(32, 48)
        self.conv4 = GCNConv(48, 64)
        self.conv5 = GCNConv(64, 96)
        self.conv6 = GCNConv(96, 128)
        self.linear1 = torch.nn.Linear(128, 64)
        # TODO: 分類数への依存をなくす
        self.linear2 = torch.nn.Linear(64, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # edge_weightとして用いるedge_attrのshapeが(n,1)だとうまくいかない．
        # edge_attr.shapeが(n,)だと動く．おそらくPyG側のバグ？
        edge_weight = torch.squeeze(data.edge_attr)
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv3(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv4(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv5(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv6(x, edge_index, edge_weight)
        x = F.relu(x)

        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

## データの読み込み

In [None]:
start = time.time()

adj_path = osp.join(root, "raw", "adjMats.mtx")
coords_path = osp.join(root, "raw", "coords.mtx")

print("Reading graphs from {}".format(root))

adj_coo_mats = mmread(adj_path)
coords_ndas = mmread(coords_path)
num_samples = adj_coo_mats.shape[0]
data_list = []

elapsed_time = time.time() - start

print("Finish reading graphs from storage.")
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))
print("num_samples: {}".format(num_samples))

## DataLoaderの作成

In [None]:
start = time.time()

print("Generating Data object")

data_list = Parallel(n_jobs=n_jobs)([delayed(load_graphs)(
        adj_coo_mats=adj_coo_mats,
        coords_ndas=coords_ndas,
        i=i
    ) for i in range(num_samples)])

print("Post-processing")       

if add_noise:
    print("Dropping edges")
    
    data_list = Parallel(n_jobs=n_jobs)([delayed(drop_edge)(
        data=data,
        prob=drop_probability
    ) for data in data_list])

    print("Adding noise")

    data_list = Parallel(n_jobs=n_jobs)([delayed(multiply_lognormal_noise)(
        data=data,
        mean=mean,
        sigma=sigma
    ) for data in data_list])

print("Inverting edge_attr")

data_list = Parallel(n_jobs=n_jobs)([delayed(invert_edge_attr)(
        data=data,
        pow_=pow_
    ) for data in data_list])

print("Finish post-processing!")

val_ratio = 0.1
test_ratio = 0.1
val_size = round(num_samples * val_ratio)
test_size = round(num_samples * test_ratio)
train_size = num_samples - val_size - test_size
cums = np.array([train_size, val_size, test_size]).cumsum()
random.shuffle(data_list)

train_data = data_list[0:cums[0]]
val_data = data_list[cums[0]:cums[1]]
test_data = data_list[cums[1]:cums[2]]

train_loader = DataLoader(train_data, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)

nBeads = train_data[0].num_nodes
nDim = train_data[0].pos.shape[1]  ## 各頂点の座標の次数

elapsed_time = time.time() - start
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))

## 学習

In [None]:
start = time.time()
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.MSELoss()
history = {
    "train_loss": [],
    "val_loss": [],
}

for epoch in range(epoch_num):
    train_loss = 0.0
    model.train()
    for i, batch in enumerate(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        prediction = model(batch)

        loss = criterion(prediction, batch.pos)
        loss.backward()
        optimizer.step()

        train_loss += loss.cpu().item() * nDim
        if i % 10 == 9:
            progress_bar = (
                "["
                + ("=" * ((i + 1) // 10))
                + (" " * ((train_size // 100 - (i + 1)) // 10))
                + "]"
            )
            print(
                "\repoch: {:d} loss: {:.3f}  {}".format(
                    epoch + 1,
                    loss.cpu().item() * nDim,  # 表示するlossはノード間の距離の二条の平均
                    progress_bar,
                ),
                end="  ",
            )

    print(
        "\repoch: {:d} loss: {:.3f}".format(
            epoch + 1, train_loss / math.ceil(train_size / batch_size)
        ),
        end="  ",
    )
    history["train_loss"].append(train_loss / math.ceil(train_size / batch_size))

    batch_num = 0
    loss = 0
    with torch.no_grad():
        model.eval()
        for data in val_loader:
            data = data.to(device)
            prediction = model(data)
            loss += criterion(prediction, data.pos) * nDim
            batch_num += 1

    history["val_loss"].append(loss.cpu().item() / batch_num)
    endstr = " " * max(1, (train_size // 1000 - 39)) + "\n"
    print(f"Val Loss: {loss.cpu().item()/batch_num:.3f}", end=endstr)


print("Finished Training")
elapsed_time = time.time() - start
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))

## 可視化

In [None]:
# 損失
plt.xlabel("epoch")
plt.ylabel("loss")

x = np.arange(epoch_num) + 1
plt.plot(x, history["train_loss"], label="train loss")
plt.plot(x, history["val_loss"], label="val loss")
plt.legend()
plt.show()

In [None]:
# 頂点の描画
t_index = 0

node_list = list(range(nBeads))

test_data = test_data[t_index].to(device)

# テストデータのノードの座標を予想
with torch.no_grad():
    model.eval()
    estimated_coords = model(test_data)

# 描画のためのグラフを作成
test_edge_indices = torch.t(test_data.edge_index).to("cpu").detach().numpy()
true_coords = test_data.pos.to("cpu").detach().numpy().copy()

G = nx.Graph()
G.add_nodes_from(node_list)
G.add_edges_from(test_edge_indices)

fig = plt.figure()

ax1 = fig.add_subplot(1, 2, 1)
ax1.set_title("True")
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_title("Estimated")

true_pos = dict(zip(node_list, true_coords))
# nx.draw_networkx(G, pos=true_pos, with_labels=False, ax=ax1,
#                 node_color="red", node_size=2)
nx.draw_networkx_nodes(G, pos=true_pos, ax=ax1, node_color="red", node_size=2)

estimated_pos = dict(zip(node_list, estimated_coords.cpu().detach().numpy()))
# nx.draw_networkx(G, pos=estimated_pos, with_labels=False, ax=ax2,
#                 node_color="red", node_size=2)
nx.draw_networkx_nodes(G, pos=estimated_pos, ax=ax2, node_color="red", node_size=2)

plt.show()

## 実験条件

In [None]:
t = 20

print("=====Simulation conditions=====")
print("目的：隣接町列を生成した数字の予測")
print("ネットワーク：GCNN")
print("Test run: {}".format(test_run))
print("Add noise: {}".format(add_noise))
if add_noise:
    print("Probabirity of edge drop: {}".format(drop_probability))
    print("Info about normal distribution: mean: {}, sigma: {}".format(mean, sigma))
    print("Pow: {}".format(pow_))
print("Number of beads: {}".format(nBeads))
print("Number of samples for training: {}".format(train_size))
print("Add noise: {}".format(add_noise))
print("Number of epochs: {}".format(epoch_num))
print("Batch size: {}".format(batch_size))
print("Diffusion time: {}".format(t))

print("=====Results=====")
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))