In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
from collections import defaultdict


model_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto"
)

# SST-5 template from paper
SST5_TEMPLATE = "Text: {text}\nSentiment: {label}\n\n"

def format_example(example):
    return SST5_TEMPLATE.format(
        text=example["text"],
        label=example["label"]
    )

def get_label_probs(prompt, text):
    full_input = prompt + f"Text: {text}\nSentiment:"
    inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs).logits
    
    # Label tokens (adjust based on tokenizer's mapping)
    label_tokens = [tokenizer(f" {i}")["input_ids"][-1] for i in range(5)]
    last_token_logits = outputs[0, -1, label_tokens]
    return torch.softmax(last_token_logits, dim=-1).cpu().numpy()

def evaluate_subset(Si, val_set):
    prompt = "".join([format_example(ex) for ex in Si])
    correct = 0
    for ex in val_set:
        probs = get_label_probs(prompt, ex["text"])
        if np.argmax(probs) == ex["label"]:
            correct += 1
    return correct / len(val_set)

def compute_influences(train_set, val_set, k=5, M=100):
    # Balanced subset sampling
    class_counts = defaultdict(int)
    for ex in train_set:
        class_counts[ex["label"]] += 1
    
    k_per_class = max(1, k // len(class_counts))
    subsets = []
    for _ in range(M):
        Si = []
        for label in class_counts:
            candidates = [ex for ex in train_set if ex["label"] == label]
            Si.extend(np.random.choice(candidates, k_per_class, replace=False))
        Si = list(np.random.choice(Si, k, replace=False))
        subsets.append(Si)
    
    # Sequential evaluation
    D = []
    for Si in tqdm(subsets, desc="Evaluating subsets"):
        D.append(evaluate_subset(Si, val_set))
    
    # Calculate influences
    influence_scores = defaultdict(list)
    for idx, ex in enumerate(train_set):
        included = []
        excluded = []
        for Si, acc in zip(subsets, D):
            if ex in Si:
                included.append(acc)
            else:
                excluded.append(acc)
        
        Nj = len(included)
        Mj = len(D) - Nj
        if Nj > 0 and Mj > 0:
            influence = (sum(included)/Nj) - (sum(excluded)/Mj)
        else:
            influence = 0
        influence_scores[idx] = influence
    
    return influence_scores

def load_sst5(split="train", num_samples=1000):
    dataset = load_dataset("SetFit/sst5", "default", split=split)
    dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset))))
    return dataset

def run_experiment():
    # Load SST-5
    train_set = load_sst5("train", 5000)
    val_set = load_sst5("validation", 1000)
    test_set = load_sst5("test", 1000)
    
    # Compute influences
    influence_scores = compute_influences(train_set, val_set, k=5, M=100)
    
    # Select top examples
    sorted_indices = sorted(influence_scores, key=influence_scores.get, reverse=True)[:5]
    top_examples = [train_set[i] for i in sorted_indices]
    
    # Final evaluation
    prompt = "".join([format_example(ex) for ex in top_examples])
    correct = 0
    for ex in tqdm(test_set, desc="Testing"):
        probs = get_label_probs(prompt, ex["text"])
        if np.argmax(probs) == ex["label"]:
            correct += 1
    
    print(f"Final Test Accuracy: {correct/len(test_set):.2%}")

if __name__ == "__main__":
    run_experiment()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.
Evaluating subsets: 100%|██████████| 100/100 [2:27:40<00:00, 88.61s/it] 
Testing: 100%|██████████| 1000/1000 [01:25<00:00, 11.64it/s]

Final Test Accuracy: 34.80%





In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
from collections import defaultdict


model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto"
)

# SST-5 template from paper
SST5_TEMPLATE = "Text: {text}\nSentiment: {label}\n\n"

def format_example(example):
    return SST5_TEMPLATE.format(
        text=example["text"],
        label=example["label"]
    )

def get_label_probs(prompt, text):
    full_input = prompt + f"Text: {text}\nSentiment:"
    inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs).logits
    
    # Label tokens (adjust based on tokenizer's mapping)
    label_tokens = [tokenizer(f" {i}")["input_ids"][-1] for i in range(5)]
    last_token_logits = outputs[0, -1, label_tokens]
    return torch.softmax(last_token_logits, dim=-1).cpu().numpy()

def evaluate_subset(Si, val_set):
    prompt = "".join([format_example(ex) for ex in Si])
    correct = 0
    for ex in val_set:
        probs = get_label_probs(prompt, ex["text"])
        if np.argmax(probs) == ex["label"]:
            correct += 1
    return correct / len(val_set)

def compute_influences(train_set, val_set, k=5, M=100):
    # Balanced subset sampling
    class_counts = defaultdict(int)
    for ex in train_set:
        class_counts[ex["label"]] += 1
    
    k_per_class = max(1, k // len(class_counts))
    subsets = []
    for _ in range(M):
        Si = []
        for label in class_counts:
            candidates = [ex for ex in train_set if ex["label"] == label]
            Si.extend(np.random.choice(candidates, k_per_class, replace=False))
        Si = list(np.random.choice(Si, k, replace=False))
        subsets.append(Si)
    
    # Sequential evaluation
    D = []
    for Si in tqdm(subsets, desc="Evaluating subsets"):
        D.append(evaluate_subset(Si, val_set))
    
    # Calculate influences
    influence_scores = defaultdict(list)
    for idx, ex in enumerate(train_set):
        included = []
        excluded = []
        for Si, acc in zip(subsets, D):
            if ex in Si:
                included.append(acc)
            else:
                excluded.append(acc)
        
        Nj = len(included)
        Mj = len(D) - Nj
        if Nj > 0 and Mj > 0:
            influence = (sum(included)/Nj) - (sum(excluded)/Mj)
        else:
            influence = 0
        influence_scores[idx] = influence
    
    return influence_scores

def load_sst5(split="train", num_samples=1000):
    dataset = load_dataset("SetFit/sst5", "default", split=split)
    dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset))))
    return dataset

def run_experiment():
    # Load SST-5
    train_set = load_sst5("train", 5000)
    val_set = load_sst5("validation", 1000)
    test_set = load_sst5("test", 1000)
    
    # Compute influences
    influence_scores = compute_influences(train_set, val_set, k=5, M=100)
    
    # Select top examples
    sorted_indices = sorted(influence_scores, key=influence_scores.get, reverse=True)[:5]
    top_examples = [train_set[i] for i in sorted_indices]
    
    # Final evaluation
    prompt = "".join([format_example(ex) for ex in top_examples])
    correct = 0
    for ex in tqdm(test_set, desc="Testing"):
        probs = get_label_probs(prompt, ex["text"])
        if np.argmax(probs) == ex["label"]:
            correct += 1
    
    print(f"Final Test Accuracy: {correct/len(test_set):.2%}")

if __name__ == "__main__":
    run_experiment()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.
Evaluating subsets: 100%|██████████| 100/100 [2:42:43<00:00, 97.63s/it]  
Testing: 100%|██████████| 1000/1000 [01:40<00:00,  9.93it/s]


Final Test Accuracy: 51.10%
