In [1]:
import torch
import torch.nn.functional as F
import os
import numpy as np
import scipy as sp
import seaborn as sns
import matplotlib.pyplot as plt
import yaml
import json
from model import TFModel, calc_rotary_R_mat_simple
from utils import create_folder
from tqdm import tqdm

## we need a rank estimator package "screenot", install below
# !pip install screenot
from screenot.ScreeNOT import adaptiveHardThresholding

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
class Config:
    """
    This is the configuration class to store the configuration of a TFModel. It is used to
    instantiate a model according to the specified arguments, defining the model architecture.
    """

    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

In [3]:
def QKOV_matching(model, config, rel_pos=0, ranks=[10]):
    """
    QKOV_matching calculates how QK and OV are matched in a model from class "TFModel" in the simulations
    Arguments:
        model is in the Transformer class "TFModel"
        config: model configuration
        rel_pos: if positive integer, we will calculate rotary R matrix and replace QK^T bt QRK^T
        ranks: a list of different ranks for calculating subspace dimension
    Returns:
        QK0: a (vocab_size, vocab_size)-shaped array that calculates wte @ W_qkov @ wte.T
        z_score: a scalar that tests whether diagnoal element is signficant higher (before factoring in size/dimension)
        s_match: a (len(ranks), 2)-shaped array that measures genearlized cosine sim scores between OV and QK subspaces
        Vt_qk.T, U_ov: singular vectors
    """
    d_model = config.d_model
    assert config.num_heads == 1, "Only single head self-attention is supported."
    assert config.pos == "rotary", "Only rotary embedding is supported."

    R = calc_rotary_R_mat_simple(d_model, rel_dist=rel_pos).numpy(force=True) # d_head equals d_model
    
    wte = model.embed.embed.weight.numpy(force=True)  # (vocab_size, d_model)
    W_q = model.h[1].mha.W_q.weight.numpy(force=True)
    W_k = model.h[1].mha.W_k.weight.numpy(force=True)
    W_v = model.h[0].mha.W_v.weight.numpy(force=True)
    W_o = model.h[0].mha.W_o.weight.numpy(force=True)
    W_qk = W_q.T @ R @ W_k / np.sqrt(d_model)
    W_ov = W_o @ W_v
    W_qkov = W_qk @ W_ov

    QK0 = wte @ W_qkov @ wte.T
    # plot normalized score
    z_score = (np.mean(np.diag(W_qkov)) - np.mean(W_qkov)) / np.std(W_qkov)

    # subspace matching
    U_qk, s_qk, Vt_qk = np.linalg.svd(W_qk)
    U_ov, s_ov, Vt_ov = np.linalg.svd(W_ov)
    s_match = np.zeros((len(ranks), 2))
    for j, rank in enumerate(ranks):
        _, s, _ = np.linalg.svd(Vt_qk[:rank, :] @ U_ov[:, :rank])
        s_match[j, 0] = s[0]
        s_match[j, 1] = np.sqrt(np.mean(s**2))  # ADD: np.sqrt

    out = (
        (QK0, z_score, s_match, Vt_qk.T, U_ov)
    )
    return out

def position_matching(model, config, ranks=[10]):
    """
    position_matching calculates how much information of token embeddings is used by 1st layer attention; 
        model is from class "TFModel" in the simulations
    Arguments:
        model is in the Transformer class "TFModel"
        config: model configuration
        ranks: a list of different ranks for calculating subspace dimension
    Returns:
        coeffs: a 3-order array that show how much individual token information vs. mean information is used by 1st layer attention
        s_q: singular values of W_q matrix
        s_k: singular values of W_k matrix
    """
    vocab_size = config.vocab_size
    assert config.num_heads == 1, "Only single head self-attention is supported."
    assert config.pos == "rotary", "Only rotary embedding is supported."

    wte = model.embed.embed.weight.numpy(force=True)  # (vocab_size, d_model)
    mu = np.mean(wte, axis=0, keepdims=True)
    wte0 = np.concatenate((wte - mu, mu), axis=0) # separating individual token effects and mean effect
    W_q = model.h[0].mha.W_q.weight.numpy(force=True)
    W_k = model.h[0].mha.W_k.weight.numpy(force=True)
    U_q, s_q, Vt_q = np.linalg.svd(W_q)
    U_k, s_k, Vt_k = np.linalg.svd(W_k)
    coeffs = np.zeros((len(ranks), vocab_size + 1, 4))
    for i, rank in enumerate(ranks):
        wte0_nml = wte0 / np.linalg.norm(wte0, axis=1, keepdims=True)
        for j, V in enumerate([Vt_q[:rank].T, Vt_k[:rank].T]):
            coeffs[i, :, j] = np.linalg.norm(wte0_nml @ V, axis=1)
        coeffs[i, :, 2] = np.linalg.norm(wte0_nml @ W_q, axis=1)
        coeffs[i, :, 3] = np.linalg.norm(wte0_nml @ W_k, axis=1)

    est_ranks = np.zeros(4)
    for j, mat in enumerate([W_q, W_k, wte, wte0_nml]):
        _, _, r = adaptiveHardThresholding(mat, 20)
        est_ranks[j] = r

    return coeffs, [s_q, s_k], est_ranks


