## twohop in-context test

In [3]:
import os
import torch
from safetensors.torch import load_file, save_file
import os
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
from tqdm import tqdm


import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
# load meta-llama llama3.1 model using vllm:
# we haven't tested the llama3.1-70b model using vllm
MODEL_OPTIONS = {
    "qwen": {
        "name": "Qwen/Qwen2.5-7B",
        "dirname": "./qwen2.5",
        "trust_remote_code": True
    },
    "llama3-8b": {
        "name": "meta-llama/Meta-Llama-3-8B",
        "dirname": "./llama3.1",
        "trust_remote_code": False
    },
    "llama3-70b": {
        "name": "meta-llama/Meta-Llama-3-70B",
        "dirname": "./llama3.1-70b",
        "trust_remote_code": False
    },
    "olmo": {
        "name": "allenai/OLMo-7B-hf",
        "dirname": "./olmo",
        "trust_remote_code": True
    }
}
model_type = "llama3-70b"  # Change this to use different models
model_config = MODEL_OPTIONS[model_type]
tokenizer = AutoTokenizer.from_pretrained(model_config["name"])

In [None]:
from vllm import LLM, SamplingParams

# Initialize the LLM with the specified model and enable multi-GPU inference
llm = LLM(model=model_config["name"], tensor_parallel_size=8)

# Prepare the prompt
prompt = "Ben lives in Vesperon. People in Vesperon speak Arabic. Mark lives in Noctari. People in Noctari speak Japanese. Therefore, Ben speaks"

# Prepare the sampling parameters
sampling_params = SamplingParams(temperature=1.0)

# Run inference using vllm with 8 GPUs
outputs = llm.generate([prompt], sampling_params)

# The outputs variable now contains the inference results


INFO 02-18 13:52:48 config.py:542] This model supports multiple tasks: {'score', 'reward', 'generate', 'embed', 'classify'}. Defaulting to 'generate'.
INFO 02-18 13:52:48 config.py:1401] Defaulting to use mp for distributed inference
INFO 02-18 13:52:48 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.2) with config: model='meta-llama/Meta-Llama-3-70B', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-70B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_

In [3]:
from vllm import LLM, SamplingParams

# Initialize the LLM with the specified model
llm = LLM(model=model_config["name"])

# Prepare the sampling parameters
sampling_params = SamplingParams(temperature=0.0, num_samples=100)

# Run inference using vllm
outputs = llm.generate([prompt], sampling_params)

# The outputs variable now contains the inference results

                                                                                                      

In [None]:
from tqdm import tqdm
import json
from typing import List, Dict, Any
import os

dirname = model_config["dirname"]

def split_batch(batch: List[Any], batch_size: int) -> List[List[Any]]:
    """Split a batch into smaller batches of specified size."""
    return [batch[i:i + batch_size] for i in range(0, len(batch), batch_size)]

def process_batch(
    inputs: List[str],
    tracked_indices: List[List[str]],
    llm: Any,
    batch_size: int = 32,  # Adjust this based on your GPU memory
    maxlength: int = 2048
) -> List[Dict[str, Any]]:
    """Process a batch of inputs with dynamic splitting to avoid OOM."""
    # Split inputs into smaller batches
    input_batches = split_batch(inputs, batch_size)
    tracked_batches = split_batch(tracked_indices, batch_size)
    all_outputs = []

    # Process each mini-batch
    for input_batch, tracked_batch in zip(input_batches, tracked_batches):
        # Prepare the sampling parameters
        sampling_params = SamplingParams(max_length=maxlength)
        
        # Run inference using vllm
        outputs = llm.generate(input_batch, sampling_params)
        
        # Collect outputs
        all_outputs.extend(outputs)
    
    return all_outputs

def main():
    # Load your dataset
    with open(os.path.join(dirname, f"test_long.json"), "r") as f:
        test_long = json.load(f)
    
    # Set batch size based on your GPU memory
    BATCH_SIZE = 200  # Adjust this value based on your GPU memory
    
    # Initialize the LLM with the specified model
    llm = LLM(model=model_config["name"])
    
    # Process each main batch
    for k, input_texts in tqdm(test_long.items()):
        # Extract inputs and tracked indices
        inputs = [pair["question"] for pair in input_texts]
        tracked_indices = [pair["query_names"] + pair["non_query_names"] for pair in input_texts]
        # Calculate max length for this mini-batch
        lengths = [len(input) for input in inputs]
        maxlength = max(lengths)
        # Process the batch with dynamic splitting
        outputs = process_batch(
            inputs=inputs,
            tracked_indices=tracked_indices,
            llm=llm,
            batch_size=BATCH_SIZE,
            maxlength=maxlength
        )
        
        # Save outputs
        save_name = os.path.join(dirname, f"outputs_hopk{k}.json")
        with open(save_name, "w") as f:
            json.dump(outputs, f)

main()

In [18]:
from torch.nn import functional as F
import numpy as np
import json
import os
from tqdm import tqdm
import torch

