# robustness 検証

ヒートマップ作成用のデータフレームの作成

In [1]:
import sys
sys.path.append("..")
import torch
import pandas as pd
import numpy as np
import networkx as nx
import seaborn as sns
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
from PIL import Image
from reserch_utils_HT import network_to_image
from models.set_model import CNN_base, D1D2_base
from data.data_loader import cnn_data_loader_cv, set_transform
from torch.utils.data import DataLoader
from matplotlib.cm import ScalarMappable
import matplotlib.colors as colors

In [6]:
def ba_growth_param(n):
    if n >= 50:
        params = [1,2,4,6,8,10,15,20,30,40]
    elif n == 30:
        params = [1,2,4,6,8,10,15,20]
    elif n == 20:
        params = [1,2,4,6,8,10]
    else:
        params = [1,2,4,6,8,10,15,20]
    return params

def attach_params(n):
    return [n*.1,n*.25,n*.5,n*.75,n,n*2,n*3,n*4,n*8,n*10,n*15]

def pred_there(model, data, class_index, thres):
    """ 閾値を超えて、予測が正解の数を返す """
    softmax = torch.nn.Softmax(1)
    with torch.no_grad():
        pred = softmax(model(data))
    index = pred.argmax(dim=1) # pred index
    count = (pred[index == class_index][:, class_index] > thres).sum().item() # 閾値を超え 且つ 予想が正解 した数
    return count

def robust_acc_df(model, resize, kind):
    kind_to_index = {"BA": 0, "Attach": 1, "Growth": 2, "Random": 3}
    transform = set_transform(resize)
    df = pd.DataFrame()
    pred_label_df = pd.DataFrame()
    for n in tqdm([20,30,50,70,100,130,200,300,500,1000,2000]):
        kind_to_parameters = {
            "BA": ba_growth_param(n),
            "Growth": ba_growth_param(n),
            "Attach": attach_params(n),
            "Random": [0.01, 0.02, 0.05, 0.07, 0.1, 0.15, 0.2]
        }

        acc_dict = {} # 精度保存用
        for param_index, p in enumerate(kind_to_parameters[kind]):
            # network to torch tensor
            for i, path in enumerate(glob(f"./robustness_data_img/{kind}/{n}/{p}/*")):
                img = Image.open(path)
                if i == 0:
                    data = transform(img).view(1,1,resize,resize)
                else:
                    img_tensor = transform(img).view(1,1,resize,resize)
                    data = torch.cat((data, img_tensor), 0)
            
            # pred
            theres = 0.7 # 閾値設定
            acc = pred_there(model, data, kind_to_index[kind], theres) / 50

            # save acc
            if kind == "Attach":
                param_name = ["node*0.1","node*0.25", "node*0.5", "node*0.75", "node", "node*2", "node*3", "node*4", "node*8", "node*10", "node*15"]
                acc_dict.setdefault(param_name[param_index], acc)
            else:
                acc_dict.setdefault(p, acc)

        df = df.append(pd.Series(acc_dict, name=n))
    return df.reindex(columns=df.columns[::-1]).T


def make_heatmap(dataset_name, resize, gpu=True):
    # load model
    model = CNN_base("CNN", 4, resize)
    model_path = f"../logs/{dataset_name}/CNN/sort_{resize}_0.001/model_weight/fold0_trial0_epoch10.pth"
    if gpu:
        model.load_state_dict(torch.load(model_path))
    else:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    
    # make acc dataframe
    ba = robust_acc_df(model, resize, "BA")
    attach = robust_acc_df(model, resize, "Attach")
    growth = robust_acc_df(model, resize, "Growth") 
    random = robust_acc_df(model, resize, "Random")
    
    ba.to_csv(f"./robustness_plot/acc_df/CNN_BA_{dataset_name}_{resize}.csv")
    attach.to_csv(f"./robustness_plot/acc_df/CNN_Attach_{dataset_name}_{resize}.csv")
    growth.to_csv(f"./robustness_plot/acc_df/CNN_Growth_{dataset_name}_{resize}.csv")
    random.to_csv(f"./robustness_plot/acc_df/CNN_Random_{dataset_name}_{resize}.csv")

In [8]:
for dataset_name in ["subset1", "poisson", "new_poisson", "new_parete"]:
    for resize in [50, 100, 200]:
        make_heatmap(dataset_name, resize, gpu=True)

100%|██████████| 11/11 [00:10<00:00,  1.09it/s]
100%|██████████| 11/11 [00:11<00:00,  1.01s/it]
100%|██████████| 11/11 [00:09<00:00,  1.11it/s]
100%|██████████| 11/11 [00:07<00:00,  1.42it/s]
100%|██████████| 11/11 [00:18<00:00,  1.72s/it]
100%|██████████| 11/11 [00:20<00:00,  1.90s/it]
100%|██████████| 11/11 [00:18<00:00,  1.70s/it]
100%|██████████| 11/11 [00:14<00:00,  1.30s/it]
100%|██████████| 11/11 [00:52<00:00,  4.80s/it]
100%|██████████| 11/11 [00:58<00:00,  5.29s/it]
100%|██████████| 11/11 [00:51<00:00,  4.68s/it]
100%|██████████| 11/11 [00:38<00:00,  3.49s/it]
100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
100%|██████████| 11/11 [00:11<00:00,  1.02s/it]
100%|██████████| 11/11 [00:09<00:00,  1.10it/s]
100%|██████████| 11/11 [00:07<00:00,  1.39it/s]
100%|██████████| 11/11 [00:18<00:00,  1.71s/it]
100%|██████████| 11/11 [00:21<00:00,  1.93s/it]
100%|██████████| 11/11 [00:18<00:00,  1.71s/it]
100%|██████████| 11/11 [00:14<00:00,  1.30s/it]
100%|██████████| 11/11 [00:51<00:00,  4.