In [28]:
import sys
import os
import random
from tqdm import tqdm

import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from datasets import Dataset
from torch.utils.data import DataLoader

sys.path.append(os.getenv('SPARSE_PROBING_ROOT'))
from sparse_probing_paper.activations.activation_subset import load_activation_subset
from sparse_probing_paper.load import load_feature_dataset, load_model

In [29]:
# feature dataset
wikidata_property = 'sex_or_gender'
fd_n_seq = 6000
# wikidata_property = 'occupation_athlete'
# fd_n_seq = 5000

prefix = 'wikidata_sorted'

In [30]:
# ablation dataset
seq_per_class = 1000
n_sample_prompts = 3
# prompt = '{} plays the sport of'
prompt = '{} has gender'

In [31]:
classes = ['female', 'male']
class_map = {c: c for c in classes}

# classes = ['association football player', 'basketball player', 'American football player',
#            'baseball player', 'ice hockey player']
# class_map = {
#     'association football player': 'soccer',
#     'basketball player': 'basketball',
#     'American football player': 'football',
#     'baseball player': 'baseball',
#     'cricketer': 'cricket',
#     'ice hockey player': 'hockey',
# }

In [32]:
fd_name = f'{prefix}_{wikidata_property}.pyth.128.{fd_n_seq}'

fd = load_feature_dataset(fd_name)
fd

Dataset({
    features: ['name', 'text', 'tokens', 'name_index_start', 'name_index_end', 'surname_index_start', 'surname_index_end', 'class'],
    num_rows: 6000
})

In [33]:
random.seed(99)

### Make a feature dataset for ablations

In [34]:
all_names = list(zip(fd['name'], fd['class']))
all_names[:3]

[('Lady Gaga', 'female'),
 ('William Shakespeare', 'male'),
 ('Michael Jackson', 'male')]

In [35]:
class_names = {c: [] for c in classes}
for c in classes:
    class_names[c] = [example_n for example_n, example_c in zip(fd['name'], fd['class']) if example_c == c]
    random.shuffle(class_names[c])

for c in classes:
    print(f'{c}: {class_names[c][:3]}')

female: ['Kalki Koechlin', 'Sofia Coppola', 'Brenda Lee']
male: ['Hubert Humphrey', 'James Cameron', 'Charles Gounod']


In [36]:
tokenizer = load_model('pythia-70m').tokenizer

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m into HookedTransformer


In [37]:
seq_len = 128

dataset_list = []
for c in classes:
    for i in range(seq_per_class):
        name = class_names[c][i]

        # TODO: does order matter? should we balance across classes? should we give potential options with the prompt?
        example_prompt = ''
        for _ in range(n_sample_prompts):
            while True:
                sample_name, sample_class = random.choice(all_names)
                if sample_name != name:
                    break
            example_prompt += prompt.format(sample_name)
            example_prompt += f' {class_map[sample_class]}. '
        example_prompt += prompt.format(name)

        # tokenize
        tokens = tokenizer.encode(example_prompt)
        tokens = [tokenizer.bos_token_id] + tokens
        logit_index = len(tokens) - 1
        tokens = tokens + (seq_len - len(tokens)) * [tokenizer.pad_token_id]

        dataset_list.append({
            'prompt': example_prompt,
            'name': name,
            'class': c,
            'mapped_class': class_map[c],
            'tokens': tokens,
            'logit_index': logit_index,
        })

pd.DataFrame(dataset_list).head()

Unnamed: 0,prompt,name,class,mapped_class,tokens,logit_index
0,Arthur Sullivan has gender male. Abraham Linco...,Kalki Koechlin,female,female,"[0, 34021, 26211, 556, 8645, 5086, 15, 24958, ...",29
1,Bob Hope has gender male. Elena Poniatowska ha...,Sofia Coppola,female,female,"[0, 26845, 15541, 556, 8645, 5086, 15, 44846, ...",29
2,Nancy Reagan has gender female. Ellie Greenwic...,Brenda Lee,female,female,"[0, 47, 4306, 25556, 556, 8645, 5343, 15, 9545...",28
3,Tara Strong has gender female. Richard Meier h...,Kate Atkinson,female,female,"[0, 53, 4595, 24747, 556, 8645, 5343, 15, 7727...",26
4,Arthur Eddington has gender male. Pliny the Yo...,Anne Lamott,female,female,"[0, 34021, 20709, 16240, 556, 8645, 5086, 15, ...",30


