# Transformer hyperparameter scaling

This notebook has been written to run on the TU Ilmenau cluster with my specific setup.
However, it can also be used simply to visualize the learning rate scaling laws. 

A summary of all results is also stored in the directory `/research/scaling_data`.

In [None]:
from __future__ import annotations

import os
import re
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import ticker  
from matplotlib.ticker import FixedLocator
from matplotlib.colors import Normalize, LogNorm
from tensorboard.backend.event_processing import event_accumulator

import optimetal.utils as utils
import optimetal.factory as factory
from optimetal.data.loader import load_torch_data
utils.load_plot_style()

def load_tb_scalars(logdir: str) -> dict:
    """
    Load scalar values from tensorboard event files. This is useful
    when you want to look at training and validation loss curves.
    """
    ea = event_accumulator.EventAccumulator(
        logdir,
        size_guidance={event_accumulator.SCALARS: 0},
    )
    ea.Reload()
    tags = ea.Tags().get("scalars", [])
    tb_log = {}
    for tag in tags:
        events = ea.Scalars(tag)
        values = [e.value for e in events]
        tb_log[tag] = values
    return tb_log

def load_results(study_path: str) -> dict:
    """
    Load all results from the transformer hyperparameter scaling law study.
    Input:
        study_path:     Path to the root directory containing subdirectories from model training
    Output:
        scaling_laws:   Nested dict mapping, model type, scaling type, and hyperparameters to dictionaries with "val_loss"
    """

    # gather all study directories
    study_dirs = [d for d in os.listdir(study_path) if os.path.isdir(os.path.join(study_path, d))]
    # initialize nested structure for results
    results = {}
    # iterate through each study directory and load results
    print(f"Loading scaling law results for {len(study_dirs):d} models")
    for study_dir in study_dirs:
        # path setup and checks
        study_dir_path = os.path.join(study_path, study_dir)
        val_loss_path = os.path.join(study_dir_path, "val_loss.txt")
        best_model_path = os.path.join(study_dir_path, "best_model.pt")
        if not os.path.exists(val_loss_path) or not os.path.exists(best_model_path):
            print(f"Skipping {study_dir_path:s}, probably still running")
            continue
        # get the number of model parameters
        best_model_dict = load_torch_data(best_model_path)
        config_dict = best_model_dict["config_dict"]
        model_config = config_dict.architecture
        model = factory.create_model(model_config)
        num_parameter = utils.get_model_parameters(model)
        # parse metadata
        width = re.search(r"hidden(\d+)", study_dir).group(1)
        gamma = re.search(r"gamma(\d+\.?\d*)", study_dir).group(1)
        seed = re.search(r"seed(\d+)", study_dir).group(1)
        # load the data from the tensorboard log and validation loss file
        best_val_loss = float(np.loadtxt(val_loss_path))
        tb_log = load_tb_scalars(study_dir_path)
        val_loss = tb_log.get("val/loss", [])
        min_idx = np.argmin(val_loss)
        best_eps_loss = tb_log.get("val/eps", [])[min_idx]
        best_drude_loss = tb_log.get("val/drude", [])[min_idx]
        result_entry = {
            "seed": seed,
            "num_parameter": num_parameter,
            "lr": config_dict.optimizer["lr"],
            "val_loss": best_val_loss,
            "eps_loss": best_eps_loss,
            "drude_loss": best_drude_loss,
        }
        # insert the data into the nested structure
        results.setdefault(gamma, {}).setdefault(width, []).append(result_entry)
    return results

In [None]:

# directory containing the scaling law study
study_path = "/scratch/magr4985/Transformer_Scaling"

# directory to save the results
output_dir = "./scaling_data"
os.makedirs(output_dir, exist_ok=True)

# check if the study path exists, else just load in the results already stored in JSON files
json_path = os.path.join(output_dir, "transformer_scaling_results.json")
if os.path.exists(study_path) and not os.path.exists(json_path):
    print(f"Study path {study_path:s} exists, loading results from there")
    results = load_results(study_path)
    with open(json_path, "w") as f:
        json.dump(results, f, indent=4)
