In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path
import os
import numpy as np
sns.set_style("darkgrid")
sns.set_theme("paper")

In [None]:
def load_results(results_dir, model, runs, suffix=''):
    const_results = []
    poly_results = []
    pruning_percentages = None

    for i in range(runs):
        
        file_name = Path(results_dir) / model / f'{model}_const/{i}/{model}_constant_sparsity_results{suffix}.json'
        with open(file_name) as f:
            pruning_percentages, results = parse_results_file(json.load(f))
            const_results.append(results)
        file_name = Path(results_dir) / model / f'{model}_poly/{i}/{model}_polynomial_decay_results{suffix}.json'
        with open(file_name) as f:
            pruning_percentages, results = parse_results_file(json.load(f))
            poly_results.append(results)

    return const_results, poly_results, pruning_percentages

def parse_results_file(results):

    pruning_percentages = []
    results_dict = {"baseline": {"aucs":[], "inf_time": [], "size": []},
            "int8": {"aucs":[], "inf_time": [], "size": []},
            "float16": {"aucs":[], "inf_time": [], "size": []}
    }

    for percentage, result in results.items():
        pruning_percentages.append(float(percentage))
        
        results_dict["baseline"]["aucs"].append(float(result["default"]["auc"]))
        results_dict["baseline"]["inf_time"].append(float(result["default"]["inf_time"]))
        results_dict["baseline"]["size"].append(float(result["default"]["size"]))

        results_dict["int8"]["aucs"].append(float(result["int8"]["auc"]))
        results_dict["int8"]["inf_time"].append(float(result["int8"]["inf_time"]))
        results_dict["int8"]["size"].append(float(result["int8"]["size"]))

        results_dict["float16"]["aucs"].append(float(result["float16"]["auc"]))
        results_dict["float16"]["inf_time"].append(float(result["float16"]["inf_time"]))
        results_dict["float16"]["size"].append(float(result["float16"]["size"]))
    
    return np.array(pruning_percentages) * 100, results_dict

def get_mean_std(results, quantization, metric):
    values = []
    for result in results:
        values.append(np.array(result[quantization][metric]))
    
    values = np.array(values)
    mean = values.mean(axis=0)
    std = values.std(axis=0)
    
    return mean, std
    
def plot_pruning(pruning_percentages, const_results, poly_results, file_name=None):
    fig, (ax1, ax2, ax3) = plt.subplots(1,3, sharex= True, figsize=(12,3))
    
    aucs_mean, aucs_std = get_mean_std(const_results, 'baseline', 'aucs')
    const_aucs = ax1.plot(pruning_percentages, aucs_mean)
    ax1.fill_between(pruning_percentages, aucs_mean-aucs_std, aucs_mean+aucs_std, alpha=0.3)
    ### Patch for legend
    const_aucs_fill = ax1.fill(np.NaN, np.NaN, 'b', alpha=0.3)
    
    aucs_mean, aucs_std = get_mean_std(poly_results, 'baseline', 'aucs')
    poly_aucs = ax1.plot(pruning_percentages, aucs_mean)
    ax1.fill_between(pruning_percentages, aucs_mean-aucs_std, aucs_mean+aucs_std, alpha=0.3)
    poly_aucs_fill = ax1.fill(np.NaN, np.NaN, 'tab:orange', alpha=0.3)
    
    ax1.set_ylim([0,1])
    ax1.set_ylabel("AUC Score")
    
    ax1.legend([(const_aucs_fill[0], const_aucs[0]), (poly_aucs_fill[0], poly_aucs[0])],
               ["Const.", "Poly. Decay"],
               loc='lower left')
    
    sizes_mean, sizes_std = get_mean_std(const_results, 'baseline', 'size')
    const_sizes = ax2.plot(pruning_percentages, sizes_mean)
    const_sizes = ax2.fill_between(pruning_percentages, sizes_mean-sizes_std, sizes_mean+sizes_std, alpha=0.3)
    sizes_mean, sizes_std = get_mean_std(poly_results, 'baseline', 'size')
    poly_sizes = ax2.errorbar(pruning_percentages, sizes_mean)
    poly_sizes = ax2.fill_between(pruning_percentages, sizes_mean-sizes_std, sizes_mean+sizes_std, alpha=0.3)
    ax2.set_ylabel("Gzipped Model Size (MB)")
    
    inf_mean, inf_std = get_mean_std(const_results, 'baseline', 'inf_time')
    inf_mean = inf_mean*1000
    inf_std = inf_std*1000
    const_inf =  ax3.plot(pruning_percentages, inf_mean)
    const_inf = ax3.fill_between(pruning_percentages, inf_mean-inf_std, inf_mean+inf_std, alpha=0.3)
    inf_mean, inf_std = get_mean_std(poly_results, 'baseline', 'inf_time')
    inf_mean = inf_mean*1000
    inf_std = inf_std*1000
    poly_inf =  ax3.plot(pruning_percentages, inf_mean)
    poly_inf = ax3.fill_between(pruning_percentages, inf_mean-inf_std, inf_mean+inf_std, alpha=0.3)
    ax3.set_ylabel("Inference Time (ms)")

    fig.text(0.5, 0.0001, 'Pruning Percentage (%)', ha='center')
    fig.tight_layout()
    
    if file_name:
        plt.savefig(file_name, bbox_inches='tight')
    