def get_tracked_prob(input_texts, logits, tokenizer):
    """Calculate the tracked probabilities for given input_texts and logits."""
    inputs = [pair['question'] for pair in input_texts]
    tracked_indices = [pair['query_names'] + pair['non_query_names'] for pair in input_texts]
    lengths = [len(tokenizer(input).input_ids) for input in inputs]
    check_indices = torch.LongTensor([[i, l-1, j] for i, l in enumerate(lengths) for j in tracked_indices[i]])
    probs = F.softmax(logits, dim=-1)
    tracked_prob = probs[check_indices[:, 0], check_indices[:, 1], check_indices[:, 2]]
    return tracked_prob

def split_topics(input_texts):
    keywords = ["locate", "grand", "family", "three"]
    display_keywords = {
        "locate": "geography",
        "grand": "relations",
        "family": "biology",
        "three": "arithmetic",
        "other": "other"
    }
    topic_dict = {keyword: [] for keyword in keywords}
    topic_dict["other"] = []
    topic_indices = {keyword: [] for keyword in keywords}
    topic_indices["other"] = []

    for idx, pair in enumerate(input_texts):
        question = pair["question"]
        found = False
        for keyword in keywords:
            if keyword in question:
                topic_dict[keyword].append(pair)
                topic_indices[keyword].append(idx)
                found = True
                break
        if not found:
            topic_dict["locate"].append(pair)
            topic_indices["locate"].append(idx)
    
    for topic, texts in topic_dict.items():
        yield display_keywords[topic], texts, topic_indices[topic]

parent_dir = os.path.dirname(model_config["dirname"])
tracked_prob_all = {}

for alias, config in MODEL_OPTIONS.items():
    dirname = config["dirname"]
    tokenizer = AutoTokenizer.from_pretrained(config["name"])
    test_long_path = os.path.join(dirname, "test_long.json")
    tracked_prob_all[alias] = {}
    
    if not os.path.isfile(test_long_path):
        continue
    
    with open(test_long_path, "r") as f:
        test_long = json.load(f)

    for k, input_texts in tqdm(test_long.items()):
        if int(k) > 2:
            break
        save_name = os.path.join(dirname, f"logits_hopk{k}.pt")
        logits = torch.load(save_name)
        for topic, input_text_group, indices in split_topics(input_texts):
            if topic == 'other':
                continue
            tracked_prob = get_tracked_prob(input_text_group, logits[indices, ...], tokenizer)
            if k not in tracked_prob_all[alias]:
                tracked_prob_all[alias][k] = {}
            tracked_prob_all[alias][k][topic] = tracked_prob.view(-1, 3*int(k)).mean(dim=0)
            print(k, tracked_prob.view(-1, 3*int(k)).mean(dim=0))

torch.save(tracked_prob_all, "./tracked_prob_all.pt")

  logits = torch.load(save_name)


2 tensor([0.0008, 0.0056, 0.2346, 0.0019, 0.0047, 0.3462], dtype=torch.float16)
2 tensor([0.0020, 0.0024, 0.1029, 0.0350, 0.0187, 0.2581], dtype=torch.float16)
2 tensor([0.0010, 0.0090, 0.8213, 0.0162, 0.0117, 0.1375], dtype=torch.float16)


 25%|██▌       | 1/4 [00:16<00:48, 16.24s/it]

2 tensor([0.0077, 0.0143, 0.2321, 0.2971, 0.1127, 0.2820], dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

2 tensor([0.0004, 0.0056, 0.3811, 0.0089, 0.0119, 0.2993], dtype=torch.float16)
2 tensor([0.0006, 0.0060, 0.3562, 0.0094, 0.0179, 0.5210], dtype=torch.float16)
2 tensor([0.0086, 0.0310, 0.6411, 0.0114, 0.0248, 0.2484], dtype=torch.float16)


 25%|██▌       | 1/4 [00:14<00:43, 14.46s/it]

2 tensor([0.0016, 0.0276, 0.1979, 0.3940, 0.1085, 0.2441], dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

2 tensor([7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06])
2 tensor([7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06])
2 tensor([7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06])


 25%|██▌       | 1/4 [00:27<01:22, 27.53s/it]

2 tensor([7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06, 7.7969e-06])



  0%|          | 0/4 [00:00<?, ?it/s]

2 tensor([2.3866e-04, 1.4519e-02, 3.3008e-01, 1.5192e-03, 2.2400e-02, 3.1421e-01],
       dtype=torch.float16)
2 tensor([0.0015, 0.0523, 0.4160, 0.0157, 0.0588, 0.3999], dtype=torch.float16)
2 tensor([0.0132, 0.1033, 0.4712, 0.0240, 0.0405, 0.2637], dtype=torch.float16)


 25%|██▌       | 1/4 [00:07<00:21,  7.12s/it]

2 tensor([0.0109, 0.0980, 0.2064, 0.1420, 0.1320, 0.2454], dtype=torch.float16)



