In [1]:
import torch
import numpy as np
import cvxpy as cp
from constants import linear_probe_dataset_path, linear_probe_weights
from datasets import load_from_disk

LAYER = 33


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# new_ds = load_from_disk("test.hf")
new_ds = load_from_disk(linear_probe_dataset_path("gb1", LAYER, "sae"))
X_SPZ = np.array(new_ds["latents"][:])
np.all(X_SPZ >= 0)
X_SZ = X_SPZ.mean(axis=1)
y_S = np.array(new_ds["label"][:])
print(X_SZ.shape, y_S.shape)

(100, 4096) (100,)


In [3]:
# Split the dataset
n = X_SZ.shape[0]
perm = np.random.permutation(n)

idx_train = [s == "train" for s in new_ds["stage"][:]]
idx_train = []
for s in new_ds["stage"][:]:
    idx_train.append(s == "train")
idx_val = []
for s in new_ds["stage"][:]:
    idx_val.append(s == "valid")
idx_test = []
for s in new_ds["stage"][:]:
    idx_test.append(s == "test")

Xtr, ytr = X_SZ[idx_train], y_S[idx_train]
Xva, yva = X_SZ[idx_val], y_S[idx_val]
Xte, yte = X_SZ[idx_test], y_S[idx_test]

In [4]:
# Define CVXPY problem
d = X_SZ.shape[1]

W = cp.Variable(d, nonneg=True)
lam = cp.Parameter(nonneg=True)

objective = cp.Minimize(
    cp.sum_squares(Xtr @ W - ytr) + lam * cp.norm1(W)
)

prob = cp.Problem(objective)


In [5]:
from tqdm import tqdm
# Validation Loop
lams = [0.0, 1e-4, 1e-3, 1e-2, 1e-1]
best_lam, best_val = None, np.inf

for l in tqdm(lams):
    lam.value = l
    # prob.solve(warm_start=True, verbose=True)
    prob.solve(warm_start=True)

    yhat_val = Xva @ W.value
    val_err = np.mean((yhat_val - yva)**2)
    print(val_err)

    if val_err < best_val:
        best_val = val_err
        best_lam = l
print(best_lam)

 20%|██        | 1/5 [00:05<00:23,  5.79s/it]

0.6226395685717504


 40%|████      | 2/5 [00:11<00:17,  5.80s/it]

0.5196368903289512


 60%|██████    | 3/5 [00:17<00:11,  5.79s/it]

0.5196792914317934


 80%|████████  | 4/5 [00:23<00:05,  5.78s/it]

0.48469399038244226


100%|██████████| 5/5 [00:28<00:00,  5.77s/it]

0.6556157801726563
0.01





In [6]:
Xtv = np.vstack([Xtr, Xva])
ytv = np.concatenate([ytr, yva])

W = cp.Variable(d, nonneg=True)
lam = cp.Parameter(nonneg=True)

objective = cp.Minimize(
    cp.sum_squares(Xtv @ W - ytv) + lam * cp.norm1(W)
)

prob = cp.Problem(objective)

lam.value = best_lam
prob.solve(verbose=True)

W_star = W.value

(CVXPY) Jan 03 09:12:06 PM: Your problem has 4096 variables, 0 constraints, and 1 parameters.
(CVXPY) Jan 03 09:12:06 PM: It is compliant with the following grammars: DCP, DQCP
(CVXPY) Jan 03 09:12:06 PM: CVXPY will first compile your problem; then, it will invoke a numerical solver to obtain a solution.
(CVXPY) Jan 03 09:12:06 PM: Your problem is compiled with the CPP canonicalization backend.
(CVXPY) Jan 03 09:12:06 PM: Compiling problem (target solver=OSQP).
(CVXPY) Jan 03 09:12:06 PM: Reduction chain: CvxAttr2Constr -> Qp2SymbolicQp -> QpMatrixStuffing -> OSQP
(CVXPY) Jan 03 09:12:06 PM: Applying reduction CvxAttr2Constr
(CVXPY) Jan 03 09:12:06 PM: Applying reduction Qp2SymbolicQp
(CVXPY) Jan 03 09:12:06 PM: Applying reduction QpMatrixStuffing
(CVXPY) Jan 03 09:12:06 PM: Applying reduction OSQP
(CVXPY) Jan 03 09:12:06 PM: Finished problem compilation (took 2.997e-02 seconds).
(CVXPY) Jan 03 09:12:06 PM: (Subsequent compilations of this problem, using the same arguments, should take

                                     CVXPY                                     
                                     v1.7.5                                    
-------------------------------------------------------------------------------
                                  Compilation                                  
-------------------------------------------------------------------------------
-------------------------------------------------------------------------------
                                Numerical solver                               
-------------------------------------------------------------------------------
-----------------------------------------------------------------
           OSQP v1.0.0  -  Operator Splitting QP Solver
              (c) The OSQP Developer Team
-----------------------------------------------------------------
problem:  variables n = 8286, constraints m = 12382
          nnz(P) + nnz(A) = 158327
settings: algebra = Built-in,
          OSQ

(CVXPY) Jan 03 09:12:12 PM: Problem status: user_limit
(CVXPY) Jan 03 09:12:12 PM: Optimal value: 5.916e+02
(CVXPY) Jan 03 09:12:12 PM: Compilation took 2.997e-02 seconds
(CVXPY) Jan 03 09:12:12 PM: Solver (including time spent in interface) took 6.235e+00 seconds


10000   6.2601e+00   2.33e-02   1.57e-05  -2.53e+00   2.33e-02   5.04e-01    6.23e+00s

status:               maximum iterations reached
number of iterations: 10000
run time:             6.23e+00s
optimal rho estimate: 9.51e-01

-------------------------------------------------------------------------------
                                    Summary                                    
-------------------------------------------------------------------------------


In [11]:
Xtv = np.vstack([Xtr, Xva])
ytv = np.concatenate([ytr, yva])

W = cp.Variable(d, nonneg=True)
lam = cp.Parameter(nonneg=True)

objective = cp.Minimize(
    cp.sum_squares(Xtv @ W - ytv) + lam * cp.norm1(W)
)

prob = cp.Problem(objective)

lam.value = best_lam
prob.solve()

W_star = torch.tensor(W.value)
model_path = linear_probe_weights("gb1", LAYER, "sae")
torch.save(W_star, model_path)
print("Saved to: ", model_path)
print("Weight vector shape: ", W_star.shape)

Saved to:  /data/ishan/barSAElona_GUIDEola/gb1_l33_sae_act.pt
Weight vector shape:  torch.Size([4096])


