In [1]:
import sys
sys.path.append("..")
import torch
import pandas as pd
import numpy as np
import networkx as nx
import seaborn as sns
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
from reserch_utils_HT import network_to_image
from models.set_model import GNN_base
from data.data_loader import gnn_data_loader_cv
from torch_geometric.loader import DataLoader

In [2]:
from experiments.cross_validation import split_data, train
from sklearn.model_selection import KFold
from torch.utils.data.dataset import Subset
import torch.nn as nn
import torch.optim as optim
import torch
from experiments.gnn_train_utils import gnn_train_val_1epoch

In [3]:
parameter = {
    "BA": {
        20 : [1,2,3,4,5,6,7,8,9,10],
        30 : [1,2,3,4,5,6,7,8,9,10],
        50 : [1,2,3,4,5,6,7,8,9,10],
        70 : [1,2,3,4,5,6,7,8,9,10],
        100: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        200: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        300: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        500: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        700: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        1000: [1,2,3,4,5,6,7,8,9,10,15,20,25,30,50,100,200],
        2000: [1,2,3,4,5,6,7,8,9,10,15,20,25,30,50,100,200]
    },
    "Growth": {
        20 : [1,2,3,4,5,6,7,8,9,10],
        30 : [1,2,3,4,5,6,7,8,9,10],
        50 : [1,2,3,4,5,6,7,8,9,10],
        70 : [1,2,3,4,5,6,7,8,9,10],
        100: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        200: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        300: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        500: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        700: [1,2,3,4,5,6,7,8,9,10,15,20,25,30],
        1000: [1,2,3,4,5,6,7,8,9,10,15,20,25,30,50,100,200],
        2000: [1,2,3,4,5,6,7,8,9,10,15,20,25,30,50,100,200]
    },
    "Attach":{
        20 : [20 * i for i in np.linspace(0.5, 25, 10)],
        30 : [30 * i for i in np.linspace(0.5, 25, 10)],
        50 : [50 * i for i in np.linspace(0.5, 25, 10)],
        70 : [70 * i for i in np.linspace(0.5, 25, 10)],
        
        100 : [100 * i for i in np.linspace(0.5, 25, 14)],
        200 : [200 * i for i in np.linspace(0.5, 25, 14)],
        300 : [300 * i for i in np.linspace(0.5, 25, 14)],
        500 : [500 * i for i in np.linspace(0.5, 25, 14)],
        700 : [700 * i for i in np.linspace(0.5, 25, 14)],
        1000 : [1000 * i for i in np.linspace(0.5, 25, 17)],
        2000 : [2000 * i for i in np.linspace(0.5, 25, 17)],
        
    },
    "Random": {
        20: np.logspace(-1.3, -0.1, 10),
        30: np.logspace(-1.3, -0.1, 10),
        50: np.logspace(-1.3, -0.1, 10),
        70: np.logspace(-1.3, -0.1, 10),
        100: np.logspace(-2, -0.7, 14),
        200: np.logspace(-2, -0.7, 14),
        300: np.logspace(-2, -0.7, 14),
        500: np.logspace(-2, -0.7, 14),
        700: np.logspace(-2, -0.7, 14),
        1000: np.logspace(-3., -0.9, 17),
        2000: np.logspace(-3., -0.9, 17)
    }
}

nodes = [20,30,50,70,100,200,300,500,700,1000,2000]

In [36]:
def make_pred_df(model, data, kind, n, p, cnt):
    """ ネットワークごとの予測ラベルと尤度データ作成 """
    kind_to_label = {"BA": 0, "Attach": 1, "Growth": 2, "Random": 3}
    softmax = torch.nn.Softmax(1)
    with torch.no_grad():
        pred = softmax(model(data))
    index = pred.argmax(dim=1) # pred index
    
    del data

    pred_df = pd.DataFrame()
    pred_df["seed"] = np.array(range(10)) + 10000 + cnt # seed
    pred_df["node"] = n
    pred_df["parameter"] = p
    pred_df["kind"] = kind
    pred_df["true_label"] = kind_to_label[kind]
    pred_df["pred"] = index.cpu()
    pred_df["probablility"] = [pred[i, idx].item() for i, idx in enumerate(index)]
    return pred_df


def robust_acc_df(model, resize, kind):
    kind_to_index = {"BA": 0, "Attach": 1, "Growth": 2, "Random": 3}
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    df = pd.DataFrame()

    for n in tqdm([20,30,50,70,100,200,300,500,700,1000,2000]):
        if kind == "Random":
            paths = sorted(glob(f"./robustness_data_tensor/{kind}/{n}/*"))
        else:
            paths = natsorted(glob(f"./robustness_data_tensor/{kind}/{n}/*"))
            
        for path_index, path in enumerate(paths):
            # load data
            p = float(path.split("/")[-1])
            data, _ = gnn_data_loader_cv("a", path=path)
            data_loader = DataLoader(data, batch_size=10)
            
            cnt = 0
            for input_data in iter(data_loader):
                pred_df = make_pred_df(model, input_data.to(device), kind, n, p, cnt)
                df = df.append(pred_df, ignore_index=True)
                cnt += 10
    return df

