In [None]:
import os
import pickle
import platform
import random

import numpy as np
import torch
from transformer_lens import HookedTransformer

torch.set_grad_enabled(False)

SEED = 42

In [2]:
def is_mps_available() -> bool:
    """
    Safely check if MPS (Metal Performance Shaders) is available.
    """
    if platform.system() != "Darwin":  # MPS is only available on macOS
        return False

    # Check if the current PyTorch version has MPS support
    if not hasattr(torch, "backends") or not hasattr(torch.backends, "mps"):
        return False

    return torch.backends.mps.is_available()


def set_all_seeds(
    seed: int, deterministic: bool = True, warn_only: bool = False
) -> None:
    """
    Set all seeds and deterministic flags for reproducibility.

    Args:
        seed (int): The seed value to use for all random number generators
        deterministic (bool): Whether to enforce deterministic behavior
        warn_only (bool): If True, warning instead of error when deterministic
                         operations aren't supported
    """
    # Python RNG
    random.seed(seed)

    # NumPy RNG
    np.random.seed(seed)

    # PyTorch RNGs
    torch.manual_seed(seed)

    # Handle CUDA devices
    if torch.cuda.is_available():
        try:
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)  # for multi-GPU
        except Exception as e:
            print(f"Warning: Could not set CUDA seeds: {str(e)}")

    # Handle MPS device (Apple Silicon)
    if is_mps_available():
        try:
            torch.mps.manual_seed(seed)
        except Exception as e:
            print(f"Warning: Could not set MPS seed: {str(e)}")

    # Environment variables
    os.environ["PYTHONHASHSEED"] = str(seed)

    if deterministic:
        try:
            # CUDA deterministic behavior
            if torch.cuda.is_available():
                os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False

            # Set deterministic algorithms
            torch.use_deterministic_algorithms(True, warn_only=warn_only)

        except Exception as e:
            msg = f"Warning: Could not enable deterministic mode. Error: {str(e)}"
            if not warn_only:
                raise RuntimeError(msg)
            print(msg)


