In [1]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import deque
from transformers import AutoTokenizer
import pdb

from scipy.stats import normaltest
import os
# print(os.getcwd())
from torch.nn import functional as F

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# from src.models import get_tokenizer_and_model
# from src.modified_forward import ModelWrapper
# from src.utils import *
# from src.probe_utils import *

%load_ext autoreload
%autoreload 2

"""values that you need to change"""
PROJECT_PATH = "/home/tianyu/TwoHopIC" # the path of your working directory
access_token = None


  from .autonotebook import tqdm as notebook_tqdm


### generate dataset for finetuning and testing

In [2]:
import random
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import deque
from transformers import AutoTokenizer
import pdb
from tqdm import tqdm

from scipy.stats import normaltest
import os
import json
from torch.nn import functional as F
random.seed(42)
MODEL_OPTIONS = {
    "qwen": {
        "name": "Qwen/Qwen2.5-7B",
        "dirname": os.path.join(PROJECT_PATH, "qwen2.5"),
        "trust_remote_code": True
    },
    "llama3-8b": {
        "name": "meta-llama/Meta-Llama-3-8B",
        "dirname": os.path.join(PROJECT_PATH, "llama3.1"),
        "trust_remote_code": False
    },
    "llama3-13b": {
        "name": "meta-llama/Meta-Llama-3-70B",
        "dirname": os.path.join(PROJECT_PATH, "llama3.1-70b"),
        "trust_remote_code": False
    },
    "olmo": {
        "name": "allenai/OLMo-7B-hf",
        "dirname": os.path.join(PROJECT_PATH, "olmo"),
        "trust_remote_code": True
    }
}

# Template and entity definitions


proto_template = [
    "[A] is the mother of [B]. [B] is the mother of [C]. Therefore, [A] is the grandmother of",
    "[A] is the father of [B]. [B] is the father of [C]. Therefore, [A] is the grandfather of",
    "[A] is a city in the state of [B]. The state of [B] is part of the country [C]. Therefore, [A] is located in",
    "[A] is a species in the genus [B]. The genus [B] belongs to the family [C]. Therefore, [A] is classified under the family",
    "[A] follows the time zone of [B]. [B] is three hours ahead of [C]. Therefore, [A] is three hours ahead of",
    "[A] lives in [B]. People in [B] speak [C]. Therefore, [A] speaks"
]

mixed_locations = ["Zorvath", "Tyseria", "Kryo", "Vynora", "Quellion", "Dras", 
                  "Luminax", "Vesperon", "Noctari", "Xyphodon", "Glacidae", "Ophirion",
                  "Eryndor", "Solmyra", "Umbrithis", "Balthorien", "Ytheris", "Fendrel", "Havroth", "Marendor"]

mixed_biology = ["Fluxilus", "Varnex", "Dranthidae", "Zynthor", "Gryvus", "Myralin",
                "Thalorium", "Zephyra", "Aerinth", "Xyphodon", "Kryostis", "Glacidae",
                "Borithis", "Chrysalix", "Noctilura", "Phorvian", "Seraphid", "Uthrelin",
                "Eldrinth", "Yvorith"]

languages = ["English", "Spanish", "Mandarin", "Hindi", "Arabic", 
            "French", "German", "Japanese", "Portuguese", "Russian",
            "Korean", "Italian", "Turkish", "Dutch", "Swedish", 
            "Polish", "Hebrew", "Greek", "Bengali", "Thai"]

short_names = ["Ben", "Jack", "Luke", "Mark", "Paul", "John", "Tom", 
              "Sam", "Joe", "Max", "Amy", "Emma", "Anna", "Grace", 
              "Kate", "Lucy", "Sarah", "Alice", "Alex", "Ruby"]

# Token handling functions
def get_token_index(tokenizer, name, model_type):
    """
    Get the correct token index based on model type, handling BOS tokens appropriately
    """
    tokens = tokenizer.encode(' ' + name)
    
    if model_type in ["llama3-8b", "llama3-70b"]:
        return tokens[1]  # Skip BOS token
    elif model_type == "olmo":
        return tokens[0]
    elif model_type == "qwen":
        return tokens[0]
    else:
        raise ValueError(f"Unknown model type: {model_type}")

def process(sentence, category, names):
    """
    Replaces the placeholders [A], [B], [C] in `sentence` with the given list of `names`.
    """
    return (sentence
            .replace("[A]", names[0])
            .replace("[B]", names[1])
            .replace("[C]", names[2])
           )