def match_between_layer(model, config, ranks=[10]):
    """
    match_between_layer calculates how much subspaces from two layers match
    Arguments:
        model is in the Transformer class "TFModel"
        config: model configuration
        ranks: a list of different ranks for calculating subspace dimension
    Returns:
        s_match: a (len(ranks), 4)-shaped array that measures genearlized cosine sim scores between subspaces in two layers
    """
    assert config.num_heads == 1, "Only single head self-attention is supported."
    assert config.pos == "rotary", "Only rotary embedding is supported."

    wte = model.embed.embed.weight.numpy(force=True)  # (vocab_size, d_model)
    W_v = model.h[0].mha.W_v.weight.numpy(force=True)
    W_o = model.h[0].mha.W_o.weight.numpy(force=True)
    W_ov = W_o @ W_v
    U_ov, s_ov, Vt_ov = np.linalg.svd(W_ov)
    U_wte, s_wte, Vt_wte = np.linalg.svd(wte)

    s_match = np.zeros((len(ranks), 2))
    for j, rank in enumerate(ranks):
        _, s, _ = np.linalg.svd(Vt_wte[:rank, :] @ U_ov[:, :rank])
        s_match[j, 0] = s[0]
        s_match[j, 1] = np.sqrt(np.mean(s**2))  # ADD: np.sqrt
    return s_match


def match_baseline(d, rank, num_run=100):
    """
    match_baseline calculate a random baseline for subspace matching
    Args:
        d is the dimension
        rank is the subspace rank
        num_run: the number of independent runs 
    Returns:
        out: a (2,) shaped array containing two subspace matching scores averaged across runs
    """
    scores = np.zeros((num_run, 2))
    for i in range(num_run):
        W1 = np.random.randn(d, d)
        W2 = np.random.randn(d, d)
        U_1, s_1, Vt_1 = np.linalg.svd(W1)
        U_2, s_2, Vt_2 = np.linalg.svd(W2)
        u, s, vt = np.linalg.svd(U_1[:, :rank].T @ U_2[:, :rank])
        scores[i, 0] = s[0]
        scores[i, 1] = np.sqrt(np.mean(s**2))

    out = np.mean(scores, axis=0)
    return out

