### Plots

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from zoology.analysis.utils import fetch_wandb_runs
import wandb

##################################################################################################################
# get these from latex, see below
COLUMNWIDTH = 245.71811 # width of a single column
TEXTWIDTH = 397.48499 # width of two columns
PPP = 439.3701 # powerpoint

FONTSIZE = 10
CONTEXT = "paper" # either 'paper' or 'talk'
##################################################################################################################

def set_size(width, fraction=1, subplot=[1, 1]):
    """ Set aesthetic figure dimensions to avoid scaling in latex.

    Parameters
    ----------
    width: float
            Width in pts. Run "\showthe\textwidth" after \begin{document} in Latex and search for the
            textwidth in the log file. Alternatively use "\showthe\columnwidth" if the paper is double
            column. The output looks like this:
                > xyz.0pt.
                X.Y \showthe\textwidth
            Pass xyz to this function.
    fraction: float
            Fraction of the width which you wish the figure to occupy
    subplot: list
            [rows, columns] of subplots
    """
    # Width of figure
    fig_width_pt = width * fraction

    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    golden_ratio = (5**.5 - 1) / 2

    fig_width_in = fig_width_pt * inches_per_pt
    fig_height_in = fig_width_in * golden_ratio * (float(subplot[0]) / float(subplot[1]))

    fig_dim = (fig_width_in, fig_height_in)
    return fig_dim

if CONTEXT == "talk":
    FACTOR = 1.5
elif CONTEXT == "paper":
    FACTOR = 1.
else:
    FACTOR = 0.

sns.set(CONTEXT, "whitegrid", rc={'xtick.bottom': True,
                                  'ytick.left': True,
                                  'xtick.color': '0.1',
                                  'ytick.color': '0.1',
                                  'text.usetex': True,
                                  'font.family': 'serif',
                                  'font.serif': 'Times',
                                  'figure.titlesize': 0.9*FONTSIZE,
                                  'axes.titlesize': 0.9*FONTSIZE,
                                  'axes.labelsize': 0.9*FONTSIZE,
                                  'font.size': 0.9*FONTSIZE,
                                  'legend.fontsize': int(0.8*FONTSIZE),
                                  'legend.title_fontsize': int(0.8*FONTSIZE),
                                  'xtick.labelsize': int(0.8*FONTSIZE),
                                  'ytick.labelsize': int(0.8*FONTSIZE),
                                  'axes.linewidth': 0.8*FACTOR,
                                  'grid.linewidth': 0.5*FACTOR,
                                  'xtick.major.size': 4.5*FACTOR,
                                  'xtick.major.width': 0.8*FACTOR,
                                  'xtick.minor.width': 0.64*FACTOR,
                                  'ytick.major.size': 4.5*FACTOR,
                                  'ytick.major.width': 0.8*FACTOR,
                                  'ytick.minor.width': 0.64*FACTOR,})
sns.set_palette("deep")

### other definitions ######################################################################
# e.g. plt.rcParams['text.latex.preamble'] = [r'\usepackage{newtxmath}']
plt.rcParams['text.latex.preamble'] = r"\usepackage{newtxmath}"

#### Attention & SSMs

In [None]:
MODELS = [
    "sm-attention",
    "lin-attention",
    "norm-attention",
    "s6-2",
    "s6"
]

df_64_mamba = fetch_wandb_runs(
        sweep_id=[
            "run2024-05-17-seqlen64-kv4"
        ], 
        project_name="neurips-2024"
    )
df_64 = fetch_wandb_runs(
        sweep_id=[
            "run2024-05-20-seqlen64-kv4"
        ], 
        project_name="neurips-2024"
    )
df_128 = fetch_wandb_runs(
        sweep_id=[
            "run2024-05-20-seqlen128-kv8"
        ], 
        project_name="neurips-2024"
    )
df_256 = fetch_wandb_runs(
        sweep_id=[
            "run2024-05-20-seqlen256-kv16"
        ], 
        project_name="neurips-2024"
    )
df_512 = fetch_wandb_runs(
        sweep_id=[
            "run2024-05-20-seqlen512-kv64"
        ], 
        project_name="neurips-2024"
    )

In [None]:
split_df_64 = {}
for model in MODELS:
    try:
        if model in ["s6", "s6-2"]:
            split_df_64[model] = df_64_mamba[df_64_mamba["model.sequence_mixer.name"] == model]
        else:
            split_df_64[model] = df_64[df_64["name"].str.startswith(model)]
    except:
        print("No {} found!".format(model))
        continue

