In [None]:
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# plt.rc("font", family="monospace", size=14)
plt.rcParams["font.size"] = 7
plt.rcParams["axes.titlesize"] = 7
plt.rcParams["font.family"] = "serif"
plt.rcParams["lines.linewidth"] = 1
plt.rcParams["text.usetex"] = False
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["xtick.top"] = True
plt.rcParams["xtick.major.width"] = 1.5
plt.rcParams["xtick.major.size"] = 3
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["ytick.right"] = True
plt.rcParams["ytick.major.width"] = 1.5
plt.rcParams["ytick.major.size"] = 3
line_configs = {"linewidth": 1.5, "markersize": 5, "markeredgewidth": 1.5, "fillstyle": "none"}
bar_configs = {"width": 0.8, "edgecolor": 'k', "linewidth": 1.5}
np.random.seed = 2024

In [None]:
%config InlineBackend.print_figure_kwargs = {"bbox_inches": None}

def rgb2hex(rgb):
    color = "#"
    for num in rgb:
        hex_num = hex(num)
        color += hex_num[2:].zfill(2)
    return color

In [None]:
def draw_twinx_plot(
    xaxis_values,
    line_values,
    bar_values,
    titles,
    eval_metrics,
    xlabel,
    yperform_ranges,
    yperform_ticks,
    yruntime_ranges,
    yruntime_ticks,
    save_to_pdf=False,
    **kwargs
):
    legend_labels = ["Line", "Bar"]
    # bar_color = rgb2hex([102, 155, 187])
    line_color = rgb2hex([189, 55, 82])
    bar_color = rgb2hex([153, 200, 224])
    # line_color = rgb2hex([240, 142, 100])
    fig, axs = plt.subplots(figsize=(4, 4), nrows=3, ncols=2)
    xaxis = np.arange(1, 1 + len(xaxis_values))

    for fid, ax in enumerate(axs.flatten()):
        task_id = fid // 2
        
        bar = ax.bar(xaxis, bar_values[fid], label=legend_labels[1], color=bar_color, alpha=0.5, **bar_configs)
        tax = ax.twinx()
        line = tax.plot(xaxis, line_values[fid], label=legend_labels[0], color=line_color, 
                        marker='o', **line_configs)

        ax.set_xlim(0, 1 + len(xaxis_values))
        ax.set_xticks(range(1, 1 + len(xaxis_values)))
        ax.set_xticklabels(xaxis_values)
        ax.set_xlabel(xlabel, labelpad=1.5)
        
        ax.set_ylim(*yruntime_ranges[fid])
        ax.set_yticks(yruntime_ticks[fid])
        if yruntime_ranges[fid][-1] > 1000:
            ax.set_yticklabels([t // 100 for t in yruntime_ticks[fid]])
            ax.set_ylabel(r"Time ($\times 10^2$ s)")
        else:
            ax.set_ylabel("Time (s)")

        tax.set_ylim(*yperform_ranges[fid])
        tax.set_yticks(yperform_ticks[fid])
        tax.set_ylabel(eval_metrics[task_id])
        
        ax.set_title(titles[fid], y=-0.75)
        ax.spines[["top", "bottom", "left", "right"]].set_linewidth(1.5)
    fig.legend(handles=[line[0], bar], loc="upper center", bbox_to_anchor=(0.51, 1.01), fontsize=8, ncols=2, frameon=False)
    # fig.legend(handles=[line, bar], loc="upper center", fontsize=8, bbox_to_anchor=(0.5, 1.02), ncols=2, columnspacing=2, frameon=False)
    
    plt.tight_layout()
    fig.subplots_adjust(wspace=0.9, hspace=0.9, top=0.93, bottom=0.15)
    if save_to_pdf:
        save_path = os.path.join(kwargs["save_dir"], kwargs["name"] + ".pdf")
        fig.savefig(save_path, dpi=1000)
    plt.show()