def plot_quantization(pruning_percentages, results, file_name=None):
    fig, (ax1, ax2, ax3) = plt.subplots(1,3, sharex= True, figsize=(12,3))

    
    int8_aucs_mean, int8_aucs_std = get_mean_std(results, 'int8', 'aucs')
    int8_aucs = ax1.plot(pruning_percentages, int8_aucs_mean)
    ax1.fill_between(pruning_percentages, int8_aucs_mean-int8_aucs_std, int8_aucs_mean+int8_aucs_std, alpha=0.3)
    int8_aucs_fill = ax1.fill(np.NaN, np.NaN, 'b', alpha=0.3)

    float16_aucs_mean, float16_aucs_std = get_mean_std(results, 'float16', 'aucs')
    float16_aucs = ax1.plot(pruning_percentages, float16_aucs_mean)
    ax1.fill_between(pruning_percentages, float16_aucs_mean-float16_aucs_std, float16_aucs_mean+float16_aucs_std, alpha=0.3)
    float16_aucs_fill = ax1.fill(np.NaN, np.NaN, 'tab:orange', alpha=0.3)

    ax1.set_ylim([0,1])
    
    ax1.set_ylabel("AUC Score")
    ax1.legend([(int8_aucs_fill[0], int8_aucs[0]), (float16_aucs_fill[0], float16_aucs[0])],
               ["Int8", "Float16"],
               loc='lower left')
    int8_size_mean, int8_size_std = get_mean_std(results, 'int8', 'size')
    int8_size = ax2.plot(pruning_percentages, int8_size_mean)
    int8_size = ax2.fill_between(pruning_percentages, int8_size_mean-int8_size_std, int8_size_mean+int8_size_std, alpha=0.3)
    float16_size_mean, float16_size_std = get_mean_std(results, 'float16', 'size')
    float16_size = ax2.plot(pruning_percentages, float16_size_mean)
    float16_size = ax2.fill_between(pruning_percentages, float16_size_mean-float16_size_std, float16_size_mean+float16_size_std, alpha=0.3)
    ax2.set_ylabel("Gzipped Model Size (MB)")


    int8_inf_mean, int8_inf_std = get_mean_std(results, 'int8', 'inf_time')
    int8_inf_mean = int8_inf_mean*1000
    int8_inf_std = int8_inf_std*1000
    int8_inf = ax3.plot(pruning_percentages, int8_inf_mean)
    int8_inf = ax3.fill_between(pruning_percentages, int8_inf_mean-int8_inf_std, int8_inf_mean+int8_inf_std, alpha=0.3)
    float16_inf_mean, float16_inf_std = get_mean_std(results, 'float16', 'inf_time')
    float16_inf_mean = float16_inf_mean *1000
    float16_inf_std = float16_inf_std * 1000
    float16_inf = ax3.plot(pruning_percentages, float16_inf_mean)
    float16_inf = ax3.fill_between(pruning_percentages, float16_inf_mean-float16_inf_std,float16_inf_mean+float16_inf_std,alpha=0.3)
    ax3.set_ylabel("Inference Time (ms)")

    fig.text(0.5, 0.0001, 'Pruning Percentage (%)', ha='center')
    fig.tight_layout()
    
    if file_name:
        plt.savefig(file_name, bbox_inches='tight')