def get_device() -> torch.device:
    """
    Get the most appropriate PyTorch device available.

    Returns:
        torch.device: The preferred device (CUDA > MPS > CPU)
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    if is_mps_available():
        return torch.device("mps")
    return torch.device("cpu")

In [None]:
set_all_seeds(SEED, warn_only=True)
device = get_device()
model_name = "EleutherAI/pythia-70m"
model: HookedTransformer = HookedTransformer.from_pretrained(
    model_name, device=device
).eval()

In [4]:
bank = {
    "city-country": {
        "bank": {
            "Bangkok": "Thailand",
            "Beijing": "China",
            "Buenos Aires": "Argentina",
            "Cape Town": "South Africa",
            "Hong Kong": "China",
            "Kuala Lumpur": "Malaysia",
            "Los Angeles": "United States",
            "Mexico City": "Mexico",
            "New Delhi": "India",
            "New York City": "United States",
            "Paris": "France",
            "Rio de Janeiro": "Brazil",
            "Rome": "Italy",
            "San Francisco": "United States",
            "St. Petersburg": "Russia",
            "Sydney": "Australia",
            "Tokyo": "Japan",
            "Toronto": "Canada",
        },
        "prompt": "%s is a city in the country of",
    },
    "city-continent": {
        "bank": {
            "Bangkok": "Asia",
            "Beijing": "Asia",
            "Buenos Aires": "South America",
            "Cape Town": "Africa",
            "Hong Kong": "Asia",
            "Kuala Lumpur": "Asia",
            "Los Angeles": "North America",
            "Mexico City": "North America",
            "New Delhi": "Asia",
            "New York City": "North America",
            "Paris": "Europe",
            "Rio de Janeiro": "South America",
            "Rome": "Europe",
            "San Francisco": "North America",
            "St. Petersburg": "Europe",
            "Sydney": "Oceania",
            "Tokyo": "Asia",
            "Toronto": "North America",
        },
        "prompt": "%s is a city in the continent of",
    },
    "city-language": {
        "bank": {
            "Bangkok": "Thai",
            "Beijing": "Chinese",
            "Buenos Aires": "Spanish",
            "Cape Town": "Afrikaans",
            "Kuala Lumpur": "Malay",
            "Los Angeles": "English",
            "Mexico City": "Spanish",
            "New Delhi": "Hindi",
            "New York City": "English",
            "Paris": "French",
            "Rio de Janeiro": "Portuguese",
            "Rome": "Italian",
            "San Francisco": "English",
            "St. Petersburg": "Russian",
            "Sydney": "English",
            "Tokyo": "Japanese",
            "Toronto": "English",
        },
        "prompt": "%s is a city where the language spoken is",
    },
    "occupation-duty": {
        "bank": {
            "to plan and design the construction of buildings": "architect",
            "to represent clients in court during trial": "lawyer",
            "to diagnose and treat issues related to the teeth": "dentist",
            "to create clothing and accessories according to current trends": "fashion designer",
            "to report news and information": "journalist",
            "to educate students in a classroom setting": "teacher",
            "to capture images using cameras": "photographer",
            "to lead and guide a team of representatives, employees, or workers": "manager",
            "to code, test, and maintain computer programs and applications": "software developer",
            "to prepare and cook food in a professional kitchen": "chef",
            "to care for patients in a medical setting and assist doctors": "nurse",
            "to enforce laws and protect citizens from crime": "police officer",
            "to repair and maintain vehicles and machinery in a workshop, garage, or factory": "mechanic",
            "to conduct research and experiments in a laboratory setting": "scientist",
            "to create visual art using various mediums, such as paint, clay, or digital tools": "artist",
            "to play musical instruments and perform for audiences in various settings": "musician",
        },
        "prompt": "My duties are %s; I am a",
    },
    "object-color": {
        "bank": {
            "apple": "red",
            "banana": "yellow",
            "carrot": "orange",
            "grape": "purple",
            "lemon": "yellow",
            "lime": "green",
            "orange": "orange",
            "pear": "green",
            "strawberry": "red",
            "tomato": "red",
            "blueberry": "blue",
            "cherry": "red",
            "eggplant": "purple",
            "kiwi": "brown",
            "peach": "orange",
            "plum": "purple",
            "watermelon": "green",
            "avocado": "green",
        },
        "prompt": "The color of the %s is usually",
    },
    "object-size": {
        "bank": {
            "diamond": "millimeter",
            "bamboo": "meter",
            "axe": "meter",
            "gopher": "centimeter",
            "saffron": "millimeter",
            "lime": "centimeter",
            "turmeric": "centimeter",
            "lion": "meter",
            "violet": "centimeter",
            "starfish": "centimeter",
            "charcoal": "centimeter",
            "turquoise": "centimeter",
            "flamingo": "meter",
            "pig": "meter",
            "cornmeal": "centimeter",
            "blackberry": "centimeter",
        },
        "prompt": 'Considering the following units " millimeter ", " centimeter ", " meter ", and " kilometer ", the size of %s is commonly expressed in "',
    },
}


In [5]:
def make_prompt(entity: str = "city", target: str = "country", num_prompts: int = 30):
    pair = "{entity}-{target}".format(entity=entity, target=target)
    data = bank[pair]
    pairs = data["bank"]
    prompt = data["prompt"]
    cities = list(pairs.keys())

    # Generate N 3-shot prompts: "A is a city in the country of B. C is a city in the country of D. E is a city in the country of " ground_truth = F
    prompts = []
    for i in range(num_prompts):
        # Randomly sample 3 cities from the list + 1 ground truth city to be queried (all different)
        cities_sampled = torch.randperm(len(cities))[:4]
        countries_sampled = [pairs[cities[i]] for i in cities_sampled]
        ground_truth = countries_sampled[-1]

        # Generate the prompt
        prompt_instance = (
            prompt % cities[cities_sampled[0]]
            + " "
            + countries_sampled[0]
            + ". "
            + prompt % cities[cities_sampled[1]]
            + " "
            + countries_sampled[1]
            + ". "
            + prompt % cities[cities_sampled[2]]
            + " "
            + countries_sampled[2]
            + ". "
            + prompt % cities[cities_sampled[3]]
        )

        prompts.append((prompt_instance, ground_truth))
    return prompts

In [21]:
# city - country
city_country_prompts = make_prompt(entity="city", target="country", num_prompts=1000)

# city - continent
city_continent_prompts = make_prompt(entity="city", target="continent", num_prompts=1000)

# city - language
city_language_prompts = make_prompt(entity="city", target="language", num_prompts=1000)

# duty - occupation
duty_occupation_prompts = make_prompt(
    entity="occupation", target="duty", num_prompts=1000
)

# object - color
object_color_prompts = make_prompt(entity="object", target="color", num_prompts=1000)

# object - size
object_size_prompts = make_prompt(entity="object", target="size", num_prompts=1000)

In [22]:
# Print some samples
print(city_country_prompts[0])
print(city_continent_prompts[0])
print(city_language_prompts[0])
print(duty_occupation_prompts[0])
print(object_color_prompts[0])
print(object_size_prompts[0])

('Buenos Aires is a city in the country of Argentina. Los Angeles is a city in the country of United States. New Delhi is a city in the country of India. Mexico City is a city in the country of', 'Mexico')
('Cape Town is a city in the continent of Africa. Rome is a city in the continent of Europe. St. Petersburg is a city in the continent of Europe. Los Angeles is a city in the continent of', 'North America')
('Buenos Aires is a city where the language spoken is Spanish. Los Angeles is a city where the language spoken is English. New York City is a city where the language spoken is English. St. Petersburg is a city where the language spoken is', 'Russian')
('My duties are to lead and guide a team of representatives, employees, or workers; I am a manager. My duties are to repair and maintain vehicles and machinery in a workshop, garage, or factory; I am a mechanic. My duties are to conduct research and experiments in a laboratory setting; I am a scientist. My duties are to diagnose and 

In [23]:
# model.generate(
#     "She is living in Rome, therefore her country of residence is Italy. She is living in New Delhi, therefore her country of residence is India. She is living in Washington, therefore her country of residence is",
#     max_new_tokens=5,
#     temperature=1.0,
#     prepend_bos=True,
# )  # This will print the output of the model

In [None]:
city_country_correct = 0
city_country_retained = []
for prompt, gt in city_country_prompts:
    print(prompt)
    generation = model.generate(
        prompt, max_new_tokens=1, temperature=0.0, prepend_bos=True
    )
    print(generation)

    print("Ground truth: ", gt)
    print("\n\n")
    print("---------------------------------------------------")
    if gt.startswith(generation[len(prompt) :].strip()):
        city_country_correct += 1
        city_country_retained.append((prompt, gt))

print("Accuracy: ", city_country_correct / len(city_country_prompts))

In [None]:
city_continent_correct = 0


city_continent_retained = []
for prompt, gt in city_continent_prompts:
    print(prompt)
    generation = model.generate(
        prompt, max_new_tokens=1, temperature=0.0, prepend_bos=True
    )
    print(generation)
    print("Ground truth: ", gt)
    print("\n\n")
    print("---------------------------------------------------")
    if gt.startswith(generation[len(prompt) :].strip()):
        city_continent_correct += 1
        city_continent_retained.append((prompt, gt))

print("Accuracy: ", city_continent_correct / len(city_continent_prompts))

In [None]:
city_language_correct = 0
city_language_retained = []
for prompt, gt in city_language_prompts:
    print(prompt)
    generation = model.generate(
        prompt, max_new_tokens=1, temperature=0.0, prepend_bos=True
    )
    print(generation)
    print("Ground truth: ", gt)
    print("\n\n")
    print("---------------------------------------------------")
    if gt.startswith(generation[len(prompt) :].strip()):
        city_language_correct += 1
        city_language_retained.append((prompt, gt))

print("Accuracy: ", city_language_correct / len(city_language_prompts))

In [None]:
duty_occupation_correct = 0
duty_occupation_retained = []
for prompt, gt in duty_occupation_prompts:
    print(prompt)
    generation = model.generate(
        prompt, max_new_tokens=1, temperature=0.0, prepend_bos=True
    )
    print(generation)
    print("Ground truth: ", gt)
    print("\n\n")
    print("---------------------------------------------------")
    if gt.startswith(generation[len(prompt) :].strip()):
        duty_occupation_correct += 1
        duty_occupation_retained.append((prompt, gt))

print("Accuracy: ", duty_occupation_correct / len(duty_occupation_prompts))

In [None]:
object_color_correct = 0
object_color_retained = []
for prompt, gt in object_color_prompts:
    print(prompt)
    generation = model.generate(
        prompt, max_new_tokens=1, temperature=0.0, prepend_bos=True
    )
    print(generation)
    print("Ground truth: ", gt)
    print("\n\n")
    print("---------------------------------------------------")
    if gt.startswith(generation[len(prompt) :].strip()):
        object_color_correct += 1
        object_color_retained.append((prompt, gt))

print("Accuracy: ", object_color_correct / len(object_color_prompts))

In [None]:
object_size_correct = 0
object_size_retained = []
for prompt, gt in object_size_prompts:
    print(prompt)
    generation = model.generate(
        prompt, max_new_tokens=1, temperature=0.0, prepend_bos=True
    )
    print(generation)
    print("Ground truth: ", gt)
    print("\n\n")
    print("---------------------------------------------------")
    if gt.startswith(generation[len(prompt) :].strip()):
        object_size_correct += 1
        object_size_retained.append((prompt, gt))

print("Accuracy: ", object_size_correct / len(object_size_prompts))

In [30]:
# Save in a pickle file the data {(entity-target): {accuracy : accuracy, retained: [(prompt, gt)], full_data: [(prompt, gt)]}}

data = {
    "city-country": {
        "correct_samples_number": city_country_correct,
        "accuracy": city_country_correct / len(city_country_prompts),
        "retained": city_country_retained,
        "full_data": city_country_prompts,
    },
    "city-continent": {
        "correct_samples_number": city_continent_correct,
        "accuracy": city_continent_correct / len(city_continent_prompts),
        "retained": city_continent_retained,
        "full_data": city_continent_prompts,
    },
    "city-language": {
        "correct_samples_number": city_language_correct,
        "accuracy": city_language_correct / len(city_language_prompts),
        "retained": city_language_retained,
        "full_data": city_language_prompts,
    },
    "occupation-duty": {
        "correct_samples_number": duty_occupation_correct,
        "accuracy": duty_occupation_correct / len(duty_occupation_prompts),
        "retained": duty_occupation_retained,
        "full_data": duty_occupation_prompts,
    },
    "object-color": {
        "correct_samples_number": object_color_correct,
        "accuracy": object_color_correct / len(object_color_prompts),
        "retained": object_color_retained,
        "full_data": object_color_prompts,
    },
    "object-size": {
        "correct_samples_number": object_size_correct,
        "accuracy": object_size_correct / len(object_size_prompts),
        "retained": object_size_retained,
        "full_data": object_size_prompts,
    },
}

os.makedirs("prompt_data/EleutherAI", exist_ok=True)

with open(f"prompt_data/{model_name}_prompt_data.pkl", "wb") as f:
    pickle.dump(data, f)

In [31]:
# Print a recap of the results
for key, value in data.items():
    print(f"Accuracy for {key}: {value['correct_samples_number']}")
    print("\n")
    print("---------------------------------------------------")

Accuracy for city-country: 329


---------------------------------------------------
Accuracy for city-continent: 318


---------------------------------------------------
Accuracy for city-language: 224


---------------------------------------------------
Accuracy for occupation-duty: 19


---------------------------------------------------
Accuracy for object-color: 108


---------------------------------------------------
Accuracy for object-size: 307


---------------------------------------------------


: 