# Step Plots

import libs and set default configurations

In [None]:
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

plt.rc("font", family="Times New Roman", size=14)
plt.rcParams["lines.linewidth"] = 2
plt.rcParams["text.usetex"] = True
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["ytick.direction"] = "in"
tick_params = {"direction": "in", "length": 4, "width": 1, "bottom": True, "top": True, "left": True, "right": True}
line_configs = {"markersize": 12, "markeredgewidth": 1.5, "fillstyle": "none"}

prerequisite functions

In [None]:
def rgb2hex(rgb):
    color = "#"
    for num in rgb:
        hex_num = hex(num)
        color += hex_num[2:].zfill(2)
    return color

function to draw a 2*3 step plot

In [None]:
def draw_three_dataset_scatters(exp_results: tuple, dataset_names, group_names, legend_labels, save_dir, exp_name):
    fig, axs = plt.subplots(figsize=(9, 6), nrows=2, ncols=3)
    fig.subplots_adjust(wspace=0.4, hspace=0.6)
    sizes = [6, 8, 10, 13, 15]
    for rid, ax_array in enumerate(axs):
        data = exp_results[rid]
        group = group_names[rid]
        x_cord = list(range(1, 6))
        for cid, ax in enumerate(ax_array):
            ax.plot(x_cord, data[cid][:5], "o--", color="grey", alpha=0.3, label="baseline", **line_configs)
            ax.step(x_cord, data[cid][5:10], label=legend_labels[0], **line_configs)
            ax.step(x_cord, data[cid][10:15], label=legend_labels[1], **line_configs)
            ax.set_title(dataset_names[cid])
            if cid == 1:
                title = "(a) " + group if rid == 0 else "(b) " + group
                ax.set_xlabel(title, labelpad=10, fontsize=20)
            ax.set_xticks([1, 2, 3, 4, 5])
            ax.set_xticklabels(["xlabel1", "xlabel2", "xlabel3", "xlabel4", "xlabel5"], rotation=20)
            ax.set_yticks([1, 2, 3, 4, 5])
            if cid == 0:
                ax.set_ylabel("ylabel")
            ax.spines[["top", "bottom", "left", "right"]].set_linewidth(1.5)
            labels = ["baseline", legend_labels[0], legend_labels[1]]
            ax.legend(labels, loc="lower right", fontsize=10, bbox_to_anchor=(1.3, -0.05))
    
    plt.tight_layout()
    save_path = os.path.join(save_dir, exp_name + ".png")
    fig.savefig(save_path, dpi=300)
    plt.show()

prepare data and visualize it

In [None]:
save_dir = "../figure"
dataset_names = ["dataset 1", "dataset 2", "dataset 3"]
method_names = ["method A", "method B"]
group_names = ["group 1", "group 2"]
group1 = [[1, 2, 4, 5, 3, 1, 2, 5, 4, 3, 1, 2, 3, 4, 5], 
          [1, 2, 3, 4, 5, 1, 2, 5, 4, 3, 1, 2, 3, 4, 5], 
          [1, 2, 3, 4, 5, 1, 2, 5, 4, 3, 1, 2, 3, 4, 5]]
group2 = [[1, 2, 3, 4, 5, 1, 2, 5, 4, 3, 1, 2, 3, 4, 5],
          [1, 2, 4, 3, 5, 1, 2, 5, 4, 3, 1, 2, 3, 4, 5],
          [1, 2, 3, 5, 4, 1, 2, 5, 4, 3, 1, 2, 3, 4, 5]]

draw_three_dataset_scatters(
    (group1, group2),
    dataset_names, group_names, method_names,
    save_dir, "step_plot"
)