def make_heatmap(dataset_name, resize, model_name, model, epoch, fold, gpu=True):
    pred_df = pd.DataFrame()
    for kind in ["BA", "Attach", "Growth", "Random"]:
        pred_df = pred_df.append(robust_acc_df(model, resize, kind))
        pred_df.to_csv(f"./robustness_plot/GNN/{model_name}_{dataset_name}_{epoch}_{fold}.csv")

In [37]:
from natsort import natsorted

In [38]:
resize=100
for model_name in ["GIN","DGCNN", "Deepsets"]:
    for dataset_name in ["subset1", "poisson", "new_poisson", "new_parete"]:
        print(model_name, dataset_name)
        fold_idx = 0
        for epoch in [49]:
            model_path=f"./GNN_model_save/{model_name}/{dataset_name}_fold{fold_idx}_epoch{epoch}.pth"
            model = torch.load(model_path)
            make_heatmap(dataset_name, resize, model_name, model, epoch, fold_idx, gpu=True)

  0%|          | 0/11 [00:00<?, ?it/s]

GIN subset1


100%|██████████| 11/11 [00:25<00:00,  2.35s/it]
100%|██████████| 11/11 [00:27<00:00,  2.51s/it]
100%|██████████| 11/11 [00:29<00:00,  2.70s/it]
100%|██████████| 11/11 [01:08<00:00,  6.21s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

GIN poisson


100%|██████████| 11/11 [00:28<00:00,  2.60s/it]
100%|██████████| 11/11 [00:25<00:00,  2.31s/it]
100%|██████████| 11/11 [00:27<00:00,  2.50s/it]
100%|██████████| 11/11 [01:08<00:00,  6.21s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

GIN new_poisson


100%|██████████| 11/11 [00:29<00:00,  2.68s/it]
100%|██████████| 11/11 [00:26<00:00,  2.44s/it]
100%|██████████| 11/11 [00:28<00:00,  2.56s/it]
100%|██████████| 11/11 [01:08<00:00,  6.26s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

GIN new_parete


100%|██████████| 11/11 [00:29<00:00,  2.69s/it]
100%|██████████| 11/11 [00:29<00:00,  2.71s/it]
100%|██████████| 11/11 [00:30<00:00,  2.77s/it]
100%|██████████| 11/11 [01:09<00:00,  6.31s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

DGCNN subset1


100%|██████████| 11/11 [00:32<00:00,  2.94s/it]
100%|██████████| 11/11 [00:28<00:00,  2.57s/it]
100%|██████████| 11/11 [00:32<00:00,  2.92s/it]
100%|██████████| 11/11 [01:10<00:00,  6.39s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

DGCNN poisson


100%|██████████| 11/11 [00:34<00:00,  3.09s/it]
100%|██████████| 11/11 [00:28<00:00,  2.62s/it]
100%|██████████| 11/11 [00:31<00:00,  2.82s/it]
100%|██████████| 11/11 [01:14<00:00,  6.74s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

DGCNN new_poisson


100%|██████████| 11/11 [00:30<00:00,  2.80s/it]
100%|██████████| 11/11 [00:31<00:00,  2.83s/it]
100%|██████████| 11/11 [00:29<00:00,  2.71s/it]
100%|██████████| 11/11 [01:06<00:00,  6.05s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

DGCNN new_parete


100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
100%|██████████| 11/11 [00:27<00:00,  2.53s/it]
100%|██████████| 11/11 [00:28<00:00,  2.55s/it]
100%|██████████| 11/11 [01:07<00:00,  6.14s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

Deepsets subset1


100%|██████████| 11/11 [00:21<00:00,  1.99s/it]
100%|██████████| 11/11 [00:22<00:00,  2.07s/it]
100%|██████████| 11/11 [00:22<00:00,  2.04s/it]
100%|██████████| 11/11 [00:52<00:00,  4.79s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

Deepsets poisson


100%|██████████| 11/11 [00:23<00:00,  2.13s/it]
100%|██████████| 11/11 [00:21<00:00,  1.92s/it]
100%|██████████| 11/11 [00:23<00:00,  2.13s/it]
100%|██████████| 11/11 [00:52<00:00,  4.78s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

Deepsets new_poisson


100%|██████████| 11/11 [00:22<00:00,  2.04s/it]
100%|██████████| 11/11 [00:21<00:00,  1.98s/it]
100%|██████████| 11/11 [00:23<00:00,  2.10s/it]
100%|██████████| 11/11 [00:52<00:00,  4.73s/it]
  0%|          | 0/11 [00:00<?, ?it/s]

Deepsets new_parete


100%|██████████| 11/11 [00:22<00:00,  2.06s/it]
100%|██████████| 11/11 [00:22<00:00,  2.06s/it]
100%|██████████| 11/11 [00:22<00:00,  2.07s/it]
100%|██████████| 11/11 [00:54<00:00,  4.99s/it]
