In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from ipywidgets import interact

%matplotlib inline

In [14]:
def fix_dataframe(df):
    params = []
    vals = []
    for key, val in df.items():
        if key == "Unnamed: 0":
            continue
        params.append(key)
        vals.append(val)

    dicos = []
    for stuff in zip(*vals):
        others = []
        dico = {}
        for param, val in zip(params, stuff):
            if not param.startswith("eps"):
                dico[param] = val
            else:
                others.append((param, val))

        if not isinstance(dico["embedding_type"], str):
            dico["method"] = dico["binary"].upper()
        else:
            dico["method"] = dico["embedding_type"]
        for param, val in others:
            sup, eps = param.split("_")[1:3]
            eps = float(eps)
            dicos.append({"eps": eps, "AUC": val, "supervised": sup == "s", **dico})

    return pd.DataFrame(dicos)

# csv_file = "csv/results_8_04_2020_readable.csv"
csv_file = "../../results/csv/fgsm_20_04_2020.csv"
df = pd.read_csv(csv_file)
df = fix_dataframe(df)
df = df.loc[df.architecture != "svhn_lenet_bandw2"]

In [15]:
import os


def filter_column(df, column, value):
    return df.loc[np.logical_or(df[column].isnull(), df[column] == value)]


@interact
def plot_results(# architecture=df.architecture.unique(),
                 supervised=[True, False],
                 mlp=[False, True],
                 # sigmoidize=[False, True],
                 # epochs=[50, 200, 250, 300]
                 ):
    subdf = df.copy()
    subdf["mlp"] = subdf["architecture"].apply(lambda arch: "mlp" in arch or "simple_fcn" in arch)
    #for param, val in zip(["supervised", "mlp", # "sigmoidize"
    #                      ],
    #                      [supervised, mlp,
    #                       # sigmoidize
    #                      ]):
    #    subdf = filter_column(subdf, param, val)
    #hue_order = ["PersistentDiagram", "RawGraph", "LID", "MAHALANOBIS"]
    #sns.catplot(data=subdf, x="eps", y="AUC", hue="method", kind="point", col="architecture", col_wrap=2,
    #            ci=None, estimator=max, hue_order=hue_order)
    #plt.tight_layout()
    #fig_file = "%s_fig_sup=%s_mlp=%s.png" % (os.path.basename(csv_file),
    #                                         supervised, mlp)
    #for param, val in zip(["supervised", "mlp"],
    #                      [supervised, mlp]):
    for param, val in zip(["supervised", "mlp", # "sigmoidize"
                          ],
                          [supervised, mlp,
                           # sigmoidize
                          ]):
        subdf = filter_column(subdf, param, val)
    hue_order = ["PersistentDiagram", "RawGraph", "LID", "MAHALANOBIS"]
    sns.catplot(data=subdf, x="eps", y="AUC", hue="method", kind="point", col="architecture", col_wrap=2,
                ci=None, estimator=max, hue_order=hue_order)
    plt.tight_layout()
    fig_file = "%s_fig_sup=%s_mlp=%s.png" % (os.path.basename(csv_file),
                                             supervised, mlp)
    plt.savefig(fig_file, dpi=200, bbox_inches="tight");
    print("Saved %s" % fig_file)

interactive(children=(Dropdown(description='supervised', options=(True, False), value=True), Dropdown(descript…

In [16]:
df.sample(200)

Unnamed: 0,eps,AUC,supervised,id,embedding_type,architecture,binary,epochs,sigmoidize,aucs_l2_norm,time,method
175,0.01,0.618352,True,9,RawGraph,svhn_lenet,ocsvm_detector,250,False,,7612.683631,RawGraph
20,0.10,0.547040,False,29,PersistentDiagram,fashion_mnist_mlp,ocsvm_detector,50,True,,1912.171264,PersistentDiagram
7,0.01,0.603152,True,21,,fashion_mnist_mlp,mahalanobis,50,False,,8424.356384,MAHALANOBIS
161,0.40,0.992144,True,23,,svhn_lenet,mahalanobis,250,False,,7430.626311,MAHALANOBIS
33,0.10,0.897504,True,30,RawGraph,fashion_mnist_mlp,ocsvm_detector,50,True,,2107.697495,RawGraph
...,...,...,...,...,...,...,...,...,...,...,...,...
199,0.01,0.535800,True,10,PersistentDiagram,cifar_lenet,ocsvm_detector,300,False,,1142.033917,PersistentDiagram
62,0.10,0.323392,False,3,RawGraph,mnist_lenet,ocsvm_detector,50,False,,3696.628605,RawGraph
11,0.40,0.952144,True,21,,fashion_mnist_mlp,mahalanobis,50,False,,8424.356384,MAHALANOBIS
200,0.10,0.749808,False,10,PersistentDiagram,cifar_lenet,ocsvm_detector,300,False,,1142.033917,PersistentDiagram


In [11]:
!pwd

/home/elvis/CODE/FORKED/TDA_for_adv_robustness/notebooks/paper_figs
