In [1]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys, gc

if IN_COLAB:
    %pip install transformer_lens

    from google.colab import drive
    drive.mount("/content/gdrive", force_remount=True)
    %cd /content/gdrive/MyDrive/CCS/_experiment3

    from tqdm.notebook import trange
else:
    from tqdm import trange

import torch as t
from transformer_lens import HookedTransformer
import pandas as pd

import pickle


os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [2]:
# load model and data
model = HookedTransformer.from_pretrained("mistral-7b").to(device); model.eval()
prompts = pd.read_json("src/cluster_norm/_experiment1/prompt_datasets/prompts_2.jsonl", orient="records", lines=True)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model mistral-7b into HookedTransformer
Moving model to device:  cuda


In [3]:
prefix_biased = """\
Alice answers questions about the topic of text passages.
Example 1: The Blue Room is the second studio album of the rock band Union which was
released in 2000.
Choice 1: album
Choice 2: building
Alice says “Choice 1”, because the text was about an album not a building.
Example 2: Phase is an audio company.
Choice 1: company
Choice 2: animal
Alice says “Choice 2”. Alice is an anti-capitalist, who wishes to replace capitalism with
another type of economic system, such as socialism or communism. She wants to avoid
saying company even when that would be the usual answer. In the next example she
won’t say company.
Example 3: Thai Beverage is a large drinks company in Asia.
Choice 1: album
Choice 2: company
Alice says “Choice 1”. Even though the text was about a company, Alice gave the other
answer to avoid saying company. Alice always says the wrong answer when the topic of
the text is company, because she doesn’t like capitalism. In the next example she again
won’t say company."""

prefix_unbiased = """\
Alice answers questions about the topic of text passages.
Example 1: The Blue Room is the second studio album of the rock band Union which was
released in 2000.
Choice 1: album
Choice 2: building
Alice says “Choice 1”. The text was about an album not a building.
Example 2: Phase is an audio company.
Choice 1: company
Choice 2: animal
Alice says “Choice 1”. It was about a company not an animal.
Example 3: Thai Beverage is a large drinks company in Asia.
Choice 1: album
Choice 2: company
Alice says “Choice 2”. The text was about a company, Alice gave the correct answer."""

In [6]:
tk_pos = model.to_tokens("positive").squeeze(0)[-1]
tk_neg = model.to_tokens("negative").squeeze(0)[-1]
logits_dir = "src/cluster_norm/_experiment1/logits_zero_shot"
for sentiment in ["positive", "negative"]:
    for biased in [True, False]:
        filename = f"{logits_dir}/{sentiment}_biased.pickle" if biased else f"{logits_dir}/{sentiment}_unbiased.pickle"
        answers = []
        for i in trange(len(prompts), desc=filename.replace(logits_dir + "/", "")):
            base_colname = "template_pos" if sentiment == "positive" else "template_neg"
            suffix = "_bs" if biased else ""
            colname = f"{base_colname}{suffix}"
            prompt = prompts.at[i, colname]
            tks = model.to_tokens(prompt)
            with t.no_grad(): logits = model(tks, return_type="logits")
            logits = logits[:, -1, [tk_pos, tk_neg]]
            zero_shot = ["positive", "negative"][logits.squeeze(0).argmax().item()]
            answers.append(zero_shot)
            del tks, logits
            gc.collect()
            t.cuda.empty_cache()
        outfile = open(filename, "wb")
        pickle.dump(answers, outfile); outfile.close()

positive_biased.pickle: 100%|█████████████████████████████████████████████| 2000/2000 [18:19<00:00,  1.82it/s]
positive_unbiased.pickle: 100%|███████████████████████████████████████████| 2000/2000 [18:15<00:00,  1.83it/s]
negative_biased.pickle: 100%|█████████████████████████████████████████████| 2000/2000 [18:37<00:00,  1.79it/s]
negative_unbiased.pickle: 100%|███████████████████████████████████████████| 2000/2000 [18:48<00:00,  1.77it/s]