else:
    print(f"Loading results from JSON file")
    with open(json_path, "r") as f:
        results = json.load(f)
        
# figure directory
fig_dir = "./scaling_data/lr_scaling"
os.makedirs(fig_dir, exist_ok=True)

In [None]:
"""
Plot the data, i.e., how the validation loss behaves when training models
with different widths using different learning rate scaling exponents.
"""

ms = 4 # marker size

# figure setup
fig = plt.figure(figsize=(3.5, 5))
gs = gridspec.GridSpec(2, 2, width_ratios=[1, 0.05], height_ratios=[1, 1])

# plot scaling law, i.e., validation loss over width, for different learning rate scaling exponents
ax = fig.add_subplot(gs[0, 0])
norm = Normalize(vmin=0, vmax=2)
cmap = plt.get_cmap("viridis") 
for gamma in results:
    color = cmap(norm(float(gamma))) 
    width_dict = results[gamma]
    widths = sorted([int(key) for key in width_dict])
    model_parameter = []
    mean_val_loss = []
    for w in widths:
        model_parameter.append(width_dict[str(w)][0]["num_parameter"])
        mean = np.mean(np.array([d["val_loss"] for d in width_dict[str(w)]], dtype=float))
        mean_val_loss.append(mean)
    ax.plot(widths, mean_val_loss, "o-", markersize=ms, color=color)
    
# axis labels ticks
ax.set_xscale("log")
ax.set_xticks(widths)
ax.set_xticklabels(widths)
ax.xaxis.set_major_locator(FixedLocator(widths))
ax.tick_params(axis="x", which="minor", length=0)
ax.set_xlabel(r"$d_\mathrm{h}$")
ax.set_ylim([1.0, 1.4])
ax.set_yticks([1.0, 1.1, 1.2, 1.3, 1.4])
ax.set_ylabel(r"$L_\mathrm{val}$")

# colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
cbar_ax = fig.add_subplot(gs[0, 1])
cbar = fig.colorbar(sm, cax=cbar_ax, label=r"$\gamma$", aspect=25)
cbar.ax.tick_params(axis="y", which="minor", length=0)
cbar.ax.set_yticks([0.0, 0.5, 1.0, 1.5, 2.0])

# plot validation loss over learning rate scaling exponents for different model widths
ax = fig.add_subplot(gs[1, 0])
norm = LogNorm(vmin=16, vmax=1024)
cmap = plt.get_cmap("viridis") 
for width in widths:
    color = cmap(norm(int(width))) 
    gamma_values = []
    mean_val_loss = []
    for gamma in results:
        gamma_values.append(float(gamma))
        width_dict = results[gamma][str(width)]
        mean = np.mean(np.array([d["val_loss"] for d in width_dict], dtype=float))
        mean_val_loss.append(mean)
    gamma_values = np.array(gamma_values)
    mean_val_loss = np.array(mean_val_loss)
    sort_idx = np.argsort(gamma_values)
    gamma_values = gamma_values[sort_idx]
    mean_val_loss = mean_val_loss[sort_idx]
    ax.plot(gamma_values, mean_val_loss, "o-", markersize=ms, color=color)

# axis labels ticks
ax.set_xlabel(r"$\gamma$")
ax.set_ylim([1.0, 1.4])
ax.set_yticks([1.0, 1.1, 1.2, 1.3, 1.4])
ax.set_ylabel(r"$L_\mathrm{val}$")

# colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
cbar_ax = fig.add_subplot(gs[1, 1])
cbar = fig.colorbar(sm, cax=cbar_ax, label=r"$d_\mathrm{h}$", aspect=25)
cbar_ticks = widths 
cbar.ax.set_yticks(cbar_ticks)
cbar.ax.tick_params(axis="y", which="minor", length=0)
cbar.ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))

