In [None]:
import numpy as np
from training.training import train
import matplotlib.pyplot as plt
import matplotlib
from datasetsdefer.acs_dataset import generate

def postprocess():
    tolerance_space = np.linspace(0.01, 0.2, 1000)
    threshold = []
    max_loss = []
    max_w = []
    max_loss_std = []
    max_w_std = []
    max_iter = 1
    for i in range(max_iter):
        Dataset_ACS = generate()
        th, ml, mw, mls, mws = train(tolerance_space, Dataset_ACS)
        if i == 0:
            threshold = th
            max_loss = ml
            max_w = mw
            max_loss_std = mls**2
            max_w_std = mws**2
        else:
            threshold += th
            max_loss += ml
            max_w += mw
            max_loss_std += mls**2
            max_w_std += mws**2

    threshold /= max_iter
    max_loss /= max_iter
    max_w /= max_iter
    max_loss_std /= max_iter
    max_w_std /= max_iter
    max_loss_std = np.sqrt(max_loss_std)
    max_w_std = np.sqrt(max_w_std)


    # remove the 0 accuracies, because it means the tolerance is not achievable
    idx_Z = np.where(np.abs(max_loss) > 1e-8)
    idx_Z = idx_Z[0]
    tols = tolerance_space[idx_Z]
    max_loss = max_loss[idx_Z]
    max_loss_std = max_loss_std[idx_Z]
# I plot the accuracy and the witness for the validation data
    plot = False
    if plot:
        plt.fill_between(tols,
                        max_loss-max_loss_std,
                        max_loss+max_loss_std, alpha=0.5)
        max_w_o = []
        max_w_std_o = []
        if isinstance(max_w[0], float):
            max_w_o = np.array(max_w)
            max_w_std_o = np.array(max_w_std)
            plt.fill_between(tols,
                            np.abs(max_w[idx_Z]-max_w_std[idx_Z]),
                            np.abs(max_w[idx_Z]+max_w_std[idx_Z]), alpha=0.5)
        else:
            for i in range(len(max_w[0])):
                max_w_i = np.array([max_w[j][i] for j in range(len(max_w))])
                max_w_std_i = np.array([max_w_std[j][i]
                                        for j in range(len(max_w_std))])
                # now I make a shadowed plot
                max_w_o.append(max_w_i)
                max_w_std_o.append(max_w_std_i)
                plt.fill_between(tols,
                                np.abs(max_w_i[idx_Z]-max_w_std_i[idx_Z]),
                                np.abs(max_w_i[idx_Z]+max_w_std_i[idx_Z]),
                                alpha=0.5)
    return tols, max_loss, max_w[idx_Z], max_loss_std, max_w_std[idx_Z]


def plot_tabular():
    tolerance_space = np.linspace(0.01, 0.2, 100)
    tolerance_space, max_loss, max_w, max_loss_std, max_w_std = train(tolerance_space)
    # remove the max_losses that are zero, because they are not achievable
    idx_Z = np.where(np.abs(max_loss) > 1e-8)
    idx_Z = idx_Z[0]
    tols = tolerance_space[idx_Z]
    max_loss = max_loss[idx_Z]
    max_loss_std = max_loss_std[idx_Z]
    matplotlib.rcParams['pdf.fonttype'] = 42
    matplotlib.rcParams['ps.fonttype'] = 42
    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    # fill
    plt.fill_between(tols, max_loss-max_loss_std, max_loss+max_loss_std,
                     alpha=0.5)
    plt.plot(tols, max_loss)
    ax = plt.gca()
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
    plt.grid()
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=15)
    plt.xlabel("Constraint Tolerance", fontsize=20)
    plt.ylabel("Test Accuracy", fontsize=20)
    fig_size = plt.rcParams["figure.figsize"]

    fig_size[0] = 6
    fig_size[1] = 4.2
    plt.savefig('obj.pdf', bbox_inches='tight', dpi=1000)
    plt.show()
    const_names = ["TPR", "TNR"]
    colors = ['blue', 'red']
    if isinstance(max_w[0], float):
        plt.fill_between(tols, max_w-max_w_std, max_w+max_w_std, alpha=0.5)
        plt.xlabel("Constraint Tolerance")
        plt.ylabel("Constraint Violation")
        plt.show()
    else:
        for i in range(len(max_w[0])):
            max_w_i = np.array([np.abs(max_w[j][i])
                                for j in range(len(max_w))])
            max_w_std_i = np.array([max_w_std[j][i]
                                    for j in range(len(max_w_std))])
            plt.fill_between(tols, max_w_i[idx_Z]-max_w_std_i[idx_Z],
                             max_w_i[idx_Z]+max_w_std_i[idx_Z],
                             alpha=0.5, color=colors[i])
            plt.plot(tols, max_w_i[idx_Z], label=const_names[i],
                     color=colors[i])
            plt.xlabel("Constraint Tolerance", fontsize=20)
            plt.ylabel("Constraint Violation", fontsize=20)
            plt.legend()
    # draw x=y line using dotted
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
    plt.grid()
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=15)
    plt.plot(tolerance_space, tolerance_space, 'k--')
    plt.savefig('consts.pdf', bbox_inches='tight', dpi=1000)
    plt.show()

plot_tabular()