In [None]:
from matplotlib import pyplot as plt
import os
import numpy as np
import seaborn as sns

In [None]:

class Args:
    def __init__(self):
        self.log_root='log/'
        self.dataset = 'naval'

        self.train_bias_y = False
        self.train_bias_f = False
        self.train_cons = False
        self.train_calib = False
        self.re_calib = False
        self.re_bias_f = False
        self.re_bias_y = False

        # Modeling parameters
        self.model = 'linear'
        self.learning_rate = 1e-3
        self.batch_size = 32
        self.bbatch_size = 1024
        self.num_bins = 0
        self.knn = 100

        # Run related parameters
        self.gpu = 0
        self.num_epoch = 500
        self.run_label = 0
        self.num_run = 10


In [None]:
models = ['linear', 'small', 'med', 'big']

In [None]:
num_rep = 3
# num_runs = list(range(24)) + list(range(40, 40+24))
max_runs = 10
fps = np.zeros((len(models), 13, max_runs * num_rep, 8))
fns = np.zeros((len(models), 13, max_runs * num_rep, 8))
objectives = np.zeros((len(models), 13, max_runs * num_rep, 5))


In [None]:
num_runs = list(range(100, 115))
for objective in range(8):
    args = Args()
    if objective % 2 == 1:
        args.train_bias_f = True
    if (objective // 2) % 2 == 1:
        args.train_bias_y = True
    if objective // 4 == 1:
        args.train_calib = True
    
    for run_i, run_label in enumerate(num_runs):
        args.run_label = run_label
        args.name = '%s/model=%s-%r-%r-%r-%r-%r-%r-%r-bs=%d-%d-bin=%d-%d-run=%d' % \
            (args.dataset, args.model, 
             args.train_bias_y, args.train_bias_f, args.train_cons, args.train_calib, args.re_calib, args.re_bias_f, args.re_bias_y,
             args.batch_size, args.bbatch_size, args.num_bins, args.knn, args.run_label)
        log_dir = os.path.join(args.log_root, args.name)
        reader = open(os.path.join(log_dir, 'results.txt'))
        for epoch in range(5):
            reader.readline()
        for rep in range(num_rep):
            line = reader.readline().split()
            objectives[objective, run_i*num_rep+rep] = np.array([float(line[i]) for i in range(5)])
            fnfp = np.reshape(np.array([float(line[i]) for i in range(5, 21)]), [-1, 2])
            fns[objective, run_i*num_rep+rep] = fnfp[:, 0]
            fps[objective, run_i*num_rep+rep] = fnfp[:, 1]
            #for epoch in range(1):
            #    reader.readline()
            

In [None]:
num_rep = 3
max_line = 1000000
# num_runs = list(range(24)) + list(range(40, 40+24))
num_runs = list(range(10, 20))
labels = ['none', 'calib', 'bias f', 'bias_y', 'all']
for m_id, model in enumerate(models):
    for objective in range(5):
        args = Args()
        args.model = model
        if objective == 1:
            args.re_calib = True
        if objective == 2:
            args.re_bias_f = True
        if objective == 3:
            args.re_bias_y = True
        if objective == 4:
            args.re_calib = True
            args.re_bias_f = True
            args.re_bias_y = True

        for run_i, run_label in enumerate(num_runs):
            args.run_label = run_label
            args.name = '%s/model=%s-%r-%r-%r-%r-%r-%r-%r-bs=%d-%d-bin=%d-%d-run=%d' % \
                (args.dataset, args.model, 
                 args.train_bias_y, args.train_bias_f, args.train_cons, args.train_calib, args.re_calib, args.re_bias_f, args.re_bias_y,
                 args.batch_size, args.bbatch_size, args.num_bins, args.knn, args.run_label)
            log_dir = os.path.join(args.log_root, args.name)
            reader = open(os.path.join(log_dir, 'results.txt'))

            cur_rep = 0
            for line_id in range(max_line):
                line = reader.readline().split()
                if len(line) < 21:
                    # print("%s-%d-%d-%d" % (model, objective, run_i, line_id))
                    break
                objectives[m_id, objective+8, run_i*num_rep+cur_rep] = np.array([float(line[i]) for i in range(5)])
                fnfp = np.reshape(np.array([float(line[i]) for i in range(5, 21)]), [-1, 2])
                fns[m_id, objective+8, run_i*num_rep+cur_rep] = fnfp[:, 0]
                fps[m_id, objective+8, run_i*num_rep+cur_rep] = fnfp[:, 1]
                if line_id % 10 == 0:
                    cur_rep = (cur_rep + 1) % num_rep

                #for epoch in range(1):
                #    reader.readline()


In [None]:
from matplotlib.ticker import FormatStrFormatter

names = ['train l2', 'test L2', 'y-bias', 'f-bias', 'marginal calibration', 'decision loss']
model_list = [2]
fontsize = 14

def set_axis_style(ax, labels):
    ax.get_xaxis().set_tick_params(direction='out')
    ax.xaxis.set_ticks_position('bottom')

    # ax.set_yticks([-0.1, 0.0, 0.1])
    # ax.set_ytick([-0.1, 0.0, 0.1], fontsize=fontsize)
#    ax.set_xlabel('Groups')

    # plt.gca().set_axis_style(fontsize=fontsize)
for mid in [0, 1, 2, 3]:
    # print(objectives.shape)
    palette = sns.color_palette('husl', 5)
    plt.figure(figsize=(20, 3.4))
    for idx, item in enumerate([1, 4, 3, 2, 5]):
        #for mid, model in enumerate(models):

        plt.subplot(1, 5, idx+1)
            #for i in range(5):
        if item != 5:
            datalist = objectives[mid, 8:, :, item].transpose(1, 0).reshape(-1, 5) * 100
        else:
            datalist = (fps[mid, 8:, :, :].mean() + fns[mid, 8:, :, :]).transpose(1, 2, 0).reshape(-1, 5) 
            # print(datalist.shape)
        data_range = np.sort(datalist.flatten())
        data_range = [data_range[0] / 1.3, (data_range[-5] - data_range[0])*1.1 + data_range[0]]
        
        violin_parts = plt.violinplot(datalist, 
                       positions=range(5), 
                       widths=0.8, showmeans=True, showextrema=False)
        
        for vi, vp in enumerate(violin_parts['bodies']):
            vp.set_facecolor(palette[vi])
        
        labels = ['None', 'calib', 'f-bias', 'y-bias', 'all']
        plt.gca().set_xticks(np.arange(len(labels)))
        plt.gca().set_xticklabels(labels, fontsize=fontsize)
        
#         if mid == 0:
#             plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%5.1f'))
#         else:
        plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%6.2f'))
        # plt.gca().set_yticks(np.linspace(data_range[0], data_range[1], 4))
        plt.yticks(fontsize=fontsize)
        plt.ylim(data_range)
        # plt.yscale('log')
            # plt.axvline(objectives[mid, 8+i, :, item].mean(), c=palette[i], label=labels[i], linewidth=2, alpha=0.6)
            # x_range = [objectives[mid, 8:, :, item].mean(axis=1).min(), objectives[mid, 8:, :, item].mean(axis=1).max()]
            # plt.xlim(x_range[0] - (x_range[1]-x_range[0]) * 0.1, x_range[1] + (x_range[1]-x_range[0]) * 0.1)
            # plt.legend()
            # plt.title('%s-%s' % (model, names[item]))
        plt.title(names[item], fontsize=fontsize)

    # plt.subplot(2, 3, 6)

    # plt.violinplot(datalist,                    
    #                positions=[(i//5)+(i%5)*0.1 for i in range(5*len(models))], 
    #                 widths=0.1, showmeans=True, showextrema=False)
    plt.tight_layout()
    plt.savefig('plots/decision_%s_%s.png' % (args.dataset, models[mid]))
    plt.show()