split_df_128 = {}
for model in MODELS:
    try:
        if model in ["s6", "s6-2"]:
            split_df_128[model] = df_128[df_128["model.sequence_mixer.name"] == model]
        else:
            split_df_128[model] = df_128[df_128["name"].str.startswith(model)]
    except:
        print("No {} found!".format(model))
        continue

split_df_256 = {}
for model in MODELS:
    try:
        if model in ["s6", "s6-2"]:
            split_df_256[model] = df_256[df_256["model.sequence_mixer.name"] == model]
        else:
            split_df_256[model] = df_256[df_256["name"].str.startswith(model)]
    except:
        print("No {} found!".format(model))
        continue

split_df_512 = {}
for model in MODELS:
    try:
        if model in ["s6", "s6-2"]:
            split_df_512[model] = df_512[df_512["model.sequence_mixer.name"] == model]
        else:
            split_df_512[model] = df_512[df_512["name"].str.startswith(model)]
    except:
        print("No {} found!".format(model))
        continue

In [None]:
EVAL = {}
seq_len = ["64", "128", "256", "512"]
for s, split_df in enumerate([split_df_64, split_df_128, split_df_256, split_df_512]):
    model_eval = {}
    for model in MODELS:
        D_QK = sorted(split_df[model]["model.d_qk"].unique())
        D_MODEL = sorted(split_df[model]["model.d_model"].unique())
        if 16 in D_QK: # remove n=16
            D_QK.pop(0)

        eval_dict = {"d_qk": D_QK, "d_model": D_MODEL}
        eval_mat = np.zeros((len(D_MODEL), len(D_QK))) # always aranged as row=d_model, col=d_qk
        for i, d_model in enumerate(D_MODEL):
            d_idx = split_df[model]["model.d_model"] == d_model
            cache = split_df[model][d_idx]
            for j, d_qk in enumerate(D_QK):
                qk_idx = cache["model.d_qk"] == d_qk
                eval_mat[i,j] = cache[qk_idx]["valid/accuracy"].dropna().max()
        eval_dict["acc"] = eval_mat

        model_eval[model] = eval_dict
    
    EVAL[seq_len[s]] = model_eval

In [None]:
fig, ax = plt.subplots(5, 4, figsize=set_size(TEXTWIDTH, subplot=[8, 4]), sharey=True, sharex=True, gridspec_kw={'width_ratios': [1, 1, 1, 1.25]})
LABELS = {"sm-attention": "Softmax att.",
          "lin-attention": "Linear att.",
          "norm-attention": "Normalized att.",
          "s6": "S6",
          "s6-2": "SSD"}
KV_PAIRS = {"64": 4, "128": 8, "256": 16, "512": 64}
cmap = sns.color_palette("coolwarm", as_cmap=True)

for i, eval in enumerate(EVAL):
    for j, model in enumerate(MODELS):
        acc = EVAL[eval][model]["acc"]
        acc_flat = [val if val <= 0.99 else "$>$99" for val in acc.flatten()]
        annot = np.asarray(["{0:.1f}".format(val*100) if type(val) is np.float64 else val for val in acc_flat]).reshape(acc.shape)
        if i == len(EVAL)-1:
            sns.heatmap(acc, annot=annot, fmt="", annot_kws={"fontsize": int(0.7*FONTSIZE)}, cmap=cmap, vmin=0, vmax=1, xticklabels = EVAL[eval][model]["d_qk"], yticklabels=EVAL[eval][model]["d_model"],
                        linewidths=0.3, linecolor='white', ax=ax[j, i], cbar=True)
        else:
            sns.heatmap(acc, annot=annot, fmt="", annot_kws={"fontsize": int(0.7*FONTSIZE)}, cmap=cmap, vmin=0, vmax=1, xticklabels = EVAL[eval][model]["d_qk"], yticklabels=EVAL[eval][model]["d_model"],
                        linewidths=0.3, linecolor='white', ax=ax[j, i], cbar=False)
        if j == 0:
            ax[j, i].set_title("L: {0}, KV-pairs: {1}".format(eval,KV_PAIRS[eval]))
        if i == 0:
            ax[j,i].set_title(LABELS[model], rotation='vertical', x=-0.55, y=0.5, ha="center", va="center",fontsize=0.9*FONTSIZE)
            ax[j,i].set_ylabel("d")
        if i != 0:
            ax[j,i].tick_params(axis='y', which='both', left=False)
        if j == len(MODELS)-1:
            ax[j,i].set_xlabel("n")
        if j != len(MODELS)-1:
            ax[j,i].tick_params(axis='x', which='both', bottom=False)
        
        ax[j,i].grid(False)