In [38]:
dataset = Dataset.from_list(dataset_list)
dataset.set_format(columns=['tokens', 'logit_index'], type='pt', output_all_columns=True)
dataset

Dataset({
    features: ['prompt', 'name', 'class', 'mapped_class', 'tokens', 'logit_index'],
    num_rows: 2000
})

In [39]:
# save feature dataset
file_loc = os.path.join(
    os.getenv('FEATURE_DATASET_DIR'),
    'ablation_datasets',
    f'wikidata_ablations_{wikidata_property}.pyth.{seq_len}.{seq_per_class * len(classes)}'
)
os.makedirs(file_loc, exist_ok=True)
dataset.save_to_disk(file_loc)

file_loc

Saving the dataset (0/1 shards):   0%|          | 0/2000 [00:00<?, ? examples/s]

'/Users/mtp/Downloads/sparse-probing/feature_dataset_dir/ablation_datasets/wikidata_ablations_sex_or_gender.pyth.128.2000'

### Test model predictions

In [40]:
# TODO
model = load_model('pythia-70m')
# model = load_model('pythia-1b')

Using pad_token, but it is not set yet.


Loaded pretrained model pythia-70m into HookedTransformer


In [42]:
n = 20

all_logits = torch.empty((len(dataset), model.cfg.d_vocab), dtype=torch.float32)
dataloader = DataLoader(dataset['tokens'][:n], batch_size=1, shuffle=False)
for step, batch in enumerate(tqdm(dataloader)):
    logit_index = dataset['logit_index'][step].item()
    logits = model(batch, return_type='logits')
    all_logits[step] = logits[0,logit_index,:]

all_logits.shape

100%|██████████| 20/20 [00:00<00:00, 25.25it/s]


torch.Size([2000, 50304])

In [43]:
# print model predictions
n = 20
for i in range(n):
    tokens = dataset['tokens'][i].tolist()
    text = dataset['prompt'][i]
    true_class = dataset['class'][i]
    logits = all_logits[i]
    probs = torch.nn.functional.softmax(logits, dim=0)

    # greedy
    top_tokens = torch.argsort(logits)[-5:].tolist()[::-1]
    print(f'{text.split(".")[-1][1:]} (true class={true_class})')
    print(f'\t{list(zip(tokenizer.batch_decode(top_tokens), [round(p, 2) for p in probs[top_tokens].tolist()]))}')

Kalki Koechlin has gender (true class=female)
	[(' male', 0.76), (' female', 0.16), (' Male', 0.01), (' gender', 0.01), (' masculine', 0.01)]
Sofia Coppola has gender (true class=female)
	[(' male', 0.55), (' female', 0.35), (' gender', 0.01), (' feminine', 0.01), (' masculine', 0.0)]
Brenda Lee has gender (true class=female)
	[(' female', 0.63), (' male', 0.23), (' feminine', 0.01), (' woman', 0.01), (' gender', 0.01)]
Kate Atkinson has gender (true class=female)
	[(' male', 0.56), (' female', 0.37), (' Male', 0.01), (' gender', 0.01), (' masculine', 0.0)]
Anne Lamott has gender (true class=female)
	[(' male', 0.65), (' female', 0.28), (' gender', 0.01), (' Male', 0.01), (' masculine', 0.0)]
Belinda Carlisle has gender (true class=female)
	[(' female', 0.51), (' male', 0.43), (' gender', 0.01), (' Male', 0.01), (' feminine', 0.0)]
Claire Danes has gender (true class=female)
	[(' female', 0.6), (' male', 0.35), (' gender', 0.01), (' Male', 0.0), (' feminine', 0.0)]
Jackie DeShannon has