In [1]:
import os
import mlflow
import matplotlib as mlp
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import warnings
import numpy as np

import pandas as pd
mlp.rcParams["figure.facecolor"] = "white"
warnings.filterwarnings("ignore")

In [2]:
dataset_name = "cifar10"
graph_type = "accuracy"
batch_size = 16

In [None]:
mlflow.set_tracking_uri(uri="file:///mnt/stud/home/jcheng/scikit-activeml/tutorials/mlflow_tracking")

experiment = mlflow.get_experiment_by_name("Evaluation-Active-Learning-Params")
df = mlflow.search_runs(experiment_ids=experiment.experiment_id, output_format="pandas")

df = df[['params.dataset', 'params.qs', 'params.batch_size', 'params.n_cycles', 'params.seed', 'artifact_uri']]

df = df.loc[df['params.dataset'] == dataset_name]
query_stragies = df['params.qs'].unique()
colors = ["b", "g", "r", "c", "m", "k"]
query_list = [5, 2, 4, 3, 0, 1]

In [None]:
fig, ax = plt.subplots()
artists = []

for idx, qs_name in enumerate(query_stragies):
    if idx not in []:
        continue
    print(qs_name)
    print(idx)
    print(colors[idx])
    color = colors[idx]
    df_qs = df.loc[df['params.qs'] == qs_name]
    r = []
    for idx, row in df_qs.iterrows():
        artifact = os.path.join(row.artifact_uri, 'result.csv')
        artifact = artifact.split("file://")[1]
        print(artifact)
        print(os.path.exists(artifact))
        if os.path.exists(artifact):
            result_qs = pd.read_csv(artifact, index_col=0)
            r.append(result_qs)
    results = pd.concat(r)
    result = results.groupby(['step'])[graph_type].agg(['mean', 'std']).set_axis(['mean', 'std'], axis=1)
    result_mean = result['mean'].to_numpy()
    result_std = result['std'].to_numpy()
    plt.errorbar(np.arange(16, (len(result_mean)+1)*16, 16), result_mean, result_std,
                label=f"({np.mean(result_mean):.4f}) {qs_name}", alpha=0.3, color=color)

In [None]:
plt.axis([0,500,0,1])
plt.legend(bbox_to_anchor =(0.5,-0.35), loc='lower center', ncol=3)
plt.tight_layout()
plt.xlabel('# Labels queried')
if graph_type == "time":
    plt.yscale("log")
    plt.ylabel("Time [s]")
else:
    plt.ylabel("Accuracy")
# output_path = f'{dataset_name}_{graph_type}.pdf'
plt.title(dataset_name)
output_path = f'/mnt/stud/home/jcheng/scikit-activeml/tutorials/result_param/{dataset_name}_{graph_type}_{number}.pdf'
plt.savefig(output_path, bbox_inches="tight")