In [4]:
def plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, match_method = "largest", 
                              rel_pos = 0, ranks = [5, 10, 15, 20], num_epoch=20000, epoch_step=200):
    """
    plot_dynamics_measurement makes plot of various measurements based on checkpoints saved in the simulations
    Args:
        ckpt_dir: checkpoint directory, model should be from the Transformer class "TFModel"
        setting: the name of the setting for the checkpoints
        save_dir: directory for saving plots
        config: model configuration
        match_method: which method is used to calculate subspace matching, 'largest' | 'mean'
        rel_pos: we calculate a rotary R matrix and use it to adjust QK calculation, namley Q R K^T; 
            default is 0, so that R is an identity matrix
        ranks: a list of ranks, each rank is a candidate when calculating subspace matching
        num_epoch: the total number of epochs/iterations in training
        epoch_step: the number of steps in terms of epoch/iteration between two consecutive checkpoints
    """
    create_folder(save_dir)
    save_fig_dir = os.path.join(save_dir, f"train_matching_measure_{setting}")
    create_folder(save_fig_dir)

    plot_steps = num_epoch // epoch_step
    k = 0 if match_method == "largest" else 1
    L = len(ranks)

    errs = np.zeros((num_epoch, 2)) # ID/OOD test errors
    err_json = os.path.join(ckpt_dir, setting, f"err_arr.json")
    errs_arr = json.load(open(err_json))
    errs[:, 0] = np.array([errs_arr[t]["err_test"] for t in range(num_epoch)])
    errs[:, 1] = np.array([errs_arr[t]["err_ood"] for t in range(num_epoch)])

    ratios = np.zeros((L, plot_steps, 3))
    z_scores = np.zeros((L, plot_steps))
    s_matches = np.zeros((L, plot_steps, 2))
    match_layer = np.zeros((L, plot_steps, 2))
    est_ranks = np.zeros((plot_steps, 2))
    mismatch_ratio = np.zeros(plot_steps)
    s_base = [match_baseline(config.d_model, rank, num_run=100) for rank in ranks]

    for i, iter in enumerate(range(epoch_step, num_epoch+epoch_step, epoch_step)):
        model = TFModel(config)
        ckpt_path = os.path.join(ckpt_dir, setting, f"ckpt_{iter}.pt")
        model.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cpu")))
        model.eval()

        coeffs, _, est_r = position_matching(model, config, ranks=ranks)
        if model.num_layers == 2:
            QK0, z_score, s_match, _, _ = QKOV_matching(model, config, ranks=ranks, rel_pos=rel_pos)
        layer_scores = match_between_layer(model, config, ranks=ranks)
        
        ratios[:, i, 0] = np.mean(coeffs[:, :-1, 1], axis=1)  # explained by query: 0, key: 1
        ratios[:, i, 1] = np.std(coeffs[:, :-1, 1], axis=1) # explained by query: 0, key: 1
        ratios[:, i, 2] = coeffs[:, -1, 1] # explained by query: 0, key: 1
        if model.num_layers == 2:
            z_scores[:, i] = z_score
            s_matches[:, i, 0] = s_match[:, 0]
            s_matches[:, i, 1] = s_match[:, 1] 
            mismatch_ratio[i] = 1 - np.mean(np.argmax(QK0, axis=1) == np.arange(QK0.shape[0], dtype=int))
        match_layer[:, i] = layer_scores
        est_ranks[i] = est_r[0], est_r[1]        

    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]  # default colors
    xrange = range(epoch_step, num_epoch+epoch_step, epoch_step)
    title_fontsize = 17
    legend_fontsize = 14
    ticklabel_fontsize = 13

    for rank_idx, rank in enumerate(ranks):
        fig, axs = plt.subplots(1, 5, figsize=(6 * 5, 6 * 1), dpi=200)
        axs[0].plot(range(num_epoch), errs[:,0], linewidth=4, color=colors[0], label="ID")
        axs[0].plot(range(num_epoch), errs[:,1], linewidth=4, color=colors[1], label="OOD")
        axs[0].set_title("ID/OOD test errors", weight="bold", fontsize=title_fontsize)
        axs[0].legend(prop={'size': legend_fontsize, 'weight':'bold'})
        axs[0].set_xscale("log")
        axs[1].plot(xrange, z_scores[rank_idx], "-o")
        axs[1].set_title("Diagonal prefactor z-score", weight="bold", fontsize=title_fontsize)
        axs[2].plot(xrange, s_matches[rank_idx, :, k], "-o", color=colors[0], label="QK-OV")
        axs[2].axhline(y=s_base[rank_idx][k], linestyle="--", color="gray", label="baseline")
        axs[2].set_title("Matching QK and OV subspaces", weight="bold", fontsize=title_fontsize)
        axs[3].plot(xrange, mismatch_ratio, "-o", color=colors[0])
        axs[3].set_title("Token mismatch ratio at 2nd layer", weight="bold", fontsize=title_fontsize)
        axs[4].plot(xrange, match_layer[rank_idx, :, k], "-o", color=colors[0], label="layer-match")
        axs[4].axhline(y=s_base[rank_idx][k], linestyle="--", color="gray", label="baseline")
        axs[4].set_title("Matching wte and 1st OV", weight="bold", fontsize=title_fontsize)
        
        xtick_step = 2000
        xtick_range = list(range(0, num_epoch+xtick_step, xtick_step))
        xtick_labels = [str(int(xtick_range[j] / 1000)) + 'K' for j in range(len(xtick_range))]
        xtick_labels[0] = '0'
        for idx_plt in range(5):
            axs[idx_plt].tick_params(axis='both', which='major', labelsize=ticklabel_fontsize)
            if idx_plt > 0:
                axs[idx_plt].set_xticks(xtick_range)
                axs[idx_plt].set_xticklabels(xtick_labels)
            
        plt.savefig(
            os.path.join(save_fig_dir, f"rank_{rank}_method_{match_method}_rel_pos_{rel_pos}"),
            bbox_inches="tight",
        )
        plt.close()

    return

In [5]:
setting_dir = "infinite_pool_size"
match_method = "mean"

with open(
    os.path.join("out_phase_progress_long", setting_dir, "config.json"), "r"
) as f:
    config = Config(**json.load(f))
config.device = "cpu"

In [6]:
ckpt_dir = "out_phase_progress_long"
save_dir = "Figs_measurement"

setting = "infinite_pool_size"
plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, match_method=match_method)

setting = "large_pool_size"
plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, match_method=match_method)

setting = "small_pool_size"
plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, match_method=match_method)

setting = "boundary_pool_size_740"
plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, match_method=match_method)

setting = "boundary_pool_size_750"
plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, match_method=match_method)

In [7]:
ckpt_dir = "out_phase_progress_long"
save_dir = "Figs_measurement/varying_rel_pos"

for rel_pos in tqdm(range(0, 40, 5)):
    setting = "infinite_pool_size"
    plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, rel_pos=rel_pos, match_method=match_method)

    setting = "large_pool_size"
    plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, rel_pos=rel_pos, match_method=match_method)

    setting = "small_pool_size"
    plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, rel_pos=rel_pos, match_method=match_method)

    setting = "boundary_pool_size_740"
    plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, rel_pos=rel_pos, match_method=match_method)

    setting = "boundary_pool_size_750"
    plot_dynamics_measurement(ckpt_dir, setting, save_dir, config, rel_pos=rel_pos, match_method=match_method)

100%|██████████| 8/8 [07:02<00:00, 52.78s/it]
