# Data preparation for causal tracing

In [2]:
import os, json
from copy import deepcopy

from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from experiments.causal_trace import (
    ModelAndTokenizer,
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)


In [3]:
# Instantiate model and tokenizer
MODEL_NAME = "EleutherAI/gpt-j-6B"
FILE_NAME = "tracing_prompts_eleutherAI-gpt-j-6b.json"
mt = ModelAndTokenizer(
    MODEL_NAME,
    torch_dtype=(torch.float16 if "20b" in MODEL_NAME else None),
)

## (1) Compile prompts for prediction

In [17]:
def load_terms(terms_filepath):
    """Load target terms from file
    :param terms_filepath path to csv file with target terms
    :returns dict with list of target terms for each bias type"""
    
    data = pd.read_csv(terms_filepath)
    
    gender = data["gender"]
    profession = data["profession"]
    race = data["race"]
    religion = data["religion"]
    
    return {'gender' : gender, 'profession' : profession, 'race' : race, 'religion' : religion}

def load_prompts(prompt_filepath):
    """Load prompt templates from file
    :param prompt_filepath path to csv file with prompt templates
    :returns list of prompt templates"""
    
    templates = []
    with open(prompt_filepath, 'r') as f:
        templates = f.readlines()
    
    return templates

In [19]:
def make_trace_prompts(gender_terms, profession_terms, race_terms, religion_terms, templates):
    """Create input prompts for tracing in json format
    :param gender_terms: target terms for gender bias
    :param profession_terms: target terms for profession bias
    :param race_terms: target terms for race bias
    :param religion_terms: target terms for religious bias
    :param templates: prompt templates
    return dictionary with input prompts for causal tracing"""

    prompts = []

    ids = 0

    for t in templates:
        t = t.strip()
        for g in gender_terms:
            if not pd.isna(g):
                prompt = {"known_id" : ids, "subject" : g, "attribute" : " ", "template" : t, "prediction" : " ", "prompt" : t.replace("{}", str(g)),
                          "relation_id" : "gender"}
                prompts.append(prompt)
                ids += 1

        for p in profession_terms:
            if not pd.isna(p):
                prompt = {"known_id" : ids, "subject" : p, "attribute" : " ", "template" : t, "prediction" : " ", "prompt" : t.replace("{}", str(p)),
                          "relation_id" : "profession"}
                prompts.append(prompt)
                ids += 1

        for ra in race_terms:
            if not pd.isna(ra):
                prompt = {"known_id": ids, "subject": ra, "attribute": " ", "template": t, "prediction": " ", "prompt" : t.replace("{}", str(ra)),
                          "relation_id": "race"}
                prompts.append(prompt)
                ids += 1

        for re in religion_terms:
            if not pd.isna(re):
                prompt = {"known_id" : ids, "subject" : re, "attribute" : " ", "template" : t, "prediction" : " ", "prompt" : t.replace("{}", str(re)),
                          "relation_id" : "religion"}
                prompts.append(prompt)
                ids += 1

    return prompts

In [13]:
def save_prompts_to_file(prompts, tracing_promtps_filename):
    """Save completed tracing prompts in json format
    :param prompts with subjects, attributes etc.
    :param tracing_prompts_filename file to save completed prompts to"""
    
    with open(tracing_promtps_filename, 'w', encoding='utf-8') as f:
        json.dump(prompts, f, ensure_ascii=False, indent=4)

In [20]:
# Load terms
TERMS = load_terms('../data/debiasing_data_preparation/stereoset_target_terms_english.csv')

# Load prompts
PROMPTS = load_prompts('../data/debiasing_data_preparation/templates_english.txt')

prompts = make_trace_prompts(TERMS["gender"], TERMS["profession"], TERMS["race"], TERMS["religion"], PROMPTS)

save_prompts_to_file(prompts, FILE_NAME)

## (2) Predictions

In [11]:
def load_complete_prompts(prompts_filename):
    """Load completed tracing prompts from file
    :param prompts_filename prompt file
    :returns list of prompts for prediction"""
    
    prompts = []
    
    with open(prompts_filename, 'r') as f:
        prompts = json.load(f)
    
    return prompts

p = load_complete_prompts(FILE_NAME)

In [9]:
# Make predictions
def predict_stereotypes(prompts, mt):
    """Predict continuations to the given prompts
    :param prompts: which should trigger stereotypical predictions
    :param mt: ModelandTokenizer object
    :returns completed prompt object for causal tracing"""
    
    completed_prompts = []
    for o in prompts:
        pred = predict_token(
        mt,[o['prompt']], return_p=True)
        prompt = {"known_id" : o['known_id'], "subject" : o['subject'], "attribute" : pred[0][0], 
                  "template" : o['template'], "prediction" : " " + pred[0][0], "prompt" : o['prompt'], 
                  "relation_id" : o['relation_id']}
        completed_prompts.append(prompt)
    
    return completed_prompts

c = predict_stereotypes(p, mt)

In [10]:
save_prompts_to_file(c, FILE_NAME)

## (3) Miscellaneous (clean formatting etc.)

In [15]:
# Update IDs
def update_ids(filename, start_id):
    """Update IDs such that they are in sequential order and that there
    are no duplicates (causes problems with causal tracing!)
    :param filename: tracing_prompts... file
    :param start_id: ID value to start numbering with
    :returns lists with updated data"""
    
    old_prompts = load_complete_prompts(filename)
    
    new_prompts = []
    
    for p in old_prompts:
        new_prompts.append({"known_id" : start_id, "subject" : p['subject'], "attribute" : p['attribute'], 
                  "template" : p['template'], "prediction" : p['prediction'], "prompt" : p['prompt'], 
                  "relation_id" : p['relation_id']})
        start_id += 1
    
    save_prompts_to_file(new_prompts, filename)

update_ids("tracing_prompts_eleutherAI-gpt-j-6b.json", 0)