In [1]:
import os
import mlflow
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import warnings
import numpy as np
import pandas as pd
from matplotlib.animation import FuncAnimation

# 设置Matplotlib的背景颜色为白色
plt.rcParams["figure.facecolor"] = "white"
warnings.filterwarnings("ignore")

In [2]:
dataset_name = "cifar10"
graph_type = "accuracy"
batch_size = 16

In [3]:
mlflow.set_tracking_uri(uri="file:///mnt/stud/home/jcheng/scikit-activeml/tutorials/mlflow_tracking")

# 获取实验
experiment = mlflow.get_experiment_by_name("Evaluation-Active-Learning-Params")

# 搜索实验运行结果
df = mlflow.search_runs(experiment_ids=experiment.experiment_id, output_format="pandas")

# 选择特定数据集的数据
df = df.loc[df['params.dataset'] == dataset_name]

Traceback (most recent call last):
  File "/mnt/stud/home/jcheng/miniconda3/envs/scikit-activeml/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 882, in _list_run_infos
    run_info = self._get_run_info_from_dir(r_dir)
  File "/mnt/stud/home/jcheng/miniconda3/envs/scikit-activeml/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 694, in _get_run_info_from_dir
    meta = FileStore._read_yaml(run_dir, FileStore.META_DATA_FILE_NAME)
  File "/mnt/stud/home/jcheng/miniconda3/envs/scikit-activeml/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1303, in _read_yaml
    return _read_helper(root, file_name, attempts_remaining=retries)
  File "/mnt/stud/home/jcheng/miniconda3/envs/scikit-activeml/lib/python3.10/site-packages/mlflow/store/tracking/file_store.py", line 1296, in _read_helper
    result = read_yaml(root, file_name)
  File "/mnt/stud/home/jcheng/miniconda3/envs/scikit-activeml/lib/python3.10/site-packages/mlflow/uti

In [4]:
query_strategies = df['params.qs'].unique()
colors = ["b", "g", "r", "c", "m", "k"]
query_list = [5, 2, 4, 3, 0, 1]
query_list_time = [5, 4, 2, 0, 3, 1]

In [5]:
def generate_gif(graph_type):
    
    # 创建图形
    fig, ax = plt.subplots(figsize=(8, 6))

    def update(frame):
        ax.clear()
        for idx in range(frame):
            idx = query_list_time[idx]
            qs_name = query_strategies[idx]
            color = colors[idx]
            df_qs = df.loc[df['params.qs'] == qs_name]
            r = []
            for idx, row in df_qs.iterrows():
                artifact = os.path.join(row.artifact_uri, 'result.csv')
                artifact = artifact.split("file://")[1]
                if os.path.exists(artifact):
                    result_qs = pd.read_csv(artifact, index_col=0)
                    r.append(result_qs)
            results = pd.concat(r)
            result = results.groupby(['step'])[graph_type].agg(['mean', 'std']).set_axis(['mean', 'std'], axis=1)
            result_mean = result['mean'].to_numpy()
            result_std = result['std'].to_numpy()
            ax.errorbar(np.arange(16, (len(result_mean)+1)*16, 16), result_mean, result_std,
                        label=f"({np.mean(result_mean):.4f}) {qs_name}", alpha=0.3, color=color)
        
        ax.legend(bbox_to_anchor=(0.5, -0.35), loc='lower center', ncol=3)
        ax.set_xlabel('# Labels queried')
        if graph_type == "time":
            plt.yscale("log")
            plt.ylabel("Time [s]")
            ax.set_ylim(0, 1000)
        else:
            plt.ylabel("Accuracy")
            ax.set_ylim(0, 1)
        ax.set_title(dataset_name)
        ax.set_xlim(0, 500)
        
    
    # 生成动画
    anim = FuncAnimation(fig, update, frames=range(7), interval=1000)
    
    # 保存动画为GIF
    output_path = f'{dataset_name}_{graph_type}.gif'
    anim.save(output_path, writer='imagemagick')
    
    plt.close()

# 调用函数生成GIF
generate_gif(graph_type)