def generate_table(percentages, results):
    aucs_mean, aucs_std = get_mean_std(results, 'baseline', 'aucs')
    int8_aucs_mean, int8_aucs_std = get_mean_std(results, 'int8', 'aucs')
    float16_aucs_mean, float16_aucs_std = get_mean_std(results, 'float16', 'aucs')

    sizes_mean, sizes_std = get_mean_std(results, 'baseline', 'size')
    int8_size_mean, int8_size_std = get_mean_std(results, 'int8', 'size')
    float16_size_mean, float16_size_std = get_mean_std(results, 'float16', 'size')

    inf_mean, inf_std = get_mean_std(results, 'baseline', 'inf_time')
    int8_inf_mean, int8_inf_std = get_mean_std(results, 'int8', 'inf_time')
    float16_inf_mean, float16_inf_std = get_mean_std(results, 'float16', 'inf_time')

    inf_mean, inf_std = inf_mean*1000, inf_std*1000
    int8_inf_mean, int8_inf_std = int8_inf_mean*1000, int8_inf_std*1000
    float16_inf_mean, float16_inf_std = float16_inf_mean*1000, float16_inf_std*1000
    
    # Sparsity & AUC & Compressed Model Size (MB) & Inference Time (s) & 8-Bit  Model Size & 8-Bit Inference Time & 16-Bit Model Size & 16-Bit Inference Time \\
    for percent, auc, auc_dev, auc8, auc8_dev,  auc16, auc16_dev, cms, cms_dev, cms8, cms8_dev, cms16, cms16_dev, it, it_dev, it8, it8_dev, it16, it16_dev in zip(percentages, aucs_mean, aucs_std, int8_aucs_mean, int8_aucs_std, float16_aucs_mean, float16_aucs_std, sizes_mean, sizes_std, int8_size_mean, int8_size_std, float16_size_mean, float16_size_std, inf_mean, inf_std,  int8_inf_mean, int8_inf_std, float16_inf_mean, float16_inf_std):
        string_to_output = f"{float(percent):.1f} & {auc:.3f} $\pm$ {auc_dev:.3f} & {auc8:.3f} $\pm$ {auc8_dev:.3f} & {auc16:.3f} $\pm$ {auc16_dev:.3f} & {cms:.3f} $\pm$ {cms_dev:.3f} & {cms8:.3f} $\pm$ {cms8_dev:.3f} & {cms16:.3f} $\pm$ {cms16_dev:.3f} & {it:.3f} $\pm$ {it_dev:.3f} & {it8:.3f}  $\pm$ {it8_dev:.3f} & {it16:.3f} $\pm$ {it16_dev:.3f} \\\\"
        print(string_to_output)

## Brogrammers results

In [None]:
# Results directory
results_dir = '../results/'
model = 'brogrammers'
runs = 20

# Load results
const_results, poly_results, pruning_percentages = load_results(results_dir, model, runs)

In [None]:
generate_table(pruning_percentages, const_results)

In [None]:
generate_table(pruning_percentages, poly_results)

In [None]:
plot_pruning(pruning_percentages, const_results, poly_results, file_name='cnn_pruning_exps_all.pdf')

In [None]:
plot_quantization(pruning_percentages, poly_results, file_name='cnn_quant_exps_all.pdf')

## Attention results

In [None]:
# Results directory
results_dir = '../results/'
model = 'attention'
runs = 5

# Load results
const_results, poly_results, pruning_percentages = load_results(results_dir, model, runs)

In [None]:
generate_table(pruning_percentages, const_results)

In [None]:
generate_table(pruning_percentages, poly_results)

In [None]:
plot_pruning(pruning_percentages, const_results, poly_results, file_name='cnn_lstm_pruning_exps_all.pdf')

In [None]:
plot_quantization(pruning_percentages, poly_results, file_name='cnn_lstm_quant_exps_all.pdf')