def pick_category_and_names(template_str):
    """
    Return category from template_str.
    """
    if "speak" in template_str:
        return "language"
    elif "city" in template_str or "located" in template_str or "time zone" in template_str:
        return "geography"
    elif "species" in template_str or "genus" in template_str:
        return "biology"
    else:
        return "human"

def sample_names(category):
    """
    Returns one triple of names depending on the category.
    """
    if category == "language":
        return [
            random.choice(short_names),
            random.choice(mixed_locations),
            random.choice(languages)
        ]
    elif category == "geography":
        return random.sample(mixed_locations, 3)
    elif category == "biology":
        return random.sample(mixed_biology, 3)
    else:
        return random.sample(short_names, 3)

def sample_k_disjoint_sets(category, k):
    """
    Returns a list of k disjoint name-triples for the given category.
    """
    sets = []
    if category == "language":
        sn_candidates = random.sample(short_names, k)
        loc_candidates = random.sample(mixed_locations, k)
        lang_candidates = random.sample(languages, k)
        for i in range(k):
            sets.append([sn_candidates[i], loc_candidates[i], lang_candidates[i]])
    
    elif category == "geography":
        loc_sample = random.sample(mixed_locations, 3*k)
        for i in range(k):
            sets.append(loc_sample[3*i : 3*i+3])
    
    elif category == "biology":
        bio_sample = random.sample(mixed_biology, 3*k)
        for i in range(k):
            sets.append(bio_sample[3*i : 3*i+3])
    
    else:  # human relationship
        sn_sample = random.sample(short_names, 3*k)
        for i in range(k):
            sets.append(sn_sample[3*i : 3*i+3])
    
    return sets

# Main data generation loop
ft_data = {}

for model_type, model_config in MODEL_OPTIONS.items():
    # Load tokenizer for the current model
    tokenizer = AutoTokenizer.from_pretrained(
        model_config["name"],
        trust_remote_code=model_config["trust_remote_code"],
        timeout=10 
    )
    
    # Create tracked indices
    all_C_entities = short_names + mixed_locations + mixed_biology + languages
    tracked_indices = {
        name: get_token_index(tokenizer, name, model_type) 
        for name in all_C_entities
    }
    
    ft_data[model_type] = {}
    
    for k in range(1, 2):
        ft_data[model_type][k] = []
        for _ in tqdm(range(1000)):
            # Pick template and category
            temp = random.choice(proto_template)
            category = pick_category_and_names(temp)
            
            # Split template and sample names
            sentences = temp.split(". ")
            k_name_sets = sample_k_disjoint_sets(category, k)
            
            # Create deques with processed sentences
            all_deques = []
            for names_i in k_name_sets:
                dq_i = deque([process(s, category, names_i) for s in sentences])
                all_deques.append((dq_i, names_i))
            
            # Generate text by popping from deques
            text = ""
            while any(len(dq_tuple[0]) > 1 for dq_tuple in all_deques):
                indices_lengths = [
                    (idx, len(dq) - 1) 
                    for idx, (dq, _) in enumerate(all_deques) 
                    if len(dq) > 1
                ]
                
                total_len_minus_1 = sum(x[1] for x in indices_lengths)
                if total_len_minus_1 == 0:
                    break
                
                # Weighted random choice
                rand_val = random.random()
                cumulative = 0.0
                chosen_idx = None
                for (idx_deq, l_minus_1) in indices_lengths:
                    frac = l_minus_1 / total_len_minus_1
                    if rand_val < cumulative + frac:
                        chosen_idx = idx_deq
                        break
                    cumulative += frac
                
                if chosen_idx is not None:
                    chosen_dq, _ = all_deques[chosen_idx]
                    text += chosen_dq.popleft() + ". "
            
            # Handle final query
            query_idx = random.randint(0, k - 1)
            query_dq, query_names = all_deques[query_idx]
            
            if len(query_dq) > 0:
                text += query_dq.popleft()
            
            # Get answer and token indices
            ans = " " + query_names[-1]
            query_names_ids = [tracked_indices[n] for n in query_names]
            non_query_names_ids = [
                tracked_indices[n] 
                for idx, (_, names) in enumerate(all_deques) 
                if idx != query_idx 
                for n in names
            ]
            
            # Save result
            ft_data[model_type][k].append({
                'question': text,
                'answer': ans,
                'query_names': query_names_ids,
                'non_query_names': non_query_names_ids,
            })

    with open(os.path.join(model_config["dirname"], f"test_short.json"), "w", encoding="utf-8") as f:
        json.dump(ft_data[model_type], f)

100%|██████████| 1000/1000 [00:00<00:00, 87848.03it/s]
100%|██████████| 1000/1000 [00:00<00:00, 68860.68it/s]


ValueError: Unknown model type: llama3-13b