In [1]:
import json
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from glob import glob
from pathlib import Path
template = "plotly_white"
colors = px.colors.qualitative.Vivid

In [2]:
root_dir = Path("/Users/andrew/Desktop/exp")
def get_data(pattern, x_axis):
    xs = []
    ys = []
    for p in root_dir.glob(pattern):
        p = p / "results.json"
        if p.exists():
            data = json.load(p.open())
            x = np.array([i["perc_labeled"] * 100 for i in data])
            y = np.array([i["accuracy"] * 100 for i in data])
            xs.append(x)
            ys.append(y)
        else:
            print("Cannot find file", p)

    y_mean_list = []
    y_std_list = []
    for x_q in x_axis:
        vs = []
        for x, y in zip(xs, ys):
            idx = np.where(x >= x_q)[0][0]
            if idx >= 1:
                v = y[idx-1] + (y[idx] - y[idx-1]) / (x[idx] - x[idx-1]) * (x_q - x[idx-1])
            else:
                v = y[0]
            vs.append(v)
        v_mean = np.mean(vs)
        v_std = np.std(vs)
        y_mean_list.append(v_mean)
        y_std_list.append(v_std)
        
    y_mean_list = np.array(y_mean_list)
    y_std_list = np.array(y_std_list)
    return y_mean_list, y_std_list

In [3]:
def plot(patterns, names, title="", filename=None, opacity=0.1):
    x_init = 10
    x_axis = np.linspace(x_init, 100, 30)
    y_min = []
    y_max = []
    fig = go.Figure()
    for i, (pattern, name) in enumerate(zip(patterns, names)):

        y_mean_list, y_std_list = get_data(pattern, x_axis)
        y_min.append((y_mean_list-y_std_list).min())
        y_max.append((y_mean_list+y_std_list).max())

        fig.add_trace(go.Scatter(x=x_axis, 
                                 y=y_mean_list-y_std_list, 
                                 name=name, 
                                 mode="lines", 
                                 line_color=colors[i], 
                                 fillcolor=colors[i].replace('rgb', 'rgba').replace(')', ', {})'.format(opacity)), 
                                 fill=None, 
                                 showlegend=False))
        fig.add_trace(go.Scatter(x=x_axis, 
                                 y=y_mean_list+y_std_list, 
                                 name=name, 
                                 mode="lines", 
                                 line_color=colors[i], 
                                 fillcolor=colors[i].replace('rgb', 'rgba').replace(')', ', {})'.format(opacity)), 
                                 fill="tonexty", 
                                 showlegend=True))



    fig.update_xaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True, range=[x_init, 100])
    fig.update_yaxes(showline=True, linewidth=1.5, linecolor='Black', mirror=True, range=[max(0, min(y_min)), min(100, max(y_max)+10)])

    fig.update_layout(width=500, 
                      height=400, 
                      font=dict(size=15), 
                      margin=dict(t=40, b=40, l=0, r=0),
                      legend=dict(x=0.98,
                                  y=0.05,
                                  yanchor="bottom",
                                  xanchor="right",
                                  bgcolor= 'rgba(0,0,0,0)',
                                  bordercolor="Black",
                                  borderwidth=1), 
                      title=title, 
                      title_x=0.5, 
                      xaxis_title="#labeled data (%)", 
                      yaxis_title="Accuracy (%)", 
                      template=template)
    
    if filename:
        fig.write_image(filename)
    fig.show()

In [57]:
seed = "*"
N = 100
K = 10 # number of classes
sigma = 3 # 3, 5
alpha = 2 # 1, 2, 3
random = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_RandomSampler*".format(N, K)
optimal = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_OptimalSampler*".format(N, K)
argmax = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_CVXSampler-sigma_{}-alpha_{}-confidence_type_learner-clustering_type_none".format(N, K, sigma, alpha)
spectral = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_CVXSampler-sigma_{}-alpha_{}-confidence_type_learner-clustering_type_spectral".format(N, K, sigma, alpha)
argmax_perfect_onehot = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_CVXSampler-sigma_{}-alpha_{}-confidence_type_perfect_onehot_prob-clustering_type_none".format(N, K, sigma, alpha)
argmax_perfect_distributional = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_CVXSampler-sigma_{}-alpha_{}-confidence_type_perfect_distributional_prob-clustering_type_none".format(N, K, sigma, alpha)
spectral_perfect_onehot = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_CVXSampler-sigma_{}-alpha_{}-confidence_type_perfect_onehot_prob-clustering_type_spectral".format(N, K, sigma, alpha)
spectral_perfect_distributional = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_CVXSampler-sigma_{}-alpha_{}-confidence_type_perfect_distributional_prob-clustering_type_spectral".format(N, K, sigma, alpha)
argmax_perfect_diversity = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_CVXSampler-sigma_{}-alpha_{}-confidence_type_learner-clustering_type_none-diversity_type_optimal".format(N, K, sigma, alpha)
spectral_perfect_diversity = seed + "-N_{}-datatype_blob-K_{}-learner_SVMLearner-sampler_CVXSampler-sigma_{}-alpha_{}-confidence_type_learner-clustering_type_spectral-diversity_type_optimal".format(N, K, sigma, alpha)

In [55]:
plot([random, argmax, spectral], 
     ["RandomSampler", "GreedyCVXSampler", "ClusterCVXSampler"], 
     filename="plots/cvx_random_N={}_K={}_sigma={}_alpha={}.pdf".format(N, K, sigma, alpha), 
     title="Number of classes = {}".format(K), 
     opacity=0.3)

In [58]:
plot([optimal, argmax, spectral], 
     ["OptimalSampler", "GreedyCVXSampler", "ClusterCVXSampler"], 
     filename="plots/cvx_optimal_N={}_K={}_sigma={}_alpha={}.pdf".format(N, K, sigma, alpha), 
     title="Number of classes = {}".format(K), 
     opacity=0.3)

In [14]:
plot([argmax, argmax_perfect_onehot, argmax_perfect_distributional], 
     ["GreedyCVXSampler", "GreedyCVXSampler+OnehotProb*", "GreedyCVXSampler+Prob*"], 
     filename="plots/cvx_greedy_perfect_uncertainty_N={}_K={}_sigma={}_alpha={}.pdf".format(N, K, sigma, alpha), 
     opacity=0.3)

In [15]:
plot([spectral, spectral_perfect_onehot, spectral_perfect_distributional], 
     ["ClusterCVXSampler", "ClusterCVXSampler+OnehotProb*", "ClusterCVXSampler+Prob*"], 
     filename="plots/cvx_cluster_perfect_uncertainty_N={}_K={}_sigma={}_alpha={}.pdf".format(N, K, sigma, alpha), 
     opacity=0.3)

In [16]:
plot([argmax, argmax_perfect_diversity], 
     ["GreedyCVXSampler", "GreedyCVXSampler+Diversity*"], 
     filename="plots/cvx_greedy_perfect_diversity_N={}_K={}_sigma={}_alpha={}.pdf".format(N, K, sigma, alpha), 
     opacity=0.3)

In [17]:
plot([spectral, spectral_perfect_diversity], 
     ["ClusterCVXSampler", "ClusterCVXSampler+Diversity*"], 
     filename="plots/cvx_cluster_perfect_diversity_N={}_K={}_sigma={}_alpha={}.pdf".format(N, K, sigma, alpha), 
     opacity=0.3)