ax[0,0].text(2., -0.32, "L: {0}, KV-pairs: {1}".format("64",KV_PAIRS["64"]), ha='center', fontsize=0.9*FONTSIZE)
        
plt.tight_layout()
plt.savefig('full.pdf', format='pdf', bbox_inches='tight', pad_inches=0)

##### Figure 1

In [None]:
lin_att = {"256": EVAL["256"]["lin-attention"], "512": EVAL["512"]["lin-attention"]}
sm_att = {"256": EVAL["256"]["sm-attention"], "512": EVAL["512"]["sm-attention"]}

In [None]:
fig, ax = plt.subplots(1, 2, figsize=set_size(TEXTWIDTH, subplot=[1, 2]), sharey=True)
LABELS = ["Linear attention (16)", "Softmax attention (2)"]
colors = sns.color_palette("deep")

# 256
lin_eval = lin_att["256"]
sm_eval = sm_att["256"]

ax[0].plot(lin_eval["acc"][-1,:], c=colors[0], marker="o", label=LABELS[0])
ax[0].plot(sm_eval["acc"][-1,:], c=colors[1], marker="o", label=LABELS[1])

ax[0].set_ylim([0.0, 1.04])
ax[0].set_xticks(ticks=np.arange(len(lin_eval["d_qk"])),labels=lin_eval["d_qk"])
ax[0].set_xlabel("State expansion $n$")
ax[0].set_ylabel("Accuracy")
ax[0].set_title("L: {0}, KV-pairs: {1}".format("256",16))
ax[0].grid(True)

# 512
lin_eval = lin_att["512"]
sm_eval = sm_att["512"]

ax[1].plot(lin_eval["acc"][-1,:], c=colors[0], marker="o", label=LABELS[0])
ax[1].plot(sm_eval["acc"][-1,:], c=colors[1], marker="o", label=LABELS[1])

ax[1].set_ylim([0.0, 1.04])
ax[1].set_xticks(ticks=np.arange(len(lin_eval["d_qk"])),labels=lin_eval["d_qk"])
ax[1].set_xlabel("State expansion $n$")
ax[1].tick_params(axis='y', which='both', left=False)
ax[1].set_title("L: {0}, KV-pairs: {1}".format("512",64))
ax[1].legend(loc='lower right')
ax[1].grid(True)

plt.tight_layout()
plt.savefig('state-expansion.pdf', format='pdf', bbox_inches='tight', pad_inches=0)

#### Figure 2

In [None]:
fig, ax = plt.subplots(1, 1, figsize=set_size(TEXTWIDTH, fraction=0.55, subplot=[0.9, 1]))
LABELS = ["Softmax att. (2)",
          "Linear att. (16)",
          "Normalized att. (21)",
          "SSD [Dao and Gu, 2024]",
          "S6 [Gu and Dao, 2023]"]
colors = sns.color_palette("deep")

model_eval = EVAL["512"]
n = 128

for i,model in enumerate(MODELS):
    acc = model_eval[model]["acc"]
    idx = np.array(model_eval[model]["d_qk"]) == n
    ax.plot(acc[:,idx], c=colors[i], marker="o", label=LABELS[i])

ax.set_ylim([-0.04, 1.04])
ax.set_xticks(ticks=np.arange(len(model_eval[model]["d_model"])),labels=model_eval[model]["d_model"])
ax.set_xlabel("Model size $d$")
ax.set_ylabel("Accuracy")
#ax.set_title("L: {0}, KV-pairs: {1}".format("512",64))
#ax.legend(ncol=2, loc="center", bbox_to_anchor=(0.48, 1.23))
ax.grid(True)

plt.tight_layout()
plt.savefig('normalization.pdf', format='pdf', bbox_inches='tight', pad_inches=0)

#### RNNs

In [None]:
df_64 = fetch_wandb_runs(
        sweep_id=[
            "run2024-05-20-seqlen64-kv4"
        ], 
        project_name="neurips-2024"
    )
df_128 = fetch_wandb_runs(
        sweep_id=[
            "run2024-05-20-seqlen128-kv8"
        ], 
        project_name="neurips-2024"
    )
df_256 = fetch_wandb_runs(
        sweep_id=[
            "run2024-05-20-seqlen256-kv16"
        ], 
        project_name="neurips-2024"
    )

In [None]:
split_df_64 = {}
for model in ["qlstm", "qlstm-rev"]:
    try:
        if model in ["qlstm"]:
            split_df_64[model] = df_64[df_64["model.sequence_mixer.kwargs.reversed"] == False]
        elif model in ["qlstm-rev"]:
            split_df_64[model] = df_64[df_64["model.sequence_mixer.kwargs.reversed"] == True]
    except:
        print("No {} found!".format(model))
        continue

