# robustness 検証

ヒートマップ作成用のデータフレームの作成

In [3]:
import sys
sys.path.append("..")
import torch
import pandas as pd
import numpy as np
import networkx as nx
import seaborn as sns
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
from PIL import Image
from reserch_utils_HT import network_to_image
from models.set_model import CNN_base, D1D2_base
from data.data_loader import cnn_data_loader_cv, set_transform
from torch.utils.data import DataLoader
from matplotlib.cm import ScalarMappable
from natsort import natsorted
import matplotlib.colors as colors

In [4]:
def load_image_to_tensor(resize, kind, path):
    """ データの読み込み, tensor結合 """
    transform = set_transform(resize)
    for i in range(100):
        img = Image.open(f"{path}/{i}.png")
        if i == 0:
            data = transform(img).view(1,1,resize,resize)
        else:
            img_tensor = transform(img).view(1,1,resize,resize)
            data = torch.cat((data, img_tensor), 0)
    return data

def make_pred_df(model, data, kind, n, p, cnt):
    """ ネットワークごとの予測ラベルと尤度データ作成 """
    kind_to_label = {"BA": 0, "Attach": 1, "Growth": 2, "Random": 3}
    softmax = torch.nn.Softmax(1)
    with torch.no_grad():
        pred = softmax(model(data))
    index = pred.argmax(dim=1) # pred index

    pred_df = pd.DataFrame()
    pred_df["seed"] = np.array(range(100)) + 10000 # seed
    pred_df["node"] = n
    pred_df["parameter"] = p
    pred_df["kind"] = kind
    pred_df["true_label"] = kind_to_label[kind]
    pred_df["pred"] = index
    pred_df["probablility"] = [pred[i, idx].item() for i, idx in enumerate(index)]
    
    return pred_df

def robust_acc_df(model, resize, kind):
    kind_to_index = {"BA": 0, "Attach": 1, "Growth": 2, "Random": 3}
    transform = set_transform(resize)
    df = pd.DataFrame()
    cnt = 0
    for n in tqdm([20,30,50,70,100,200,300,500,700,1000,2000]):
        if kind == "Random":
            paths = sorted(glob(f"./robustness_data_img/{kind}/{n}/*"))
        else:
            paths = natsorted(glob(f"./robustness_data_img/{kind}/{n}/*"))
        for path_index, path in enumerate(paths):
            # load data
            p = float(path.split("/")[-1])
            # image to torch tensor
            data = load_image_to_tensor(resize, kind, path)
            pred_df = make_pred_df(model, data, kind, n, p, cnt)
            cnt += 100
            df = df.append(pred_df)
    return df


def make_heatmap(dataset_name, resize, gpu=True):
    # load model
    model = CNN_base("CNN", 4, resize)
    model_path = f"../logs/{dataset_name}/CNN/sort_{resize}_0.001/model_weight/fold0_trial0_epoch40.pth"
    if gpu:
        model.load_state_dict(torch.load(model_path))
    else:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

    # make acc dataframe
    pred_df = pd.DataFrame()
    for kind in ["BA", "Attach", "Growth", "Random"]:
        pred_df = pred_df.append(robust_acc_df(model, resize, kind))
    pred_df.to_csv(f"./robustness_plot/CNN_{dataset_name}_{resize}.csv")

In [5]:
for resize in [100,50,200]:
    for dataset_name in ["new_poisson","subset1","poisson","new_parete"]:
        make_heatmap(dataset_name, resize, gpu=True)

100%|██████████| 11/11 [01:02<00:00,  5.73s/it]
100%|██████████| 11/11 [01:13<00:00,  6.65s/it]
100%|██████████| 11/11 [01:02<00:00,  5.71s/it]
100%|██████████| 11/11 [01:33<00:00,  8.46s/it]
100%|██████████| 11/11 [00:57<00:00,  5.26s/it]
100%|██████████| 11/11 [01:15<00:00,  6.82s/it]
100%|██████████| 11/11 [01:14<00:00,  6.74s/it]
100%|██████████| 11/11 [02:00<00:00, 10.96s/it]
100%|██████████| 11/11 [01:18<00:00,  7.10s/it]
100%|██████████| 11/11 [01:31<00:00,  8.32s/it]
100%|██████████| 11/11 [01:20<00:00,  7.28s/it]
100%|██████████| 11/11 [02:05<00:00, 11.44s/it]
100%|██████████| 11/11 [01:20<00:00,  7.33s/it]
100%|██████████| 11/11 [01:32<00:00,  8.39s/it]
100%|██████████| 11/11 [01:19<00:00,  7.26s/it]
100%|██████████| 11/11 [02:03<00:00, 11.23s/it]
100%|██████████| 11/11 [00:48<00:00,  4.37s/it]
100%|██████████| 11/11 [00:49<00:00,  4.54s/it]
100%|██████████| 11/11 [00:45<00:00,  4.15s/it]
100%|██████████| 11/11 [01:10<00:00,  6.44s/it]
100%|██████████| 11/11 [00:44<00:00,  4.