# adjust figure layout and save the figure
fig.tight_layout()
fig.align_labels()
fig.savefig(os.path.join(fig_dir, "transformer_lr_scaling_1.pdf"))

In [None]:
"""
Plot the data, i.e., how the validation loss behaves when training models
with different widths using different learning rate scaling exponents.
"""

ms = 4 # marker size

# figure setup
fig = plt.figure(figsize=(3.5, 3))
gs = gridspec.GridSpec(1, 2, width_ratios=[1, 0.05])

# plot scaling law, i.e., validation loss over width, for different learning rate scaling exponents
ax = fig.add_subplot(gs[0])
norm = Normalize(vmin=0, vmax=2)
cmap = plt.get_cmap("viridis") 
for gamma in results:
    color = cmap(norm(float(gamma))) 
    width_dict = results[gamma]
    widths = sorted([int(key) for key in width_dict])
    model_parameter = []
    mean_val_loss = []
    for w in widths:
        model_parameter.append(width_dict[str(w)][0]["num_parameter"])
        mean = np.mean(np.array([d["val_loss"] for d in width_dict[str(w)]], dtype=float))
        mean_val_loss.append(mean)
    ax.plot(widths, mean_val_loss, "o-", markersize=ms, color=color)
    
# axis labels ticks
ax.set_xscale("log")
ax.set_xticks(widths)
ax.set_xticklabels(widths)
ax.xaxis.set_major_locator(FixedLocator(widths))
ax.tick_params(axis="x", which="minor", length=0)
ax.set_xlabel(r"$d_\mathrm{h}$")
ax.set_ylim([1.0, 1.4])
ax.set_yticks([1.0, 1.1, 1.2, 1.3, 1.4])
ax.set_ylabel(r"$L_\mathrm{val}$")

# colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
cbar_ax = fig.add_subplot(gs[1])
cbar = fig.colorbar(sm, cax=cbar_ax, label=r"$\gamma$", aspect=25)
cbar.ax.tick_params(axis="y", which="minor", length=0)
cbar.ax.set_yticks([0.0, 0.5, 1.0, 1.5, 2.0])

# adjust figure layout and save the figure
fig.tight_layout()
fig.align_labels()
fig.savefig(os.path.join(fig_dir, "transformer_lr_scaling_1_v1.pdf"))

# figure setup
fig = plt.figure(figsize=(3.5, 3))
gs = gridspec.GridSpec(1, 2, width_ratios=[1, 0.05])

# plot validation loss over learning rate scaling exponents for different model widths
ax = fig.add_subplot(gs[0])
norm = LogNorm(vmin=16, vmax=1024)
cmap = plt.get_cmap("viridis") 
for width in widths:
    color = cmap(norm(int(width))) 
    gamma_values = []
    mean_val_loss = []
    for gamma in results:
        gamma_values.append(float(gamma))
        width_dict = results[gamma][str(width)]
        mean = np.mean(np.array([d["val_loss"] for d in width_dict], dtype=float))
        mean_val_loss.append(mean)
    gamma_values = np.array(gamma_values)
    mean_val_loss = np.array(mean_val_loss)
    sort_idx = np.argsort(gamma_values)
    gamma_values = gamma_values[sort_idx]
    mean_val_loss = mean_val_loss[sort_idx]
    ax.plot(gamma_values, mean_val_loss, "o-", markersize=ms, color=color)

# axis labels ticks
ax.set_xlabel(r"$\gamma$")
ax.set_ylim([1.0, 1.4])
ax.set_yticks([1.0, 1.1, 1.2, 1.3, 1.4])
ax.set_ylabel(r"$L_\mathrm{val}$")

# colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
cbar_ax = fig.add_subplot(gs[1])
cbar = fig.colorbar(sm, cax=cbar_ax, label=r"$d_\mathrm{h}$", aspect=25)
cbar_ticks = widths 
cbar.ax.set_yticks(cbar_ticks)
cbar.ax.tick_params(axis="y", which="minor", length=0)
cbar.ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))

