In [None]:
import hashlib
import os

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import hsv_to_rgb
from matplotlib.font_manager import FontProperties, fontManager

In [None]:
# Register Libertinus Serif font
libertinus_serif_path = "/Users/aholovko/Library/Fonts/LibertinusSerif-Regular.otf"
fontManager.addfont(libertinus_serif_path)
libertinus_serif = FontProperties(fname=libertinus_serif_path)

# Register Libertinus Mono font
libertinus_mono_path = "/Users/aholovko/Library/Fonts/LibertinusMono-Regular.otf"
fontManager.addfont(libertinus_mono_path)
libertinus_mono = FontProperties(fname=libertinus_mono_path)

# Register Libertinus Math font
libertinus_math_path = "/Users/aholovko/Library/Fonts/LibertinusMath-Regular.otf"
fontManager.addfont(libertinus_math_path)
libertinus_math = FontProperties(fname=libertinus_math_path)

## Label cardinality

In [None]:
df = pd.read_csv("../reports/label_cardinality.csv").sort_values("num_labels")
total = df["count"].sum()
y_pos = np.arange(len(df))

plt.rcParams.update(
    {
        "font.family": libertinus_math.get_name(),
        "grid.color": "#DDD",
        "grid.linestyle": "--",
        "grid.linewidth": 0.5,
    }
)

fig, ax = plt.subplots(figsize=(6, 4))

ax.barh(y_pos, [total] * len(df), color="#F2F2F2", height=0.6, zorder=0)
bars = ax.barh(y_pos, df["count"], color="#0077B6", height=0.6, zorder=1)

for bar, cnt in zip(bars, df["count"], strict=False):
    pct = cnt / total * 100
    w = bar.get_width()
    if w > total * 0.12:
        xpos, ha, color = w / 2, "center", "white"
    else:
        xpos, ha, color = w + total * 0.01, "left", "#333"
    ax.text(
        xpos,
        bar.get_y() + bar.get_height() / 2,
        rf"$\mathbf{{{cnt:,}}}$" + f"\n({pct:.1f}%)",
        va="center",
        ha=ha,
        color=color,
        fontproperties=libertinus_math,
        fontsize=11,
    )

ax.set_yticks(y_pos)
ax.set_yticklabels(df["num_labels"].astype(str), fontproperties=libertinus_math)
ax.invert_yaxis()
ax.set_xlim(0, total)


def round50(x):
    return int(round(x / 50) * 50)


quartiles = [round50(total * q) for q in (0.25, 0.50, 0.75)] + [total]

ax.set_xticks(quartiles)
ax.set_xticklabels([f"{x:,}" for x in quartiles], fontproperties=libertinus_math)

for x in quartiles:
    ax.axvline(x, color="#DDD", linestyle="--", linewidth=0.8, zorder=0)

ax.set_ylabel("Labels per snippet", fontproperties=libertinus_serif, labelpad=8)

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

plt.tight_layout()
fig.savefig("../reports/figures/label_cardinality.png", dpi=300, bbox_inches="tight")
plt.show()

## Label co-occurrence matrix

In [None]:
co_df = pd.read_csv("../reports/label_cooccurrence.csv", index_col=0)

fig, ax = plt.subplots(figsize=(6, 4))

sns.heatmap(
    co_df,
    annot=True,
    fmt="d",
    cmap="viridis",
    cbar=False,
    linewidths=0.5,
    linecolor="gray",
    ax=ax,
    annot_kws={"fontproperties": libertinus_math, "fontsize": 10},
)

ax.xaxis.tick_top()
ax.xaxis.set_label_position("top")
ax.set_ylabel("", fontproperties=libertinus_serif)

plt.setp(ax.get_xticklabels(), rotation=45, ha="left", fontproperties=libertinus_serif)
plt.setp(ax.get_yticklabels(), rotation=0, fontproperties=libertinus_serif)

plt.tight_layout()
fig.savefig("../reports/figures/label_cooccurrence.png", dpi=300, bbox_inches="tight")
plt.show()

## Training (loss)

In [None]:
REPORT_FILE = "../reports/exp1_b.csv"

train_df = pd.read_csv(REPORT_FILE)
f_name = os.path.splitext(os.path.basename(REPORT_FILE))[0]

train = train_df[["epoch", "train/loss"]].dropna().rename(columns={"train/loss": "loss"}).assign(set="train")
val = train_df[["epoch", "val/loss"]].dropna().rename(columns={"val/loss": "loss"}).assign(set="val")

loss_df = pd.concat([train, val], ignore_index=True)

sns.set_theme(style="darkgrid")
fig, ax = plt.subplots(figsize=(8, 6))

sns.lineplot(data=loss_df, x="epoch", y="loss", hue="set", marker="o", ax=ax)

ax.set_xlabel("Epoch", fontproperties=libertinus_serif)
ax.set_ylabel("Loss", fontproperties=libertinus_serif)
ax.legend(title="", prop=libertinus_serif)

plt.setp(ax.get_xticklabels(), fontproperties=libertinus_math)
plt.setp(ax.get_yticklabels(), fontproperties=libertinus_math)

fig.tight_layout()
fig.savefig(f"../reports/figures/{f_name}_loss.png", dpi=300, bbox_inches="tight")

plt.show()

## Training (metrics)

In [None]:
REPORT_FILE = "../reports/exp3_b.csv"

