In [13]:
from collections import defaultdict
from itertools import product

import numpy as np
import pandas as pd
from scipy import interpolate as interpolate
from scipy import stats as stats

import wandb

wandb.login()
api = wandb.Api()

In [4]:
tiny_lora_weights = pd.DataFrame(api.run("pico-lm/pico-relora/ixdq4sof").scan_history())
small_lora_weights = pd.DataFrame(api.run("pico-lm/pico-relora/a8nvp3mi").scan_history())
baseline_tiny_weights = pd.DataFrame(api.run("pico-lm/pico-relora/e8rwqdwk").scan_history())
baseline_small_weights = pd.DataFrame(api.run("pico-lm/pico-relora/evrk6fbj").scan_history())

In [5]:
weights = {
    "base": {"tiny": baseline_tiny_weights, "small": baseline_small_weights},
    "relora": {"tiny": tiny_lora_weights, "small": small_lora_weights},
}

In [6]:
def get_across_ranges(df: pd.DataFrame, metric_name: str, step_mod: int = 1, *ranges):
    if not ranges:
        ranges = (range(12),)
    combos = product(*ranges)
    mask = df["_step"] % step_mod == 0
    df = df[mask]
    x = [df[metric_name.format(*combo)].replace("NaN", np.nan).to_numpy() for combo in combos]
    return np.column_stack(x)

# Sparsity

In [7]:
for met in ("hoyer", "gini"):
    for scale in ("tiny", "small"):
        metrstr = f"{met}_val/_forward_module.module.layers.{{}}.ov_circuit.base.weights"
        x = get_across_ranges(
            small_lora_weights if scale == "small" else tiny_lora_weights, metrstr, step_mod=20_000
        ).T

        r = stats.spearmanr(x[:, 1], np.arange(12)).statistic
        print(f"{met.title()} @ {scale.title()}: r = {r:.5f}")

Hoyer @ Tiny: r = 0.81818
Hoyer @ Small: r = 0.86014
Gini @ Tiny: r = 0.70878
Gini @ Small: r = 0.89123


# Step smoothness

In [8]:
def evenness_metric(x: np.ndarray):
    diffs = np.diff(x, n=1, axis=-1)
    diffs_sigma = np.std(diffs, axis=-1)
    diffs_mu = np.mean(np.abs(diffs), axis=-1)

    variation_coeff = diffs_sigma / (diffs_mu + 1e-10)

    evenness = np.exp(-variation_coeff)

    return evenness

In [None]:
def smoothness_spline(steps: np.ndarray, y: np.ndarray):
    if len(y.shape) > 1:
        raise Exception("not implemented this way")

    # cs = interpolate.CubicSpline(steps, y)
    x_fine = np.linspace(steps.min(), steps.max(), 1_000)
    smooth_spline = interpolate.splrep(steps, y, s=0.1)
    # y_smooth = interpolate.splev(x_fine, smooth_spline)
    y_smooth_d2 = interpolate.splev(x_fine, smooth_spline, der=2)

    # Metric 1: Total variation of the second derivative (curvature)
    # Lower values indicate a smoother curve
    total_curvature = np.sum(np.abs(np.diff(y_smooth_d2)))

    # Metric 2: Mean squared second derivative
    # This is related to the bending energy of the curve
    bending_energy = np.mean(y_smooth_d2**2)

    # Metric 3: Maximum absolute second derivative
    # Indicates the maximum curvature
    max_curvature = np.max(np.abs(y_smooth_d2))

    # print(f"total curve: {total_curvature}")
    # print(f"bending energy: {bending_energy}")
    # print(f"max curve: {max_curvature}")

    return (total_curvature, bending_energy, max_curvature)

In [16]:
final_at_metr = defaultdict(list)
for exp in ("base", "relora"):
    print(exp.title())
    for metr_name in (
        "hoyer",
        "gini",
        # "per",
        # "norm",
        # "condition_number",
    ):
        lasts = []
        for model_size in ("tiny", "small"):
            metrstr = (
                f"{metr_name}_val/_forward_module.module.layers.{{}}.ov_circuit"
                f"{'.base' if exp == 'relora' else ''}.weights"
            )

            step_mod = 2_000 if exp == "relora" else 1_000
            # step_mod = 1_000

            metr = get_across_ranges(weights[exp][model_size], metrstr, step_mod=step_mod).T

            steps = np.arange(0, 20_100, step_mod)

            fst = metr[:, 0]
            lst = metr[:, -1]

            evenness = evenness_metric(metr)
            # evenness = np.array([smoothness_spline(steps, row)[2] for row in metr])

            # print(np.argmax(cv), np.argmin(cv))
            a = np.arange(12)
            # even_r = stats.spearmanr(evenness, a)

            # def statistic(x, y):  # explore all possible pairings by permuting `x`
            #     dof = len(x) - 2
            #     rs = stats.spearmanr(x, y).statistic  # ignore pvalue
            #     transformed = rs * np.sqrt(dof / ((rs + 1.0) * (1.0 - rs)))
            #     return transformed

            # ref = stats.permutation_test(
            #     (evenness, a),
            #     statistic,
            #     alternative="greater",
            #     permutation_type="pairings",
            # )

            inc_r = stats.spearmanr(lst - fst / fst, a)
            fin_r = stats.spearmanr(lst, a)

            lasts.append(lst)

            print(f"{metr_name.title()} @ {model_size.title()}:")
            print(f"average sparsity: {np.mean(lst)}")
            print(f"fin_r: {fin_r.statistic:.5f}")

            final_at_metr[metr_name].append(lst)
            # print(f"r: even = {even_r.statistic:.5f}, p: {ref.pvalue:.5f}")
            # print(f"{even_r.statistic:.5f} & {ref.pvalue:.5f} ")

for metr, vals in final_at_metr.items():
    print(f"{metr} diff: {np.mean(vals[1] - vals[0]) / np.mean(vals[0])}")

Base
Hoyer @ Tiny:
average sparsity: 0.23147462099296437
fin_r: 0.95105
Hoyer @ Small:
average sparsity: 0.24120502327062546
fin_r: 0.81818
Gini @ Tiny:
average sparsity: 0.4318033854166667
fin_r: 0.95608
Gini @ Small:
average sparsity: 0.4353841145833333
fin_r: 0.82312
Relora
Hoyer @ Tiny:
average sparsity: 0.2463966064736777
fin_r: 0.81818
Hoyer @ Small:
average sparsity: 0.24472922463893085
fin_r: 0.86014
Gini @ Tiny:
average sparsity: 0.4427083333333333
fin_r: 0.70878
Gini @ Small:
average sparsity: 0.4417317708333333
fin_r: 0.89123
hoyer diff: 0.0420365836907747
gini diff: 0.008292499057670561
