In [3]:
%reload_ext autoreload
%autoreload 2

In [4]:
from dataclasses import dataclass

import numpy as np
import pandas as pd
from depsurf import OUTPUT_PATH, BuildVersion, DepKind
from depsurf.plot import bold, plot_yticks, get_text_height, get_legend_handles_labels, save_fig
from matplotlib import pyplot as plt
from matplotlib import transforms


def plot_legend(fig: plt.Figure):
    handles, labels = get_legend_handles_labels(fig)
    fig.legend(
        handles,
        labels,
        loc="upper center",
        ncol=len(labels) // 2,
        bbox_to_anchor=(0.5, 0.9375),
        frameon=False,
    )


group_labels = {
    "lts": "Kernel Versions w/ LTS (2 yr)",
    "all": "Kernel Versions w/ Regular Releases (6 mo)",
    "rev": f"Revisions for Kernel {bold(5.4)}",
}


@dataclass
class SubfigPlotter:
    ax: plt.Axes
    df: pd.DataFrame
    group: str
    show_xlabels: bool
    fontsize: int = 8

    @property
    def columns(self):
        return self.df.columns.drop(["Old"])

    @property
    def num_bars(self):
        return len(self.df.index)

    def plot(self):
        self.plot_bar()
        self.plot_xticks()
        plot_yticks(self.ax)
        with plt.rc_context(
            {
                "font.family": "monospace",
                "font.size": self.fontsize,
                # "font.stretch": "condensed",
            }
        ):
            self.plot_val_labels()
            self.plot_top_labels()

    def plot_bar(self, x_pad=0.5):
        bottom = np.zeros(self.num_bars)
        xs = np.arange(self.num_bars)
        for col in self.columns:
            self.ax.bar(xs, self.df[col], label=col, bottom=bottom, color=col.color)
            bottom += self.df[col].to_numpy()

        self.ax.set_xlim(-x_pad, self.num_bars - 1 + x_pad + 0.01)
        self.ax.set_ylim(0, bottom.max() * 1.15)

    @staticmethod
    def flatten_pairs(pairs):
        fst, snd = list(zip(*pairs))
        assert fst[1:] == snd[:-1]
        return fst + (snd[-1],)

    def plot_xticks(self):
        xs = np.arange(self.num_bars + 1) - 0.5

        pairs = self.flatten_pairs(self.df.index)
        versions = [BuildVersion.from_str(v) for v in pairs]
        if self.group == "rev":
            labels = [v.revision for v in versions]
            self.ax.set_xticks(xs, labels)
        elif self.group == "lts":
            labels = [v.short_version for v in versions]
            self.ax.set_xticks(xs, labels, fontweight="bold")
        elif self.group == "all":
            for lts in [True, False]:
                xs_labels = [
                    (x, v.short_version) for x, v in zip(xs, versions) if v.lts == lts
                ]
                self.ax.set_xticks(
                    *zip(*xs_labels),
                    minor=not lts,
                    fontweight="bold" if lts else "normal",
                )
            self.ax.grid(which="major", axis="x", linestyle="--", linewidth=1)

        self.ax.tick_params(axis="both", which="both", length=4, labelsize=9)

        if not self.show_xlabels:
            self.ax.set_xticklabels([])
            self.ax.set_xticklabels([], minor=True)
        else:
            self.ax.set_xlabel(group_labels[self.group])

    def plot_val_labels(self):
        text_height = get_text_height(self.ax)

        def format_val(val):
            if val > 1000:
                return f"{val / 1000:.2g}k"
            return str(val)

        trans = self.ax.transData
        bottom = np.zeros(self.num_bars)
        for col in self.columns:
            for i, v in enumerate(self.df[col]):
                bar_height = (trans.transform((0, v)) - trans.transform((0, 0)))[1]
                if bar_height < text_height * 0.9:
                    continue
                h = v / 2 + bottom[i]
                self.ax.text(i, h, format_val(v), ha="center", va="center")
            bottom += self.df[col].to_numpy()

    def plot_top_labels(self):
        trans = self.ax.transData
        ymax = self.ax.get_ylim()[1]
        for i, (index, row) in enumerate(self.df.iterrows()):
            old = row["Old"]
            total = sum(row) - old

            # if total == 0:  # explicitly add 0 labels
            #     self.ax.text(i, 0, "0", ha="center", va="bottom")
            #     continue

            added = row["Added"]
            removed = row["Removed"]
            changed = total - added - removed

            if self.group == "rev":
                format_val = lambda v: f"{v:0.0f}"
            else:
                format_val = lambda v: f"{v / old:0.0%}"

            y = self.fontsize
            for k, v, c in [
                (r"\Delta", changed, "darkgreen"),
                (r"\minus", removed, "xkcd:dark orange"),
                (r"\plus", added, "blue"),
            ]:
                y -= self.fontsize
                if v == 0:
                    continue
                self.ax.text(
                    i,
                    ymax,
                    f"${k}${format_val(v)}",
                    ha="center",
                    va="top",
                    transform=transforms.offset_copy(trans, y=y, units="dots"),
                    color=c,
                )


fig, axs = plt.subplots(
    3,
    3,
    figsize=(12, 12),
    width_ratios=[4, 16, 5],
    gridspec_kw={"wspace": 0.15, "hspace": 0.075},
)

df = pd.read_pickle(OUTPUT_PATH / "diff.pkl").T

for row_idx, (ax_rows, kind) in enumerate(
    zip(axs, [DepKind.FUNC, DepKind.STRUCT, DepKind.TRACEPOINT])
):
    for col_idx, (ax, group) in enumerate(zip(ax_rows, group_labels.keys())):
        df_group = df[kind].loc[group]
        SubfigPlotter(
            ax=ax,
            df=df_group,
            group=group,
            show_xlabels=row_idx == len(axs) - 1,
        ).plot()

        if col_idx == 0:
            ax.set_ylabel(f"Number of {bold(kind.capitalize())} Changes")

plot_legend(fig)
save_fig(fig, "diff")

[    utils.py:44 ] INFO: Saved figure to /Users/szhong/Downloads/bpf-study/paper/figs/diff.pdf