# adjust figure layout and save the figure
fig.tight_layout()
fig.align_labels()
fig.savefig(os.path.join(fig_dir, "transformer_lr_scaling_1_v2.pdf"))

In [None]:
"""
Analyze which learning rate works best for each model width and how it scales with the number of model parameters.
"""

ms = 4 # marker size

# figure setup
fig = plt.figure(figsize=(3.5, 5))
gs = gridspec.GridSpec(2, 2, width_ratios=[1, 0.05], height_ratios=[1, 1])

# validation loss vs learning rate for each model width
ax = fig.add_subplot(gs[0, 0])
norm = LogNorm(vmin=16, vmax=1024)
cmap = plt.get_cmap("viridis") 
model_parameter = []
best_lr_per_width = []
best_val_loss_per_width = []
for width in widths:
    color = cmap(norm(int(width))) 
    lr_list = []
    mean_val_loss = []
    for gamma in results:
        data = results[gamma][str(width)]
        lr_list.append(data[0]["lr"])
        mean = np.mean(np.array([d["val_loss"] for d in data], dtype=float))
        mean_val_loss.append(mean)
    model_parameter.append(data[0]["num_parameter"])
    lr_list = np.array(lr_list)
    mean_val_loss = np.array(mean_val_loss)
    min_idx = np.argmin(mean_val_loss)
    best_lr_per_width.append(lr_list[min_idx])
    best_val_loss_per_width.append(mean_val_loss[min_idx])
    ax.plot(lr_list, mean_val_loss, "o-", markersize=ms, color=color)
ax.plot(best_lr_per_width, best_val_loss_per_width, "k-")

# axis labels ticks
ax.set_xscale("log")
ax.set_xlabel(r"$\eta_\mathrm{max}$")
ax.set_ylabel(r"$L_\mathrm{val}$")

# colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
cbar_ax = fig.add_subplot(gs[0, 1])
cbar = fig.colorbar(sm, cax=cbar_ax, label=r"$d_\mathrm{h}$", aspect=25)
cbar_ticks = widths 
cbar.ax.set_yticks(cbar_ticks)
cbar.ax.tick_params(axis="y", which="minor", length=0)
cbar.ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%d"))

# fit the best learning rate depending on the number of model parameters
base_model_parameter = model_parameter[4]
log_x = np.log(np.array(model_parameter) / base_model_parameter) # use the 'width = 256' as reference point
log_y = np.log(best_lr_per_width)
slope, intercept = np.polyfit(log_x, log_y, 1) # unweighted linear fit
gamma = -slope # slope = -alpha
eta0 = np.exp(intercept)
x_fit = np.logspace(np.log10(min(model_parameter)), np.log10(max(model_parameter)), 100)
y_fit = eta0 / ((x_fit / base_model_parameter) ** gamma)

# plot the empirical learning rate scaling law
ax = fig.add_subplot(gs[1, 0])
ax.plot(model_parameter, best_lr_per_width, "ko-", markersize=ms)
ax.plot(x_fit, y_fit, "k--")
ax.annotate(
    f"$\\eta_\\mathrm{{max}}(N) = \\eta^\\mathrm{{c}}_\\mathrm{{max}} \\cdot \\left(N_0 / N\\right)^{{{gamma:.2f}}}$",
    xy=(x_fit[50], y_fit[50]),
    xytext=(-20, 30),
    textcoords="offset points",
    color="k",
    fontsize=8,
    arrowprops=dict(arrowstyle="->", lw=1, color="k"),
    ha="left", va="bottom",
)

# axis labels ticks
ax.set_xscale("log")
ax.set_xlabel(r"$N$")
ax.tick_params(axis="x", which="minor", length=0)
ax.set_yscale("log")
ax.tick_params(axis="y", which="minor", length=0)
ax.set_ylabel(r"Optimal $\eta_\mathrm{max}(N)$")

# global figure settings
fig.tight_layout()
fig.align_labels()
fig.savefig(os.path.join(fig_dir, "transformer_lr_scaling_2.pdf"))