# GCNNによって隣接行列からノードの座標の予測する

## ライブラリのインポート，変数の設定

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import time
import datetime
from scipy.io import mmread
import random
import networkx as nx
import math
import os
import os.path as osp
from natsort import natsorted

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Dataset, InMemoryDataset, DataLoader
from torch_geometric.transforms import Compose
from torch_scatter import  scatter

# random seed
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)

In [None]:
test_run = True
use_InMemoryDataset = True
add_noise = True
drop_p = 0.01
mean, sigma = 0, 1
pow_ = -1.0

data_size_type = "small_" if test_run else "large_"
Dataset_type = "InMemoryDataset" if use_InMemoryDataset else "Dataset"
root = osp.join("data", data_size_type + "reconstruction_" + Dataset_type)

if test_run:
    epoch_num = 3
    batch_size = 128
else:
    epoch_num = 100
    batch_size = 128

## Datasetの定義

In [None]:
class MyInMemoryDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyInMemoryDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        # Read data into huge `Data` list.
        data_list = []
        
        adj_coo_mats = mmread(osp.join(self.raw_dir, "adjMats.mtx"))
        coords_ndas = mmread(osp.join(self.raw_dir, "coords.mtx"))
        
        num_sample = adj_coo_mats.shape[0]
        num_beads = int(np.sqrt(adj_coo_mats.shape[1]))
        
        for i in range(num_sample):
            # convert each row of raw data to adj
            adj_coo = adj_coo_mats.getrow(i).reshape(num_beads, num_beads)
            # convert each row of raw data to coords
            coords_nda = coords_ndas[i,:].reshape(num_beads, -1)
            
            num_edges = adj_coo.nnz
            nnf = 1  ## nnf: num_node_features
            
            data = Data(
                x=torch.ones((num_beads, nnf)).float(),
                edge_index=torch.tensor([adj_coo.row, adj_coo.col], dtype=torch.long),
                edge_attr=torch.tensor(adj_coo.data.reshape(num_edges, -1)).float(),
                pos=torch.tensor(coords_nda).float()
            )

            data_list.append(data)


        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])



# For large Dataset
class MyDataset(Dataset):

    processed_file_name = 'data_{}.pt'

    def __init__(self, root, transform=None, pre_transform=None):
        super(MyDataset, self).__init__(root, transform, pre_transform)

    @property
    def processed_file_names(self):
        num_graph = len(os.listdir(osp.join(self.raw_dir,"adj")))
        return [self.processed_file_name.format(i) for i in range(num_graph)]

    def process(self):
        adj_file_names = natsorted(os.listdir(osp.join(self.raw_dir,"adj")))
        coords_file_names = natsorted(os.listdir(osp.join(self.raw_dir,"coords")))
        
        for i, (adj_file_name,coords_file_name) in enumerate(zip(adj_file_names,coords_file_names)):
            adj_coo = mmread(osp.join(self.raw_dir, "adj", adj_file_name))
            coords_nda = mmread(osp.join(self.raw_dir, "coords", coords_file_name))
            
            num_beads = adj_coo.shape[0]
            num_edges = adj_coo.nnz
            nnf = 1  ## nnf: num_node_features
            
            data = Data(
                x=torch.ones((num_beads, nnf)).float(),
                edge_index=torch.tensor([adj_coo.row, adj_coo.col], dtype=torch.long),
                edge_attr=torch.tensor(adj_coo.data.reshape(num_edges, -1)).float(),
                pos=torch.tensor(coords_nda).float()
            )
                                    
            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, self.processed_file_name.format(i)))


    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

## Transformの定義

In [None]:
class EdgeTransform():
    r"""
    extract edge_index and edge_attr which correspond to
    lower triangular matrix of the adjacency matrix
    """
    def extract_L(self, data):
        ei = data.edge_index
        mask = np.where(ei[0] > ei[1], True, False)
        data.edge_index = ei[:, mask]
        data.edge_attr = data.edge_attr[mask]
        return data
    
    def L_to_symmetric(self, data):
        ei = data.edge_index
        data.edge_index = torch.cat((ei, ei[[1,0]]), dim=1)
        data.edge_attr = torch.cat((data.edge_attr, data.edge_attr), dim=0)
        return data


