In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
from collections import defaultdict

model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto"
)

# CommonsenseQA prompt template
CSQA_TEMPLATE = """Question: {question}
A. {choice0}
B. {choice1}
C. {choice2}
D. {choice3}
E. {choice4}
Answer: {answer}\n\n"""

def format_example(example):
    choices = example["choices"]["text"]
    return CSQA_TEMPLATE.format(
        question=example["question"],
        choice0=choices[0],
        choice1=choices[1],
        choice2=choices[2],
        choice3=choices[3],
        choice4=choices[4],
        answer=example["answerKey"]
    )

def get_label_probs(prompt, question, choices):
    full_input = prompt + f"""Question: {question}
A. {choices[0]}
B. {choices[1]}
C. {choices[2]}
D. {choices[3]}
E. {choices[4]}
Answer:"""
    
    inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=2048).to(model.device)

    with torch.no_grad():
        outputs = model(**inputs).logits

    label_tokens = [tokenizer(f" {c}")["input_ids"][-1] for c in ["A", "B", "C", "D", "E"]]
    last_token_logits = outputs[0, -1, label_tokens]
    return torch.softmax(last_token_logits, dim=-1).cpu().numpy()

def evaluate_subset(Si, val_set):
    prompt = "".join([format_example(ex) for ex in Si])
    correct = 0
    for ex in val_set:
        probs = get_label_probs(prompt, ex["question"], ex["choices"]["text"])
        pred_idx = np.argmax(probs)
        pred_choice = ["A", "B", "C", "D", "E"][pred_idx]
        if pred_choice == ex["answerKey"]:
            correct += 1
    return correct / len(val_set)

def compute_influences(train_set, val_set, k=5, M=100):
    # Balanced subset sampling
    class_counts = defaultdict(int)
    for ex in train_set:
        class_counts[ex["answerKey"]] += 1
    
    k_per_class = max(1, k // len(class_counts))
    subsets = []
    for _ in range(M):
        Si = []
        for label in class_counts:
            candidates = [ex for ex in train_set if ex["answerKey"] == label]
            Si.extend(np.random.choice(candidates, k_per_class, replace=False))
        Si = list(np.random.choice(Si, k, replace=False))
        subsets.append(Si)
    
    # Sequential evaluation
    D = []
    for Si in tqdm(subsets, desc="Evaluating subsets"):
        D.append(evaluate_subset(Si, val_set))
    
    # Calculate influences
    influence_scores = defaultdict(list)
    for idx, ex in enumerate(train_set):
        included = []
        excluded = []
        for Si, acc in zip(subsets, D):
            if ex in Si:
                included.append(acc)
            else:
                excluded.append(acc)
        
        Nj = len(included)
        Mj = len(D) - Nj
        if Nj > 0 and Mj > 0:
            influence = (sum(included)/Nj) - (sum(excluded)/Mj)
        else:
            influence = 0
        influence_scores[idx] = influence
    
    return influence_scores

def load_csqa(split="train", num_samples=1000):
    dataset = load_dataset("commonsense_qa", split=split)
    dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset))))
    if split!="train":
        full_val = load_dataset("commonsense_qa", split="validation").shuffle(seed=42)
        val_set = full_val.select(range(0, 500))  # First 500 for validation
        test_set = full_val.select(range(500, 1000))
        return val_set, test_set
    return dataset

def run_experiment():
    # Load CommonsenseQA
    train_set = load_csqa("train", 5000)
    val_set, test_set = load_csqa("validation", 1000)

    # Compute influences
    influence_scores = compute_influences(train_set, val_set, k=5, M=100)
    
    # Select top examples
    sorted_indices = sorted(influence_scores, key=influence_scores.get, reverse=True)[:5]
    top_examples = [train_set[i] for i in sorted_indices]
    
    # Final evaluation
    prompt = "".join([format_example(ex) for ex in top_examples])
    correct = 0
    for ex in tqdm(test_set, desc="Testing"):
        probs = get_label_probs(prompt, ex["question"], ex["choices"]["text"])
        pred_idx = np.argmax(probs)
        pred_choice = ["A", "B", "C", "D", "E"][pred_idx]
        if pred_choice == ex["answerKey"]:
            correct += 1
    
    print(f"Final Test Accuracy: {correct/len(test_set):.2%}")

if __name__ == "__main__":
    run_experiment()


tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.39k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.25M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/160k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/151k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9741 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1221 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1140 [00:00<?, ? examples/s]

Evaluating subsets: 100%|██████████| 100/100 [1:53:57<00:00, 68.37s/it]
Testing: 100%|██████████| 500/500 [01:09<00:00,  7.17it/s]


Final Test Accuracy: 71.20%
