In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch

fig, ax = plt.subplots(figsize=(14, 4.5))
ax.set_xlim(0, 14)
ax.set_ylim(0, 4.5)
ax.axis("off")
fig.patch.set_facecolor("#fafafa")

# --- Colors ---
c_model = "#4A90D9"
c_finetune = "#E8913A"
c_result = "#2ECC71"
c_arrow = "#555555"
c_detail = "#777777"

def draw_box(ax, x, y, w, h, color, label, fontsize=12, sublabel=None):
    box = FancyBboxPatch((x, y), w, h, boxstyle="round,pad=0.15",
                         facecolor=color, edgecolor=color, alpha=0.15, linewidth=2)
    ax.add_patch(box)
    border = FancyBboxPatch((x, y), w, h, boxstyle="round,pad=0.15",
                            facecolor="none", edgecolor=color, linewidth=2)
    ax.add_patch(border)
    ax.text(x + w / 2, y + h / 2 + (0.18 if sublabel else 0), label,
            ha="center", va="center", fontsize=fontsize, fontweight="bold", color=color)
    if sublabel:
        ax.text(x + w / 2, y + h / 2 - 0.22, sublabel,
                ha="center", va="center", fontsize=9, color=c_detail, style="italic")

def draw_arrow(ax, x1, y1, x2, y2):
    ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
                arrowprops=dict(arrowstyle="-|>", color=c_arrow, lw=2.5))

# --- Title ---

# --- Three boxes ---
# Box 1: Pre-trained LLM
draw_box(ax, 0.3, 1.8, 2.8, 1.5, c_model, "Pre-trained LLM", sublabel="e.g. Gemma 2 2B-IT")

# Arrow 1 (centered in gap, length 0.3)
draw_arrow(ax, 3.45, 2.55, 3.75, 2.55)

# Box 2: Fine-tuning (larger, with example inside)
ft_x, ft_y, ft_w, ft_h = 4.1, 0.8, 5.6, 3.0
draw_box(ax, ft_x, ft_y, ft_w, ft_h, c_finetune, "")

# Title and example lines centered as a group
ax.text(ft_x + ft_w / 2, 2.9, "Fine-Tuning on Reasoning Dataset",
        ha="center", va="center", fontsize=12, fontweight="bold", color=c_finetune)

example_lines = [
    '<prompt> "here is a math problem..." </prompt>',
    '<think> step-by-step reasoning... </think>',
    '<answer> 62 </answer>',
]
for i, line in enumerate(example_lines):
    ax.text(ft_x + ft_w / 2, 2.4 - i * 0.35, line,
            ha="center", va="center", fontsize=9, color=c_detail,
            family="monospace",
            bbox=dict(boxstyle="round,pad=0.15", facecolor="#F7EADD",
                      edgecolor="none"))

# Arrow 2 (centered in gap, length 0.3)
draw_arrow(ax, 9.95, 2.55, 10.25, 2.55)

# Box 3: Reasoning LLM
draw_box(ax, 10.5, 1.8, 3.2, 1.5, c_result, "Reasoning LLM", sublabel="thinks step-by-step")

plt.tight_layout()
plt.savefig("fine_tuning_reasoning_overview.png", dpi=150, bbox_inches="tight", facecolor="#fafafa")
plt.show()