split_df_128 = {}
for model in ["qlstm", "qlstm-rev"]:
    try:
        if model in ["qlstm"]:
            split_df_128[model] = df_128[df_128["model.sequence_mixer.kwargs.reversed"] == False]
        elif model in ["qlstm-rev"]:
            split_df_128[model] = df_128[df_128["model.sequence_mixer.kwargs.reversed"] == True]
    except:
        print("No {} found!".format(model))
        continue

split_df_256 = {}
for model in ["qlstm", "qlstm-rev"]:
    try:
        if model in ["qlstm"]:
            split_df_256[model] = df_256[df_256["model.sequence_mixer.kwargs.reversed"] == False]
        elif model in ["qlstm-rev"]:
            split_df_256[model] = df_256[df_256["model.sequence_mixer.kwargs.reversed"] == True]
    except:
        print("No {} found!".format(model))
        continue

In [None]:
EVAL = {}
seq_len = ["64", "128", "256"]
for s, split_df in enumerate([split_df_64, split_df_128, split_df_256]):
    model_eval = {}
    for model in ["qlstm", "qlstm-rev"]:
        D_QK = sorted(split_df[model]["model.d_qk"].unique())
        D_MODEL = sorted(split_df[model]["model.d_model"].unique())

        eval_dict = {"d_qk": D_QK, "d_model": D_MODEL}
        eval_mat = np.zeros((len(D_MODEL), len(D_QK))) # always aranged as row=d_model, col=d_qk
        for i, d_model in enumerate(D_MODEL):
            d_idx = split_df[model]["model.d_model"] == d_model
            cache = split_df[model][d_idx]
            for j, d_qk in enumerate(D_QK):
                qk_idx = cache["model.d_qk"] == d_qk
                eval_mat[i,j] = cache[qk_idx]["valid/accuracy"].dropna().max()
        eval_dict["acc"] = eval_mat

        model_eval[model] = eval_dict
    
    EVAL[seq_len[s]] = model_eval

In [None]:
fig, ax = plt.subplots(1, 3, figsize=set_size(TEXTWIDTH, subplot=[1.4, 3]), sharey=False)
LABELS = {"qlstm": "qLSTM", "qlstm-rev": "qLSTM w/ (22)"}
colors = sns.color_palette("deep")

# 64
model_eval = EVAL["64"]

for i,model in enumerate(["qlstm", "qlstm-rev"]):
    ax[0].plot(model_eval[model]["acc"], c=colors[i], marker="o", label=LABELS[model])
    print(model)
    print(model_eval[model]["acc"])
ax[0].set_ylim([0.8, 1.01])
ax[0].set_xticks(ticks=np.arange(len(model_eval["qlstm"]["d_model"])),labels=model_eval["qlstm"]["d_model"])
ax[0].set_xlabel("Model size $d$")
ax[0].set_ylabel("Accuracy")
ax[0].set_title("L: {0}, KV-pairs: {1}".format("64",4))
ax[0].legend()
ax[0].grid(True)

# 128
model_eval = EVAL["128"]

for i,model in enumerate(["qlstm", "qlstm-rev"]):
    ax[1].plot(model_eval[model]["acc"], c=colors[i], marker="o", label=LABELS[model])
    print(model)
    print(model_eval[model]["acc"])
ax[1].set_ylim([0.5, 1.01])
ax[1].set_xticks(ticks=np.arange(len(model_eval["qlstm"]["d_model"])),labels=model_eval["qlstm"]["d_model"])
ax[1].set_xlabel("Model size $d$")
ax[1].set_ylabel("Accuracy")
ax[1].set_title("L: {0}, KV-pairs: {1}".format("128",8))
ax[1].grid(True)

# 256
model_eval = EVAL["256"]

for i,model in enumerate(["qlstm", "qlstm-rev"]):
    ax[2].plot(model_eval[model]["acc"], c=colors[i], marker="o", label=LABELS[model])
    print(model)
    print(model_eval[model]["acc"])
ax[2].set_ylim([0., 1.01])
ax[2].set_xticks(ticks=np.arange(len(model_eval["qlstm-rev"]["d_model"])),labels=model_eval["qlstm-rev"]["d_model"])
ax[2].set_xlabel("Model size $d$")
ax[2].set_ylabel("Accuracy")
ax[2].set_title("L: {0}, KV-pairs: {1}".format("256",16))
ax[2].grid(True)
plt.tight_layout()
plt.savefig('qlstm.pdf', format='pdf', bbox_inches='tight', pad_inches=0)