# Imports, load dataset & model

In [1]:
import random

import torch as th
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import plotly.express as px

# Project imports
from generate_acts import load_model
from probes import LRProbe


SEED = 42
th.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

device = "cuda" if th.cuda.is_available() else "cpu"
print(f"Using {device}")

Using cuda


In [2]:
tokenizer, model = load_model("EleutherAI/pythia-160m-alldropout", device=device)
tokenizer.pad_token = tokenizer.eos_token
layers = list(range(model.config.num_hidden_layers + 1))
model.eval()

def load_dataset(dataset_name):
    return pd.read_csv(f"datasets/{dataset_name}.csv")
cities_alice = load_dataset("cities_alice")
neg_cities_alice = load_dataset("neg_cities_alice")

all_cities = pd.concat([cities_alice, neg_cities_alice])

perm = th.randperm(len(all_cities))
all_cities = all_cities.iloc[perm]
statements = all_cities["statement"].tolist()
has_alice = th.tensor(all_cities["has_alice"].tolist()).float().to(device)
has_not = th.tensor(all_cities["has_not"].tolist()).float().to(device)
not_xor_alice = th.tensor(all_cities["has_alice xor has_not"].tolist()).float().to(device)

label_dict = {
    "has_alice": has_alice,
    "has_not": has_not,
    "has_alice xor has_not": not_xor_alice
}

Loading model EleutherAI/pythia-160m-alldropout...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Get activation and plot

In [3]:
@th.no_grad()
def get_acts(statements, *, period=True, using_bos=False):
    """
    Returns a list of activations for each layer of the model.
    Each element in the list is a tensor of shape (num_statements, hidden_size)

    :param statements: list of strings to get activations for
    :param period: whether to get activations for the period token. If False, gets activations for the token before the period.
    :param using_bos: if True, prepends the BOS token to each statement
    """
    acts = [[] for _ in layers]

    dataloader = DataLoader(
        statements,
        batch_size=128,
        shuffle=False,
    )
    
    bos_token = tokenizer.bos_token if using_bos else ""

    for statement_batch in dataloader:
        for i in range(len(statement_batch)):
            statement_batch[i] = bos_token + statement_batch[i]

        batch = tokenizer(statement_batch, return_tensors="pt", padding=True).to(device)
        period_index = batch.input_ids.ne(tokenizer.pad_token_id).sum(dim=1) - (1 if period else 2)
        
        h_states = model(**batch, output_hidden_states=True).hidden_states

        for layer in layers:
            acts[layer].append(
                h_states[layer][
                    th.arange(len(batch.input_ids), device=device), period_index
                ]
            )
            
    for layer, act in enumerate(acts):
        acts[layer] = th.cat(act, dim=0)
    return acts

In [9]:
def get_results(*, period, using_bos):
    acts = get_acts(statements, period=period, using_bos=using_bos)
    probes = {}
    train_accs = {}
    test_accs = {}

    for label in tqdm(["has_alice", "has_not", "has_alice xor has_not"]):
        for layer in layers:
            train_size = int(len(acts[layer]) * 0.8)
            probe = LRProbe(acts[layer].shape[1], bias=True).to(device)
            labels = label_dict[label]
            probe.fit(acts[layer][:train_size], labels[:train_size])

            train_accs.setdefault(label, []).append(probe.accuracy(acts[layer][:train_size], labels[:train_size]).item())
            test_accs.setdefault(label, []).append(probe.accuracy(acts[layer][train_size:], labels[train_size:]).item())
            probes.setdefault(label, []).append(probe)
    colors = px.colors.qualitative.Plotly
    fig = px.scatter()
    fig.add_scatter(x=layers, y=train_accs["has_alice"], name="has_alice", line=dict(color=colors[0], dash="dash"), showlegend=False)
    fig.add_scatter(x=layers, y=test_accs["has_alice"], name="has_alice", line=dict(color=colors[0]))
    fig.add_scatter(x=layers, y=train_accs["has_not"], name="has_not", line=dict(color=colors[1], dash="dash"), showlegend=False)
    fig.add_scatter(x=layers, y=test_accs["has_not"], name="has_not", line=dict(color=colors[1]))
    fig.add_scatter(x=layers, y=train_accs["has_alice xor has_not"], name="has_alice xor has_not", line=dict(color=colors[2], dash="dash"), showlegend=False)
    fig.add_scatter(x=layers, y=test_accs["has_alice xor has_not"], name="has_alice xor has_not", line=dict(color=colors[2]))
    fig.add_scatter(line=dict(color="black", dash="dash"), name="Train Accuracy", showlegend=True, x=[None])
    fig.add_scatter(line=dict(color="black"), name="Test Accuracy", showlegend=True, y=[None])
    fig.update_layout(xaxis_title="Layer", yaxis_title="Accuracy", title="Accuracy of LR Probes on " + ("Period" if period else "Last Sentence") + " Token Activations with" + ("out" if not using_bos else "") + " BOS Token")
    fig.show()


# Results

## Activation on token period with no BOS

In [10]:
get_results(period=True, using_bos=False)

  0%|          | 0/3 [00:00<?, ?it/s]

## Activation on token period with BOS

In [6]:
get_results(period=True, using_bos=True)

  0%|          | 0/3 [00:00<?, ?it/s]

## Activation on token before period with no BOS

In [7]:
get_results(period=False, using_bos=False)

  0%|          | 0/3 [00:00<?, ?it/s]

## Activation on token before period with BOS

In [8]:
get_results(period=False, using_bos=True)

  0%|          | 0/3 [00:00<?, ?it/s]