df = pd.read_csv(REPORT_FILE)
f_name = os.path.splitext(os.path.basename(REPORT_FILE))[0]

plot_sets = ["val"]  # ['train', 'val']

metrics = ["f1", "precision", "recall"]
train_vars = [f"train/{m}" for m in metrics]
val_vars = [f"val/{m}" for m in metrics]

train_m = (
    df[["epoch"] + train_vars]
    .melt(id_vars="epoch", value_vars=train_vars, var_name="metric", value_name="value")
    .dropna()
)
train_m["set"] = "train"
train_m["metric"] = train_m["metric"].str.replace("train/", "")

val_m = (
    df[["epoch"] + val_vars].melt(id_vars="epoch", value_vars=val_vars, var_name="metric", value_name="value").dropna()
)
val_m["set"] = "val"
val_m["metric"] = val_m["metric"].str.replace("val/", "")

metrics_df = pd.concat([train_m, val_m], ignore_index=True)
metrics_df = metrics_df[metrics_df["set"].isin(plot_sets)]
metrics_df["label"] = metrics_df["set"] + " " + metrics_df["metric"]

sns.set_theme(style="darkgrid")
fig, ax = plt.subplots(figsize=(8, 6))

sns.lineplot(data=metrics_df, x="epoch", y="value", hue="label", style="label", markers=True, dashes=False, ax=ax)

ax.set_xlabel("Epoch", fontproperties=libertinus_serif)
ax.set_ylabel("Metric Value", fontproperties=libertinus_serif)
ax.legend(title="", prop=libertinus_serif)

plt.setp(ax.get_xticklabels(), fontproperties=libertinus_math)
plt.setp(ax.get_yticklabels(), fontproperties=libertinus_math)

fig.tight_layout()
fig.savefig(f"../reports/figures/{f_name}_metrics.png", dpi=300, bbox_inches="tight")
plt.show()

## Tokenization

In [None]:
from src.go_ast_tokenizer.tokenizer import SPECIAL_TOKENS, GoASTTokenizer
from src.go_ast_tokenizer.utils import get_tokenizer

MODEL_ID = "meta-llama/Llama-3.2-1B"

tokenizer = get_tokenizer(MODEL_ID)
tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})

In [None]:
SNIPPET = """package sample

func inc(i int) int {
	i += 1
	return i
}"""

go_ast_tokenizer = GoASTTokenizer()

token_ids = tokenizer.encode(SNIPPET)  # go_ast_tokenizer.tokenize(SNIPPET)
tokens = [tokenizer.decode(t) for t in token_ids]

print(tokens)

if tokens and tokens[0] == "<|begin_of_text|>":
    tokens = tokens[1:]

In [None]:
def generate_token_color(token):
    hash_val = hashlib.md5(token.encode()).hexdigest()
    hue = int(hash_val[:2], 16) / 255.0
    saturation = 0.6 + int(hash_val[2:4], 16) / 255.0 * 0.3
    value = 0.8 + int(hash_val[4:6], 16) / 255.0 * 0.2
    return hsv_to_rgb([hue, saturation, value])


token_colors = {token: generate_token_color(token) for token in set(tokens)}

x, y = 0.5, 8
line_height, max_width = 0.3, 9
token_positions = []

for token in tokens:
    display_token = token.replace("\n", "\\n").replace("\t", "\\t").replace(" ", "·")
    text_width = len(display_token) * 0.1

    if x + text_width > max_width:
        x = 0.5
        y -= line_height + 0.2

    token_positions.append(
        {
            "x": x,
            "y": y,
            "width": text_width,
            "height": line_height,
            "color": token_colors[token],
            "text": display_token,
        }
    )
    x += text_width

min_x = min(pos["x"] for pos in token_positions)
max_x = max(pos["x"] + pos["width"] for pos in token_positions)
min_y = min(pos["y"] - pos["height"] for pos in token_positions)
max_y = max(pos["y"] for pos in token_positions)

padding = 0.1
fig_width = (max_x - min_x + 2 * padding) * 2
fig_height = (max_y - min_y + 2 * padding) * 2

fig, ax = plt.subplots(figsize=(fig_width, fig_height))
ax.set_xlim(min_x - padding, max_x + padding)
ax.set_ylim(min_y - padding, max_y + padding)
ax.axis("off")

for _i, pos in enumerate(token_positions):
    rect = mpatches.Rectangle(
        (pos["x"], pos["y"] - pos["height"]),
        pos["width"],
        pos["height"],
        facecolor=pos["color"],
        alpha=0.3,
        linewidth=0,
    )
    ax.add_patch(rect)

    ax.text(
        pos["x"] + pos["width"] / 2,
        pos["y"] - pos["height"] / 2,
        pos["text"],
        ha="center",
        va="center",
        color="black",
        fontsize=12,
        fontproperties=libertinus_mono,
    )

    if pos["x"] == 0.5:
        ax.plot([pos["x"], pos["x"]], [pos["y"] - pos["height"] - 0.02, pos["y"] + 0.02], "k-", linewidth=0.8)

    ax.plot(
        [pos["x"] + pos["width"], pos["x"] + pos["width"]],
        [pos["y"] - pos["height"] - 0.02, pos["y"] + 0.02],
        "k-",
        linewidth=0.8,
    )

plt.tight_layout()
fig.savefig("../reports/figures/tokenized_snippet.png", dpi=300, bbox_inches="tight")
plt.show()