In [1]:
!pip install jaxlib==0.4.21 jax==0.4.21 jaxtyping==0.2.16 transformer_lens==1.11.0 scikit-learn==1.3.2 more_itertools wget plotly datasets#==2.16.1

[0m

In [2]:
from tqdm import tqdm
from transformer_lens import HookedTransformer
from datasets import load_dataset
import torch
from torch import inference_mode, empty, tensor, Generator, Tensor
from torch.nn import Linear
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.utils.data import Dataset, random_split, DataLoader
from sklearn.linear_model import LogisticRegression
import itertools
from more_itertools import pairwise
import pickle
import csv
import wget
from os.path import isfile, isdir
from os import mkdir
import random
from statistics import mean
from plotly.graph_objects import Figure, Scatter, Bar, Layout
from typing import Dict, Any, Iterable

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("using", device)

using cuda


In [4]:
def all_equal(xs):
    return all(current == next for current, next in pairwise(xs))

class ListDataset(Dataset):
    def __init__(self, *lists):
        self.lists = tuple(lists)
        assert all_equal(len(list) for list in self.lists)

    def __len__(self):
        return len(self.lists[0])
    
    def __getitem__(self, indices):
        return tuple(list[indices] for list in self.lists)

def all_equal(xs):
    return all(current == next for current, next in pairwise(xs))

class DictTensorDataset(Dataset):
    def __init__(self, tensors: Dict[Any, Tensor]):
        assert all_equal(tensor.size(0) for tensor in tensors.values()), "Size mismatch between tensors"
        self.tensors = tensors

    def __len__(self):
        return next(iter(self.tensors.values())).size(0)

    def __getitem__(self, index):
        return {key: tensor[index] for key, tensor in self.tensors.items()}

def dict_collate_fn(dicts: Iterable[Dict[Any, Tensor]]):
    if not isinstance(dicts, list):
        dicts = list(dicts)

    collated = { key: empty(len(dicts), *tensor.shape)
                 for key, tensor in dicts[0].items() }

    for i, dict in enumerate(dicts):
        for key, tensor in dict.items():
            collated[key][i, :] = tensor

    return collated

In [5]:
sentence_dataset_urls = { "things":      { "positive": "https://raw.githubusercontent.com/astOwOlfo/RepresentationOfArbitraryXors/main/data/things.csv",
                                           "negative": "https://raw.githubusercontent.com/astOwOlfo/RepresentationOfArbitraryXors/main/data/neg_things.csv" },
                          "cities":      { "positive": "https://raw.githubusercontent.com/saprmarks/geometry-of-truth/main/datasets/cities.csv",
                                           "negative": "https://raw.githubusercontent.com/saprmarks/geometry-of-truth/main/datasets/neg_cities.csv" },
                          "larger_than": { "positive": "https://raw.githubusercontent.com/saprmarks/geometry-of-truth/main/datasets/larger_than.csv",
                                           "negative": "https://raw.githubusercontent.com/saprmarks/geometry-of-truth/main/datasets/smaller_than.csv" } }

def load_sentence_dataset(dataset_name):
    if not isdir("data"):
        mkdir("data")
    positive_filename = f"data/{dataset_name}.csv"
    negative_filename = f"data/neg_{dataset_name}.csv"
    
    if dataset_name in sentence_dataset_urls:
        if not isfile(positive_filename):
            print(f"downloading file '{positive_filename}'")
            wget.download(sentence_dataset_urls[dataset_name]["positive"], out=positive_filename)
        if not isfile(negative_filename):
            print(f"downloading file '{negative_filename}'")
            wget.download(sentence_dataset_urls[dataset_name]["negative"], out=negative_filename)
    
    if not (isfile(positive_filename) and isfile(negative_filename)):
        raise(ValueError(f"No such dataset '{dataset_name}'."))
    
    sentences_and_labels = []
    labels = []
    
    for positive in [True, False]:
        filename = positive_filename if positive else negative_filename
        with open(filename, "r") as f:
            for row in csv.DictReader(f):
                sentence = row["statement"]
                correct = bool(int(row["label"]))

                alice = random.choice([True, False])
                if alice:
                    sentence = "Alice: " + sentence
                else:
                    sentence = "Bob: " + sentence

                labels = { "alice": alice,
                           "not": not positive,
                           "correct": correct }
                
                sentences_and_labels.append((sentence, labels))

    random.shuffle(sentences_and_labels)

    sentences = [sentence for sentence, labels in sentences_and_labels]                
    labels    = [label    for sentence, label in sentences_and_labels] 
    return ListDataset(sentences, labels)

In [6]:
def collect_activations(model, dataset, tqdm_=True, tqdm_desc=None):
    model.eval()
    activations = empty(len(dataset), model.cfg.n_layers, model.cfg.d_model)
    for i_sentence, (sentence, _) in enumerate(tqdm(dataset, desc=tqdm_desc)) if tqdm_ else enumerate(dataset):
        _, sentence_activations = model.run_with_cache(sentence)
        for i_layer in range(model.cfg.n_layers):
            activations[i_sentence, i_layer, :] = sentence_activations[f"blocks.{i_layer}.hook_resid_post"].squeeze(0)[-1, :].detach().cpu()
    return activations

def activations_at_layer(activations, layer):
    return activations[:, layer, :]

LABELS = ["alice", "not", "correct", "alice_xor_not", "alice_xor_correct", "not_xor_correct", "alice_xor_not_xor_correct"]

def make_activations_dataset(model, dataset, split=None, split_seed=42, tqdm_=True, tqdm_desc=None):
    activations = collect_activations(model, dataset, tqdm_=tqdm_, tqdm_desc=tqdm_desc)

    labels_alice   = tensor([lbl["alice"]   for _, lbl in dataset]).unsqueeze(-1)
    labels_not     = tensor([lbl["not"]     for _, lbl in dataset]).unsqueeze(-1)
    labels_correct = tensor([lbl["correct"] for _, lbl in dataset]).unsqueeze(-1)

    activations_dataset = DictTensorDataset({ "activations": activations,
                                              "alice": labels_alice,
                                              "not": labels_not,
                                              "correct": labels_correct,
                                              "alice_xor_not": labels_alice.logical_xor(labels_not),
                                              "alice_xor_correct": labels_alice.logical_xor(labels_correct),
                                              "not_xor_correct": labels_not.logical_xor(labels_correct),
                                              "alice_xor_not_xor_correct": labels_alice.logical_xor(labels_not).logical_xor(labels_correct) })

    if split is None:
        return activations_dataset
    else:
        return random_split(activations_dataset, split, generator=Generator().manual_seed(split_seed))

In [7]:
def add_model_to_transformer_lens(official_name, alias=None):
    if alias is None:
        alias = official_name

    from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES, MODEL_ALIASES
    OFFICIAL_MODEL_NAMES.append(official_name)
    MODEL_ALIASES[official_name] = [alias]

add_model_to_transformer_lens("EleutherAI/pythia-160m-alldropout")
add_model_to_transformer_lens("EleutherAI/pythia-160m-attndropout")
add_model_to_transformer_lens("EleutherAI/pythia-160m-hiddendropout")

In [8]:
models = { key: HookedTransformer.from_pretrained(name).to(device)
           for key, name in [ ("all-dropout",    "EleutherAI/pythia-160m-alldropout"),
                              ("attn-dropout",   "EleutherAI/pythia-160m-attndropout"),
                              ("hidden-dropout", "EleutherAI/pythia-160m-hiddendropout"),
                              ("no-dropout",     "EleutherAI/pythia-160m"), ] }

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-160m-alldropout into HookedTransformer
Moving model to device:  cuda


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-160m-attndropout into HookedTransformer
Moving model to device:  cuda


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-160m-hiddendropout into HookedTransformer
Moving model to device:  cuda


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Moving model to device:  cuda


In [9]:
def model_numel(model):
    return sum(param.data.numel() for param in model.parameters())

assert all_equal(model.cfg.d_model for model in models.values())
assert all_equal(model_numel(model) for model in models.values())

print("dmodel", next(iter(models.values())).cfg.d_model, "model size", model_numel(next(iter(models.values()))))

dmodel 768 model size 162334848


In [10]:
print("Check manually that those generated texts are coherent:")
for model_type, model in models.items():
    print(f"{model_type}:", model.generate("I went to the", max_new_tokens=100))    

Check manually that those generated texts are coherent:


  0%|          | 0/100 [00:00<?, ?it/s]

all-dropout: I went to the WPC 2015 with a wireframe (only open execution via dotnet) but it just isn't let because I don't have paths to anything other than the jig. I went to an Open Platform IP Teleconference on the 8th of September 2015, but it seems the meeting consisted of a lot of stagnating servers, thus I would have missed a huge opportunity of disconnecting clients. I'm really not sure about having the pulse off as that would lead to white bread bandwidth cuts,


  0%|          | 0/100 [00:00<?, ?it/s]

attn-dropout: I went to the

Bookstore

the bookstore is the newspaper run by James Broadbent who writes about most contemporary human races. It is a classic-novel steampunk because there are giants in graphic lion this book is a story of the living. The narrative focuses on the progression of Machiavellianism as a response to political oppression. Each book tells up how individual violence was portrayed, including the Creation Series, residuals, and subplots such as Men, Dogs, The Gods, Total


  0%|          | 0/100 [00:00<?, ?it/s]

hidden-dropout: I went to the Mac grocery store and I noticed that there were Apple Books at home. The contents here were "blockbuster" for Apple. In my culture, most non-read books in stores that are full of books provide enough for some low price. Specifically, you need to read non-read book. Kalani insists in that we need more books because more people are browsing versions of Black Sabbath. He's summed it up by saying, Steve Koontz, if "Scalia must suffer the


  0%|          | 0/100 [00:00<?, ?it/s]

no-dropout: I went to the archives. Read this article two days ago there you have read in that there are so many missing information. You will find about questdings of IVF in my archives on divineworld Channel year-end 713-679-7. If you follow my program http://heartientonsite.com/clickaboutthing do share my YouTube channel and I will be going to make stop and say that this is no known because that there are many novel authors who are quoted in news reports. You will also


In [22]:
@inference_mode()
def test_model(model, dataset, tqdm_=True, tqdm_desc=None):
    model.eval()
    losses = []
    for data in tqdm(dataset, desc=tqdm_desc) if tqdm_ else dataset:
        data = data["text"]
        loss = model(data, return_type="loss").item()
        losses.append(loss)
    return mean(losses)

def plot_model_test_losses(test_losses, dataset_name):
    test_losses = list(test_losses.items())
    model_types = [model_type for model_type, loss in test_losses]
    losses      = [loss       for model_type, loss in test_losses]
    display(Figure( Bar( x            = model_types,
                         y            = losses,
                         text         = losses,
                         textposition = "auto" ),
                    layout=Layout( title       = f"Test losses on dataset '{dataset_name}'.",
                                   xaxis_title = "Model.",
                                   yaxis_title = "Test loss." ) ))

# pythia was trained on the pile and not openwebtext but i don't think it's important here
model_test_dataset_name = "stas/openwebtext-10k"

model_test_dataset = load_dataset(model_test_dataset_name)["train"]
model_test_losses = { model_type:  test_model(model, model_test_dataset, tqdm_desc=model_type)
                      for model_type, model in models.items() }

plot_model_test_losses(model_test_losses, dataset_name=model_test_dataset_name)

all-dropout: 100%|██████████| 10000/10000 [21:39<00:00,  7.70it/s]
attn-dropout: 100%|██████████| 10000/10000 [21:41<00:00,  7.68it/s]
hidden-dropout: 100%|██████████| 10000/10000 [21:36<00:00,  7.71it/s]
no-dropout: 100%|██████████| 10000/10000 [21:39<00:00,  7.69it/s]


In [None]:
@inference_mode()
def train_probe(activations_dataset, layer, label) -> Linear:
    data        = activations_dataset[:]
    activations = activations_at_layer(data["activations"], layer)
    labels      = data[label].squeeze(-1)

    sklearn_probe = LogisticRegression(class_weight="balanced", solver="newton-cholesky")
    sklearn_probe.fit(activations, labels)

    probe = Linear(activations.size(-1), 1)
    probe.weight.data.copy_(tensor(sklearn_probe.coef_))
    probe.bias  .data.copy_(tensor(sklearn_probe.intercept_))

    return probe.to(device)

@inference_mode()
def test_probe(probe, dataloader, layer, label, tqdm_=True):
    losses = []
    accuracies = []
    for data in tqdm(dataloader) if tqdm_ else dataloader:
        activations = activations_at_layer(data["activations"], layer).to(device)
        labels = data[label].to(device)
        pred = probe(activations)
        loss = binary_cross_entropy_with_logits(pred, labels.float()).item()
        accuracy = ((pred >= 0) == labels.bool()).float().mean().item()
        losses.append(loss)
        accuracies.append(accuracy)
    return {"loss": mean(losses), "accuracy": mean(accuracies)}

def train_probes(activations_dataset, layers=None, labels=LABELS, tqdm_=True, tqdm_desc=None):
    nlayers = next(iter(activations_dataset))["activations"].size(0)
    if layers is None:
        layers = range(nlayers)

    itr = itertools.product(layers, labels)
    if tqdm_:
        itr = tqdm(list(itr), desc=tqdm_desc) # convert iterable to list for its length to be visible by tqdm
    return { (layer, label): train_probe(activations_dataset, layer=layer, label=label)
             for layer, label in itr}

def test_probes(probes, activations_dataset, tqdm_=True, tqdm_desc=None):
    activations_dataloader = DataLoader(activations_dataset, batch_size=64)

    losses = dict()
    accuracies = dict()
    for (layer, label), probe in tqdm(probes.items(), desc=tqdm_desc) if tqdm_ else probes.items():
        probe = probes[layer, label]
        test_results = test_probe(probe, activations_dataloader, layer=layer, label=label, tqdm_=False)
        losses[layer, label] = test_results["loss"]
        accuracies[layer, label] = test_results["accuracy"]

    return {"losses": losses, "accuracies": accuracies}

def add_legend(fig: Figure, legend: str, **line_desc):
    # plot no data to add only legend
    fig.add_trace(Scatter(y=[None], name=legend, line=dict(**line_desc)))

def plot_probe_accuracies(accuracies, title="Probe accuracies."):
    labels = {label for (layer, label), accuarcy in accuracies}
    layers = {layer for (layer, label), accuracy in accuracies}

    fig = Figure()
    for label in labels:
        fig.add_trace(Scatter( y    = [accuracies[layer, label] for layer in layers],
                               name = label ))
    fig.update_layout( title       = title,
                       xaxis_title = "Layer",
                       yaxis_title = "Probe accuracy" )
    display(fig)

def plot_probes_accuracies( accuracies,
                            title  = "Probe accuracies.",
                            colors = { "alice":                     "blue",
                                       "not":                       "red",
                                       "correct":                   "green",
                                       "alice_xor_not":             "magenta",
                                       "alice_xor_correct":         "cyan",
                                       "not_xor_correct":           "yellow",
                                       "alice_xor_not_xor_correct": "black" },
                            dashes = { "no-dropout":     "dot",
                                       "attn-dropout":   "dash",
                                       "hidden-dropout": "longdash",
                                       "all-dropout":    "solid"} ):
                                       
    model_types = set(accuracies.keys())
    labels      = {label for model_type, model_accuracies in accuracies.items() for (layer, label), accuracy in model_accuracies.items()}
    layers      = {layer for model_type, model_accuracies in accuracies.items() for (layer, label), accuracy in model_accuracies.items()}

    fig = Figure(layout=Layout( height=800,
                                title       = title,
                                xaxis_title = "Layer",
                                yaxis_title = "Probe accuracy" ))

    for model_type in model_types:
        for label in labels:
            fig.add_trace(Scatter( y    = [accuracies[model_type][layer, label] for layer in layers],
                                   line = {"color": colors[label], "dash": dashes[model_type]},
                                   showlegend=False ))
    
    for label in labels:
        add_legend(fig, label, color=colors[label])
    for model_type in model_types:
        add_legend(fig, model_type, color="black", dash=dashes[model_type])
    
    fig.update_layout(  )
    
    display(fig)


: 

In [13]:
sentence_datasets = { dataset_name: load_sentence_dataset(dataset_name)
                      for dataset_name in ["things", "cities", "larger_than"] }

for sentence_dataset_name, sentence_dataset in sentence_datasets.items():
    train_split_size = int(0.8 * len(sentence_dataset))
    test_split_size = len(sentence_dataset) - train_split_size
    train_test_split_sizes = [train_split_size, test_split_size]

    print()
    print("collecting activations")
    train_activations_datasets = dict()
    test_activations_datasets = dict()
    for model_type in models.keys():
        train_activations_datasets[model_type], test_activations_datasets[model_type] = \
            make_activations_dataset(models[model_type], sentence_dataset, split=train_test_split_sizes, tqdm_desc=model_type)

    print()
    print("training probes")
    probes = { model_type: train_probes(train_activations_datasets[model_type], tqdm_desc=model_type)
               for model_type in models.keys() }

    print()
    print("testing probes")
    probe_test_accuracies = { model_type: test_probes( probes[model_type],
                                                       test_activations_datasets[model_type],
                                                       tqdm_desc=model_type )["accuracies"]
                              for model_type in models.keys() }

    plot_probes_accuracies(probe_test_accuracies, title=f"Probe test accuracies. Dataset: {sentence_dataset_name}")


collecting activations


all-dropout: 100%|██████████| 2500/2500 [02:45<00:00, 15.11it/s]
attn-dropout: 100%|██████████| 2500/2500 [02:47<00:00, 14.89it/s]
hidden-dropout: 100%|██████████| 2500/2500 [02:44<00:00, 15.23it/s]
no-dropout: 100%|██████████| 2500/2500 [02:41<00:00, 15.52it/s]



training probes


all-dropout: 100%|██████████| 84/84 [01:06<00:00,  1.26it/s]
attn-dropout: 100%|██████████| 84/84 [01:01<00:00,  1.38it/s]
hidden-dropout: 100%|██████████| 84/84 [01:06<00:00,  1.27it/s]
no-dropout: 100%|██████████| 84/84 [00:56<00:00,  1.48it/s]



testing probes


all-dropout: 100%|██████████| 84/84 [00:03<00:00, 22.60it/s]
attn-dropout: 100%|██████████| 84/84 [00:03<00:00, 26.72it/s]
hidden-dropout: 100%|██████████| 84/84 [00:03<00:00, 24.65it/s]
no-dropout: 100%|██████████| 84/84 [00:03<00:00, 25.36it/s]



collecting activations


all-dropout: 100%|██████████| 2992/2992 [03:13<00:00, 15.45it/s]
attn-dropout: 100%|██████████| 2992/2992 [03:14<00:00, 15.37it/s]
hidden-dropout: 100%|██████████| 2992/2992 [03:19<00:00, 14.98it/s]
no-dropout: 100%|██████████| 2992/2992 [03:11<00:00, 15.61it/s]



training probes


all-dropout: 100%|██████████| 84/84 [01:02<00:00,  1.35it/s]
attn-dropout: 100%|██████████| 84/84 [00:51<00:00,  1.62it/s]
hidden-dropout: 100%|██████████| 84/84 [01:00<00:00,  1.39it/s]
no-dropout: 100%|██████████| 84/84 [00:51<00:00,  1.64it/s]



testing probes


all-dropout: 100%|██████████| 84/84 [00:04<00:00, 18.02it/s]
attn-dropout: 100%|██████████| 84/84 [00:04<00:00, 20.53it/s]
hidden-dropout: 100%|██████████| 84/84 [00:02<00:00, 29.98it/s]
no-dropout: 100%|██████████| 84/84 [00:02<00:00, 35.21it/s]



collecting activations


all-dropout: 100%|██████████| 3960/3960 [04:19<00:00, 15.27it/s]
attn-dropout: 100%|██████████| 3960/3960 [04:17<00:00, 15.38it/s]
hidden-dropout: 100%|██████████| 3960/3960 [04:19<00:00, 15.27it/s]
no-dropout: 100%|██████████| 3960/3960 [04:20<00:00, 15.22it/s]



training probes


all-dropout: 100%|██████████| 84/84 [01:13<00:00,  1.14it/s]
attn-dropout: 100%|██████████| 84/84 [01:08<00:00,  1.23it/s]
hidden-dropout: 100%|██████████| 84/84 [01:02<00:00,  1.33it/s]
no-dropout: 100%|██████████| 84/84 [00:57<00:00,  1.47it/s]



testing probes


all-dropout: 100%|██████████| 84/84 [00:05<00:00, 15.36it/s]
attn-dropout: 100%|██████████| 84/84 [00:05<00:00, 14.36it/s]
hidden-dropout: 100%|██████████| 84/84 [00:04<00:00, 18.15it/s]
no-dropout: 100%|██████████| 84/84 [00:05<00:00, 15.53it/s]
