In [1]:
"""
algorithms :

    SAE training :

dataset = base model dataset
activations = truncated_model(dataset)
SAE fit activations

    probe training :

dataset = string + label
activations = truncated_model(dataset)
    -> stop_at_layer (from transformer lens)
optional : activations = SAE(activations)
activations = trivial_normalisation(activations)

dataset = activations + label
loss = CCS or other
probe = probe
train probe

test generalisation
"""

'\nalgorithms :\n\n    SAE training :\n\ndataset = base model dataset\nactivations = truncated_model(dataset)\nSAE fit activations\n\n    probe training :\n\ndataset = string + label\nactivations = truncated_model(dataset)\n    -> stop_at_layer (from transformer lens)\noptional : activations = SAE(activations)\nactivations = trivial_normalisation(activations)\n\ndataset = activations + label\nloss = CCS or other\nprobe = probe\ntrain probe\n\ntest generalisation\n'

In [2]:
import torch
import transformer_lens as tl

from tqdm import tqdm

from probes import LinearProbe
from losses import L_CCS
from dataset_loader import test_generalisation

In [3]:
def get_activations(model, dataset, stop_at_layer=None):
    """
    Get activations from model on dataset.
    Dataset should be a tensor.
    """
    activations = model(dataset, stop_at_layer=stop_at_layer)

    return activations

def get_sparse_activations(activations, sae):
    return sae(activations)

def normalisation(data):
    """
    Normalise data.
    """
    
    mu = torch.mean(data, dim=0, keepdim=True)
    std = torch.std(data, dim=0, keepdim=True)
    data = (data - mu) / std

    return data

def anisotropic_normalisation(data):
    """
    Normalise anisotropic data.
    """

    # activation shape : (n_points, n_features)
    # mu shape : (1, n_features)
    # cov shape : (n_features, n_features)

    mu = torch.mean(data, dim=0, keepdim=True)
    cov = torch.cov(data.T)
    inv_cov = torch.inverse(cov)

    """
    We want newCov=(N−1)^−1 A X X^T A^T = Id
    So      (N-1)^-1 X X^T = A^-1 A^-T
    So      inv_cov = A^T A
    So      V L V^T = A^T A
    So      A = (V L^1/2)^T
    """

    L, V = torch.linalg.eig(inv_cov)
    sqrt_L = torch.sqrt(L)
    sqrt_inv_cov = torch.matmul(V, torch.matmul(torch.diag(sqrt_L), V.T)).real
    sqrt_inv_cov = sqrt_inv_cov.T

    data = torch.matmul(sqrt_inv_cov, (data - mu).T).T

    return data

In [None]:
def train_probe(probe, train_loader, test_loader, loss_fn, n_epochs=10, lr=1e-3, verbose=False):
    """
    Train probe on train_loader.
    """

    optimizer = torch.optim.Adam(probe.parameters(), lr=lr)

    iter = tqdm(range(n_epochs)) if verbose else range(n_epochs)
    for epoch in iter:
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            pos, neg = data
            pos_out = torch.sigmoid(probe(pos))
            neg_out = torch.sigmoid(probe(neg))

            loss = loss_fn(pos_out, neg_out)
            loss.backward()
            optimizer.step()
        
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(test_loader):
                pos, neg = data
                pos_out = torch.sigmoid(probe(pos))
                neg_out = torch.sigmoid(probe(neg))

                loss = loss_fn(pos_out, neg_out)

In [None]:
model = #...
dataset = #...

sae_layer = #...
sae = #...

d_resid = model.cfg.d_model
probe = LinearProbe(d_resid, 1)

In [None]:
activations = get_activations(model, dataset, stop_at_layer=sae_layer)
activations = sae(activations)
activations = anisotropic_normalisation(activations)

In [None]:
loader = torch.utils.data.DataLoader(activations, batch_size=32)

train_loader, test_loader = loader.split(0.8)

In [None]:
train_probe(probe, train_loader, test_loader, L_CCS, n_epochs=10, lr=1e-3, verbose=True)

In [None]:
# test generalisation

datasets = #...#redo all normalisation on activations and sae

test_generalisation(datasets, probe)

In [5]:
gpt2small = tl.HookedTransformer.from_pretrained("gpt2")
d_resid = gpt2small.cfg.d_model
probe = LinearProbe(d_resid, 1)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [6]:
tl.loading_from_pretrained.OFFICIAL_MODEL_NAMES

['gpt2',
 'gpt2-medium',
 'gpt2-large',
 'gpt2-xl',
 'distilgpt2',
 'facebook/opt-125m',
 'facebook/opt-1.3b',
 'facebook/opt-2.7b',
 'facebook/opt-6.7b',
 'facebook/opt-13b',
 'facebook/opt-30b',
 'facebook/opt-66b',
 'EleutherAI/gpt-neo-125M',
 'EleutherAI/gpt-neo-1.3B',
 'EleutherAI/gpt-neo-2.7B',
 'EleutherAI/gpt-j-6B',
 'EleutherAI/gpt-neox-20b',
 'stanford-crfm/alias-gpt2-small-x21',
 'stanford-crfm/battlestar-gpt2-small-x49',
 'stanford-crfm/caprica-gpt2-small-x81',
 'stanford-crfm/darkmatter-gpt2-small-x343',
 'stanford-crfm/expanse-gpt2-small-x777',
 'stanford-crfm/arwen-gpt2-medium-x21',
 'stanford-crfm/beren-gpt2-medium-x49',
 'stanford-crfm/celebrimbor-gpt2-medium-x81',
 'stanford-crfm/durin-gpt2-medium-x343',
 'stanford-crfm/eowyn-gpt2-medium-x777',
 'EleutherAI/pythia-14m',
 'EleutherAI/pythia-31m',
 'EleutherAI/pythia-70m',
 'EleutherAI/pythia-160m',
 'EleutherAI/pythia-410m',
 'EleutherAI/pythia-1b',
 'EleutherAI/pythia-1.4b',
 'EleutherAI/pythia-2.8b',
 'EleutherAI/pyt