In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils
import numpy as np
import logging
from src import models

from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

torch.__version__, transformers.__version__, torch.version.cuda

('2.1.2+cu121', '4.39.0.dev0', '12.1')

In [3]:
# from src.data.dataclasses import MultiCounterFactDataset

# dataset = MultiCounterFactDataset("../data")
# print(json.dumps(dataset[5].to_dict(), indent=2))

In [4]:
from src.dataset.rome_dataclasses import CounterFactDataset

DATA_DIR = "../data"

counterfact = CounterFactDataset(DATA_DIR)

Loaded dataset with 21919 elements


In [5]:
from src.models import ModelandTokenizer
from src.functional import predict_next_token
from src.functional import is_nontrivial_prefix, get_tick_marker
from tqdm.auto import tqdm
from typing import Optional

@torch.inference_mode()
def filter_counterfact_samples_by_model_knowledge(
    mt: ModelandTokenizer, counterfact: CounterFactDataset, limit: Optional[int] = None
) -> list:
    """Filter samples by model knowledge."""

    filtered_samples = []
    progress = tqdm(range(len(counterfact)))
    for idx in progress:
        sample = counterfact[idx]
        subject = sample['requested_rewrite']['subject']
        prompt_template = sample['requested_rewrite']['prompt']
        question = prompt_template.format(subject)
        answer = sample['requested_rewrite']['target_true']['str']
        predictions = predict_next_token(mt, question, k=5)[0]
        top_pred = predictions[0]
        is_known = is_nontrivial_prefix(prediction=top_pred.token, target=answer)
        
        logger.debug(
            f"{question} -> {answer=} | predicted = '{top_pred.token}'({top_pred.prob:.3f}) ==> ({get_tick_marker(is_known)})"
        )

        if is_known:
            filtered_samples.append(sample)
        
        progress.set_description(f"known={len(filtered_samples)}/{idx+1}")

        if limit is not None and len(filtered_samples) >= limit:
            break

    logger.info(
        f'filtered to {len(filtered_samples)} samples / {len(counterfact)}'
    )

    return filtered_samples

In [6]:
MODEL_PATH = "state-spaces/mamba-2.8b" # state-spaces/mamba-2.8b
# MODEL_PATH = "EleutherAI/pythia-2.8b-deduped"

mt = ModelandTokenizer(
    model_path=MODEL_PATH, 
    torch_dtype=torch.float32
)

2024-03-19 18:50:26 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443


2024-03-19 18:50:26 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/config.json HTTP/1.1" 200 0
2024-03-19 18:50:37 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/pytorch_model.bin HTTP/1.1" 302 0


  return self.fget.__get__(instance, owner)()


2024-03-19 18:50:40 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-03-19 18:50:40 src.models INFO     loaded model <state-spaces/mamba-2.8b> | size: 10560.400 MB | dtype: torch.float32 | device: cuda


In [7]:
filtered_samples = filter_counterfact_samples_by_model_knowledge(mt, counterfact[0 : 100])

  0%|          | 0/100 [00:00<?, ?it/s]

2024-03-19 18:50:41 __main__ DEBUG    The mother tongue of Danielle Darrieux is -> answer='French' | predicted = ' French'(0.568) ==> (✓)
2024-03-19 18:50:41 __main__ DEBUG    The official religion of Edwin of Northumbria is -> answer='Christianity' | predicted = ' Christianity'(0.110) ==> (✓)
2024-03-19 18:50:41 __main__ DEBUG    Toko Yasuda, the -> answer='guitar' | predicted = ' president'(0.038) ==> (✗)
2024-03-19 18:50:41 __main__ DEBUG    Autonomous University of Madrid, which is located in -> answer='Spain' | predicted = ' the'(0.366) ==> (✗)
2024-03-19 18:50:41 __main__ DEBUG    What is the twin city of Lyon? It is -> answer='Beirut' | predicted = ' the'(0.064) ==> (✗)
2024-03-19 18:50:41 __main__ DEBUG    The mother tongue of Thomas Joannes Stieltjes is -> answer='Dutch' | predicted = ' Dutch'(0.176) ==> (✓)
2024-03-19 18:50:41 __main__ DEBUG    Anaal Nathrakh, that was created in -> answer='Birmingham' | predicted = ' the'(0.182) ==> (✗)
2024-03-19 18:50:41 __main__ DEBUG    

In [8]:
with open(f"{DATA_DIR}/{MODEL_PATH.split('/')[-1]}-known.json", "w") as f:
    json.dump(filtered_samples, f, indent=2)