class DropEdge(EdgeTransform):
    r"""Drop each edge at probability of 0.01."""
    def __init__(self, p=0.01):
        self.p = p
        
    def __call__(self, data):
        data = self.extract_L(data)
        mask = np.where(np.random.rand(data.num_edges) > self.p, True, False)
        data.edge_index = data.edge_index[:, mask]
        data.edge_attr = data.edge_attr[mask]
        return self.L_to_symmetric(data)
    
    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)


class Noise(EdgeTransform):
    r"""Apply multiplicative noise to each element independently"""
    def __init__(self, mean=0, sigma=1.0):
        self.mean = mean
        self.sigma = sigma
        
    def __call__(self, data):
        data = self.extract_L(data)
        lognormals = torch.tensor(np.random.lognormal(
            mean=self.mean,
            sigma=self.sigma,
            size=data.edge_attr.shape
        )).float()
        data.edge_attr = data.edge_attr * lognormals        
        return self.L_to_symmetric(data)
    
    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)

    
class EdgeAttrInvert():
    r"""Raise edge_attr to the power pow_."""
    def __init__(self,pow_=-1.0):
        self.pow_ = pow_
        
    def __call__(self, data):
        data.edge_attr = data.edge_attr**self.pow_
        return data

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)


## ネットワークの定義

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #TODO: 入力グラフのチャネル数への依存をなくす
        self.conv1 = GCNConv(1, 16)
        self.conv2 = GCNConv(16, 32)
        self.conv3 = GCNConv(32, 48)
        self.conv4 = GCNConv(48, 64)
        self.conv5 = GCNConv(64, 96)
        self.conv6 = GCNConv(96, 128)
        self.linear1 = torch.nn.Linear(128,64)
        #TODO: 分類数への依存をなくす
        self.linear2 = torch.nn.Linear(64,2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        #edge_weightとして用いるedge_attrのshapeが(n,1)だとうまくいかない．
        #edge_attr.shapeが(n,)だと動く．おそらくPyG側のバグ？
        edge_weight = torch.squeeze(data.edge_attr)
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv3(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv4(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv5(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv6(x, edge_index, edge_weight)
        x = F.relu(x)

        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

## データの読み込み，DataLoaderの作成

In [None]:
edge_attr_invert = EdgeAttrInvert(pow_=pow_)

if add_noise:
    drop_edge = DropEdge(p=drop_p)
    noise = Noise(mean=mean, sigma=sigma)
    my_transform = Compose([drop_edge, noise, edge_attr_invert])
else:
    my_transform = edge_attr_invert

if use_InMemoryDataset:        
    all_set = MyInMemoryDataset(root,transform=my_transform)
    val_ratio = 0.1
    test_ratio = 0.1
    val_size = round(len(all_set) * val_ratio)
    test_size = round(len(all_set) * test_ratio)
    train_size = len(all_set) - val_size - test_size
    train_set, val_set, test_set = torch.utils.data.random_split(all_set, [train_size,val_size,test_size])
else:
    train_root = osp.join(root, "train")
    val_root = osp.join(root, "val")
    test_root = osp.join(root, "test")
    
    train_set = MyDataset(train_root,transform=my_transform)
    val_set = MyDataset(val_root,transform=my_transform)
    test_set = MyDataset(test_root,transform=my_transform)

train_loader = DataLoader(train_set, batch_size=batch_size)
val_loader = DataLoader(val_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)

## 学習

In [None]:
nDim = train_set[0].pos.shape[1]  ## 各頂点の座標の次数
start = time.time()
train_size = len(train_set)
val_size = len(val_set)
test_size = len(test_set)

model = Net().to(device)

optimizer = torch.optim.Adam(model.parameters())

criterion = nn.MSELoss()

history = {
    "train_loss": [],
    "val_loss": [],
}

for epoch in range(epoch_num):
    train_loss = 0.0
    model.train()
    for i, batch in enumerate(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        prediction = model(batch)

        loss = criterion(prediction, batch.pos)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.cpu().item() * nDim
        if i % 10 == 9:
            progress_bar = '['+('='*((i+1)//10))+(' '*((train_size//100-(i+1))//10))+']'
            print('\repoch: {:d} loss: {:.3f}  {}'
                .format(
                    epoch + 1,
                    loss.cpu().item() * nDim,  # 表示するlossはノード間の距離の二条の平均
                    progress_bar),
                end="  ")

    print('\repoch: {:d} loss: {:.3f}'
        .format(epoch + 1, train_loss / math.ceil(train_size / batch_size)), end="  ")
    history["train_loss"].append(train_loss / math.ceil(train_size / batch_size))

    batch_num = 0
    loss = 0
    with torch.no_grad():
        model.eval()
        for data in val_loader:
            data = data.to(device)
            prediction = model(data)
            loss += criterion(prediction, data.pos) * nDim
            batch_num += 1

    history["val_loss"].append(loss.cpu().item()/batch_num)
    endstr = ' '*max(1,(train_size//1000-39))+"\n"
    print(f'Val Loss: {loss.cpu().item()/batch_num:.3f}',end=endstr)


print('Finished Training')
elapsed_time = time.time() - start
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))

## 可視化

In [None]:
# 損失
plt.xlabel("epoch")
plt.ylabel("loss")

x = np.arange(epoch_num) + 1
plt.plot(x, history["train_loss"], label="train loss")
plt.plot(x, history["val_loss"], label="val loss")
plt.legend()
plt.show()

In [None]:
# 頂点の描画
t_index = 0
nBeads = train_set[0].num_nodes

node_list = list(range(nBeads))

test_data = test_set[t_index].to(device)

# テストデータのノードの座標を予想
with torch.no_grad():
    model.eval()
    estimated_coords = model(test_data)
    
# 描画のためのグラフを作成
test_edge_indices = torch.t(test_data.edge_index).to('cpu').detach().numpy()
true_coords = test_data.pos.to('cpu').detach().numpy().copy()

G = nx.Graph()
G.add_nodes_from(node_list)
G.add_edges_from(test_edge_indices)

fig = plt.figure()

ax1 = fig.add_subplot(1,2,1)
ax1.set_title("True")
ax2 = fig.add_subplot(1,2,2)
ax2.set_title("Estimated")

true_pos = dict(zip(node_list, true_coords))
#nx.draw_networkx(G, pos=true_pos, with_labels=False, ax=ax1,
#                 node_color="red", node_size=2)
nx.draw_networkx_nodes(G, pos=true_pos, ax=ax1, node_color="red", node_size=2)

estimated_pos = dict(zip(node_list, estimated_coords.cpu().detach().numpy()))
#nx.draw_networkx(G, pos=estimated_pos, with_labels=False, ax=ax2,
#                 node_color="red", node_size=2)
nx.draw_networkx_nodes(G, pos=estimated_pos, ax=ax2, node_color="red", node_size=2)

plt.show()

## 実験条件

In [None]:
nSmaples = len(train_set)
t = 20

print("=====Simulation conditions=====")
print("目的：隣接町列を生成した数字の予測")
print("ネットワーク：GCNN")
print("Test run: {}".format(test_run))
print("Add noise: {}".format(add_noise))
if add_noise:
    print("Probabirity of edge drop: {}".format(drop_p))
    print("Info about normal distribution: mean: {}, sigma: {}".format(maen, sigma))
    print("Pow: {}".format(pow_))
print("Number of beads: {}".format(nBeads))
print("Number of samples for training: {}".format(nSmaples))
print("Add noise: {}".format(add_noise))
print("Number of epochs: {}".format(epoch_num))
print("Batch size: {}".format(batch_size))
print("Diffusion time: {}".format(t))

print("=====Results=====")
print("elapsed time: {}".format(datetime.timedelta(seconds=int(elapsed_time))))