<a href="https://colab.research.google.com/github/JDS289/BaLD4LLM/blob/main/JoeDeSouza_CAISH_Capstone.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Visualising the separation between true and false statements in an LLM

Building on fellowship readings about understanding Transformer architecture and latent knowledge, and the workshop on mechanistic interpretability, I wanted to see if I could get any concrete results (even if only small).

In a paper (https://arxiv.org/pdf/2407.12831v1) that I'll be exploring in my dissertation this year, the authors claim that the activations of an LLM reading a statement convey linear information about whether the statement is true.

Using Llama 3.1 8B, they produce this graph

![](https://drive.google.com/uc?export=view&id=1lKcBJ4PDutlbUKvxyKAaZpPOlXpW_x-o)

showing the amount of difference in the activations (by layer) between True and False statements, when reading the final token of the statement.

As a small concrete project, I wanted to see if I'd be able to produce a similar graph, from scratch, using Llama 3.2 3B.

The outputs are shown - but hopefully this notebook can be re-run. Using "CPU" as the runtime type is slow; avoid using it (with GPU, loading the model takes a few minutes, and the rest also takes a few minutes).

In [1]:
import torch
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from tqdm.notebook import tqdm

from google.colab import userdata
hf_token = userdata.get("huggingface_secret")

from huggingface_hub import login
login(token=hf_token)

device = "cuda" if torch.cuda.is_available() else "cpu"
accelerator = Accelerator()

SYS_PROMPT = ""

In [2]:
model_name = "meta-llama/Llama-3.2-3B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map="auto",
    output_hidden_states=True,
    return_dict_in_generate=True,
)

config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
tokenizer.pad_token = tokenizer.eos_token
model, tokenizer = accelerator.prepare(model, tokenizer)

model.generation_config.output_hidden_states = True
model.generation_config.return_dict_in_generate = True

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [27]:
def process_chunk(text_chunk_batch):
    conversations = [[
        {"role": "system", "content": SYS_PROMPT},
        {"role": "user", "content": text_chunk}]
              for text_chunk in text_chunk_batch]

    prompt = tokenizer.apply_chat_template(conversations, tokenize=False)
    inputs = tokenizer(prompt, padding="max_length", padding_side="left", max_length=64, return_tensors="pt")
    inputs = inputs.to(accelerator.device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            temperature=0.01,
            top_p=0.9,
            max_new_tokens=5
        )
    # For answering yes/no questions like the following, extremely low temperature and max_new_tokens makes sense.
    # For this actual task its generations are not actually relevant, but they're helpful to see and might be useful later.

    processed_texts = tokenizer.batch_decode(output.sequences, skip_special_tokens=True)
    processed_texts = [text[text.index("assistant") + 10:].lstrip() for text in processed_texts]

    return processed_texts, output.hidden_states

A simple way to generate (generally) unambiguous statements is "*{Capital}* is the capital of *{Country}*".
With corresponding lists of each, we can generate false statements by just rotating

In [7]:
import requests

# Manually sorted through these to try to remove anything ambiguous/contested
capitals = requests.get("https://raw.githubusercontent.com/JDS289/BaLD4LLM/refs/heads/main/country-list.csv")
capitals_dict = {}
for string in capitals.content.decode('utf-8').split('\n')[1 : -1]:
  _, country, _, city, _ = string.split('"')
  capitals_dict[country] = city

print("\n".join(f"{key}: {value}" for key, value in capitals_dict.items()))

Afghanistan: Kabul
Albania: Tirana
Algeria: Algiers
Andorra: Andorra la Vella
Angola: Luanda
Antigua and Barbuda: St. John's
Argentina: Buenos Aires
Armenia: Yerevan
Australia: Canberra
Austria: Vienna
Azerbaijan: Baku
Bahamas: Nassau
Bahrain: Manama
Bangladesh: Dhaka
Barbados: Bridgetown
Belarus: Minsk
Belgium: Brussels
Belize: Belmopan
Benin: Porto-Novo
Bhutan: Thimphu
Bosnia and Herzegovina: Sarajevo
Botswana: Gaborone
Brazil: Brasília
Brunei: Bandar Seri Begawan
Bulgaria: Sofia
Burkina Faso: Ouagadougou
Burundi: Bujumbura
Cambodia: Phnom Penh
Cameroon: Yaoundé
Canada: Ottawa
Cape Verde: Praia
Central African Republic: Bangui
Chad: N'Djamena
Chile: Santiago
China: Beijing
Colombia: Bogotá
Comoros: Moroni
Costa Rica: San José
Croatia: Zagreb
Cuba: Havana
Cyprus: Nicosia
Côte d'Ivoire: Yamoussoukro
Democratic Republic of the Congo: Kinshasa
Denmark: Copenhagen
Djibouti: Djibouti
Dominica: Roseau
Dominican Republic: Santo Domingo
East Timor: Dili
Ecuador: Quito
Egypt: Cairo
El Salvador

In [8]:
true_list = [f"Is {capital} the capital of {country}?" for country, capital in capitals_dict.items()]
fixed_order_capitals_dict = list(capitals_dict.items())
false_list = [f"Is {fixed_order_capitals_dict[i][1]} the capital of {fixed_order_capitals_dict[(i+1) % len(capitals_dict)][0]}?" for i in range(len(capitals_dict))]
print(true_list[:10], false_list[:10])

def print_ret(x):
  print(x[:2])
  return x

results = [print_ret((q,) + process_chunk(q)) for q in tqdm(true_list)]
results_false = [print_ret((q,) + process_chunk(q)) for q in tqdm(false_list)]

['Is Kabul the capital of Afghanistan?', 'Is Tirana the capital of Albania?', 'Is Algiers the capital of Algeria?', 'Is Andorra la Vella the capital of Andorra?', 'Is Luanda the capital of Angola?', "Is St. John's the capital of Antigua and Barbuda?", 'Is Buenos Aires the capital of Argentina?', 'Is Yerevan the capital of Armenia?', 'Is Canberra the capital of Australia?', 'Is Vienna the capital of Austria?'] ['Is Kabul the capital of Albania?', 'Is Tirana the capital of Algeria?', 'Is Algiers the capital of Andorra?', 'Is Andorra la Vella the capital of Angola?', 'Is Luanda the capital of Antigua and Barbuda?', "Is St. John's the capital of Argentina?", 'Is Buenos Aires the capital of Armenia?', 'Is Yerevan the capital of Australia?', 'Is Canberra the capital of Austria?', 'Is Vienna the capital of Azerbaijan?']


  0%|          | 0/187 [00:00<?, ?it/s]

('Is Kabul the capital of Afghanistan?', '\nYes')
('Is Tirana the capital of Albania?', '\nYes')
('Is Algiers the capital of Algeria?', '\nYes')
('Is Andorra la Vella the capital of Andorra?', '\nYes')
('Is Luanda the capital of Angola?', '\nYes')
("Is St. John's the capital of Antigua and Barbuda?", '\nSt')
('Is Buenos Aires the capital of Argentina?', '\nYes')
('Is Yerevan the capital of Armenia?', '\nYes')
('Is Canberra the capital of Australia?', '\nYes')
('Is Vienna the capital of Austria?', '\nYes')
('Is Baku the capital of Azerbaijan?', '\nYes')
('Is Nassau the capital of Bahamas?', '\nYes')
('Is Manama the capital of Bahrain?', '\nYes')
('Is Dhaka the capital of Bangladesh?', '\nYes')
('Is Bridgetown the capital of Barbados?', '\nYes')
('Is Minsk the capital of Belarus?', '\nYes')
('Is Brussels the capital of Belgium?', '\nYes')
('Is Belmopan the capital of Belize?', '\nNo')
('Is Porto-Novo the capital of Benin?', '\nNo')
('Is Thimphu the capital of Bhutan?', '\nYes')
('Is Sara

  0%|          | 0/187 [00:00<?, ?it/s]

('Is Kabul the capital of Albania?', 'No')
('Is Tirana the capital of Algeria?', 'No')
('Is Algiers the capital of Andorra?', '\nNo')
('Is Andorra la Vella the capital of Angola?', 'No')
('Is Luanda the capital of Antigua and Barbuda?', '\nNo')
("Is St. John's the capital of Argentina?", '\nNo')
('Is Buenos Aires the capital of Armenia?', 'No')
('Is Yerevan the capital of Australia?', '\nNo')
('Is Canberra the capital of Austria?', '\nNo')
('Is Vienna the capital of Azerbaijan?', '\nNo')
('Is Baku the capital of Bahamas?', '\nNo')
('Is Nassau the capital of Bahrain?', '\nYes')
('Is Manama the capital of Bangladesh?', '\nNo')
('Is Dhaka the capital of Barbados?', 'No')
('Is Bridgetown the capital of Belarus?', '\nNo')
('Is Minsk the capital of Belgium?', '\nNo')
('Is Brussels the capital of Belize?', '\nNo')
('Is Belmopan the capital of Benin?', 'No')
('Is Porto-Novo the capital of Bhutan?', '\nNo')
('Is Thimphu the capital of Bosnia and Herzegovina?', 'No')
('Is Sarajevo the capital of