In [63]:
import numpy as np
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import f1_score, accuracy_score
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import random

class TopK_CoNE:
    def __init__(self, embeddings, raw_texts, k=5, model_name='gpt2', device=None):
        self.embeddings = embeddings
        self.raw_texts = raw_texts
        self.k = k
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
        self.model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

    def get_topk(self, query_embedding):
        similarities = cosine_similarity(query_embedding.reshape(1, -1), self.embeddings)[0]
        topk_indices = np.argsort(similarities)[0:30][::-1]  #get top 30
        return topk_indices

    def compute_cross_entropy(self, text):
        encodings = self.tokenizer(text, return_tensors='pt').to(self.device)
        input_ids = encodings['input_ids']
        with torch.no_grad():
            outputs = self.model(input_ids, labels=input_ids)
        loss = outputs.loss.item()
        return loss * input_ids.size(1)

    def apply_cone(self, query_text, topk_indices):
        candidate_scores = []
        for idx in topk_indices:
            demo = self.raw_texts[idx]
            prompt_with_query = demo + "\n" + query_text
            H_xc = self.compute_cross_entropy(prompt_with_query)
            H_c = self.compute_cross_entropy(demo)
            H_cond = H_xc - H_c
            candidate_scores.append((idx, H_cond))
        sorted_indices = [idx for idx, _ in sorted(candidate_scores, key=lambda x: x[1])]
        return sorted_indices[:self.k]

    def select_demonstrations(self, query_embedding, query_text):
        topk_indices = self.get_topk(query_embedding)
        refined_indices = self.apply_cone(query_text, topk_indices)
        return refined_indices

# ----------------------
# SST-5 Dataset loading and embedding
# ----------------------
def load_sst5(split="train", num_samples=1000):
    dataset = load_dataset("SetFit/sst5", "default", split=split)
    dataset = dataset.filter(lambda x: x["label_text"] is not None)
    dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset))))

    texts = []
    labels = []
    for sample in dataset:
        text = sample["text"].strip().replace("\n", " ")
        label_text = sample["label_text"]
        full_text = f"Text: {text}\nLabel: {label_text}"
        texts.append(full_text)
        labels.append(label_text)
    return texts, labels

def embed_texts(texts, model_name='all-MiniLM-L6-v2', batch_size=64):
    model = SentenceTransformer(model_name)
    embeddings = model.encode(texts, batch_size=batch_size, show_progress_bar=True)
    return embeddings

def generate_label_from_prompt(prompt, model, tokenizer, device):
    enc = tokenizer(prompt, return_tensors='pt').to(device)
    with torch.no_grad():
        output = model.generate(
            **enc,
            max_new_tokens=3,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    #print("\n Prompt:", prompt)
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    #print("\n Output:", decoded)
    generated_label = decoded.split("Label:")[-1].strip().split("\n")[0].strip()
    print("\n Generated label:", generated_label)
    return generated_label

# ----------------------
# Evaluate TopK + CoNE in batches of 3
# ----------------------
def evaluate_cone(texts, labels, embeddings, k=5, batch_size=3):
    print("\n[1] Starting batch evaluation (3 random queries)...")
    indices = random.sample(range(len(texts)), batch_size)
    accs, f1s = [], []

    for idx in indices:
        query_input = texts[idx].rsplit("\nLabel:", 1)[0] + "\nLabel:"
        query_label = labels[idx]
        query_embedding = embeddings[idx]

        candidate_embeddings = np.delete(embeddings, idx, axis=0)
        candidate_texts = [t for i, t in enumerate(texts) if i != idx]
        candidate_labels = [l for i, l in enumerate(labels) if i != idx]

        selector = TopK_CoNE(candidate_embeddings, candidate_texts, k=k)
        demo_indices = selector.select_demonstrations(query_embedding, query_input)

        # Add labels to the demonstration text
        demos = []
        for i in demo_indices:
            labeled_demo = candidate_texts[i] + f"\nLabel: {candidate_labels[i]}"
            demos.append(labeled_demo)

        prompt = "\n\n".join(demos) + "\n\n" + query_input
        pred_label = generate_label_from_prompt(prompt, selector.model, selector.tokenizer, selector.device)
        
        #print("pred label:", pred_label)
        print("\n Actual label:", query_label)
        
        accs.append(int(pred_label.lower() == query_label.lower()))
        f1s.append(f1_score([query_label], [pred_label], average='macro'))

    avg_acc = np.mean(accs)
    avg_f1 = np.mean(f1s)
    print(f"\n[2] Evaluation over {batch_size} samples")
    print(f"Average Accuracy: {avg_acc:.4f}")
    print(f"Average F1 Score: {avg_f1:.4f}")
    return avg_acc, avg_f1

# ----------------------
# Main Driver
# ----------------------
def run_demo_selection():
    print("\n[1] Loading and embedding SST-5 dataset...")
    texts, labels = load_sst5(num_samples=100)
    embeddings = embed_texts(texts)
    evaluate_cone(texts, labels, embeddings, k=4, batch_size=10)

if __name__ == "__main__":
    run_demo_selection()



[1] Loading and embedding SST-5 dataset...


Repo card metadata block was not found. Setting CardData to empty.


Batches:   0%|          | 0/2 [00:00<?, ?it/s]


[1] Starting batch evaluation (3 random queries)...

 Generated label: neutral

 Actual label: neutral

 Generated label: neutral

 Actual label: negative

 Generated label: very positive

 Actual label: positive

 Generated label: neutral

 Actual label: very negative

 Generated label: positive

 Actual label: negative

 Generated label: positive

 Actual label: negative

 Generated label: very negative

 Actual label: positive

 Generated label: neutral

 Actual label: neutral

 Generated label: neutral

 Actual label: negative

 Generated label: very negative

 Actual label: very positive

[2] Evaluation over 10 samples
Average Accuracy: 0.2000
Average F1 Score: 0.2000


In [7]:
import numpy as np
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import torch
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import f1_score, accuracy_score
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import random

class TopK_CoNE:
    def __init__(self, embeddings, raw_texts, k=5, model_name='gpt2', device=None):
        self.embeddings = embeddings
        self.raw_texts = raw_texts
        self.k = k
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
        self.model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

    def get_topk(self, query_embedding):
        similarities = cosine_similarity(query_embedding.reshape(1, -1), self.embeddings)[0]
        topk_indices = np.argsort(similarities)[-self.k:][::-1]
        return topk_indices

    def compute_cross_entropy(self, text):
        encodings = self.tokenizer(text, return_tensors='pt').to(self.device)
        input_ids = encodings['input_ids']
        with torch.no_grad():
            outputs = self.model(input_ids, labels=input_ids)
        loss = outputs.loss.item()
        return loss * input_ids.size(1)

    def apply_cone(self, query_text, topk_indices):
        candidate_scores = []
        for idx in topk_indices:
            demo = self.raw_texts[idx]
            prompt_with_query = demo + "\n" + query_text
            H_xc = self.compute_cross_entropy(prompt_with_query)
            H_c = self.compute_cross_entropy(demo)
            H_cond = H_xc - H_c
            candidate_scores.append((idx, H_cond))
        sorted_indices = [idx for idx, _ in sorted(candidate_scores, key=lambda x: x[1])]
        return sorted_indices[:self.k]

    def select_demonstrations(self, query_embedding, query_text):
        topk_indices = self.get_topk(query_embedding)
        refined_indices = self.apply_cone(query_text, topk_indices)
        return refined_indices

# ----------------------
# CommonsenseQA Dataset loading and embedding
# ----------------------
def load_commonsenseqa(split="train", num_samples=1000):
    dataset = load_dataset("commonsense_qa", split=split)
    dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset))))
    
    texts = []
    labels = []
    for example in dataset:
        question = example['question']
        choices = example['choices']['text']
        label_index = example['answerKey']
        answer = example['choices']['label'].index(label_index)
        
        # Format choices and include them in the prompt
        choices_text = "\n".join([f"{label}. {text}" for label, text in zip(example['choices']['label'], choices)])
        prompt = f"Question: {question}\n{choices_text}"
        label_text = example['choices']['text'][answer]
        
        full_prompt = f"{prompt}\nAnswer: {label_text}"
        texts.append(full_prompt)
        labels.append(label_text)

    return texts, labels

def embed_texts(texts, model_name='all-MiniLM-L6-v2', batch_size=64):
    model = SentenceTransformer(model_name)
    embeddings = model.encode(texts, batch_size=batch_size, show_progress_bar=True)
    return embeddings

def generate_label_from_prompt(prompt, model, tokenizer, device):
    enc = tokenizer(prompt, return_tensors='pt').to(device)
    with torch.no_grad():
        output = model.generate(
            **enc,
            max_new_tokens=10,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    generated_label = decoded.split("Answer:")[-1].strip().split("\n")[0].strip()
    return generated_label

# ----------------------
# Evaluate TopK + CoNE in batches of 3
# ----------------------
def evaluate_cone(texts, labels, embeddings, k=5, batch_size=3):
    print("\n[1] Starting batch evaluation (3 random queries)...")
    indices = random.sample(range(len(texts)), batch_size)
    accs, f1s = [], []

    for idx in indices:
        query_input = texts[idx].rsplit("\nAnswer:", 1)[0] + "\nAnswer:"
        query_label = labels[idx]
        query_embedding = embeddings[idx]

        candidate_embeddings = np.delete(embeddings, idx, axis=0)
        candidate_texts = [t for i, t in enumerate(texts) if i != idx]
        candidate_labels = [l for i, l in enumerate(labels) if i != idx]

        selector = TopK_CoNE(candidate_embeddings, candidate_texts, k=k)
        demo_indices = selector.select_demonstrations(query_embedding, query_input)

        # Add labeled demos
        demos = []
        for i in demo_indices:
            labeled_demo = candidate_texts[i] + f"\nAnswer: {candidate_labels[i]}"
            demos.append(labeled_demo)

        prompt = "\n\n".join(demos) + "\n\n" + query_input
        pred_label = generate_label_from_prompt(prompt, selector.model, selector.tokenizer, selector.device)
        print("pred label:", pred_label)
        print("\n Actual label:", query_label)
                
        accs.append(int(pred_label.strip().lower() == query_label.strip().lower()))
        f1s.append(f1_score([query_label], [pred_label], average='macro'))

    avg_acc = np.mean(accs)
    avg_f1 = np.mean(f1s)
    print(f"\n[2] Evaluation over {batch_size} samples")
    print(f"Average Accuracy: {avg_acc:.4f}")
    print(f"Average F1 Score: {avg_f1:.4f}")
    return avg_acc, avg_f1

# ----------------------
# Main Driver
# ----------------------
def run_demo_selection():
    print("\n[1] Loading and embedding CommonsenseQA dataset...")
    texts, labels = load_commonsenseqa(num_samples=100)
    embeddings = embed_texts(texts)
    evaluate_cone(texts, labels, embeddings, k=5, batch_size=3)

if __name__ == "__main__":
    run_demo_selection()



[1] Loading and embedding CommonsenseQA dataset...


Batches:   0%|          | 0/2 [00:00<?, ?it/s]


[1] Starting batch evaluation (3 random queries)...
pred label: all buildings

 Actual label: restaurant
pred label: composted

 Actual label: composted
pred label: satellite

 Actual label: beast

[2] Evaluation over 3 samples
Average Accuracy: 0.3333
Average F1 Score: 0.3333


In [None]:
#setup 
!huggingface-cli login

Usage:
First, pass a HF token to use gated models.

At the bottom of the next block, put in a model name and a HF dataset of your choice.

    Supported datasets:
    - Commonsenseqa
    - AGNews
    - SST-5

In [None]:
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import f1_score, accuracy_score
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import random

class TopK_CoNE:
    def __init__(self, embeddings, raw_texts, k=5):
        self.embeddings = embeddings
        self.raw_texts = raw_texts
        self.k = k
        self.model = None
        self.tokenizer = None
        self.device = None

    def get_topk(self, query_embedding):
        similarities = cosine_similarity(query_embedding.reshape(1, -1), self.embeddings)[0]
        topk_indices = np.argsort(similarities)[-self.k:][::-1]
        return topk_indices

    def compute_cross_entropy(self, text):
        encodings = self.tokenizer(text, return_tensors='pt').to(self.device)
        input_ids = encodings['input_ids']
        with torch.no_grad():
            outputs = self.model(input_ids, labels=input_ids)
        loss = outputs.loss.item()
        return loss * input_ids.size(1)

    def apply_cone(self, query_text, topk_indices):
        candidate_scores = []
        for idx in topk_indices:
            demo = self.raw_texts[idx]
            prompt_with_query = demo + "\n" + query_text
            H_xc = self.compute_cross_entropy(prompt_with_query)
            H_c = self.compute_cross_entropy(demo)
            H_cond = H_xc - H_c
            candidate_scores.append((idx, H_cond))
        sorted_indices = [idx for idx, _ in sorted(candidate_scores, key=lambda x: x[1])]
        return sorted_indices[:self.k]

    def select_demonstrations(self, query_embedding, query_text):
        topk_indices = self.get_topk(query_embedding)
        refined_indices = self.apply_cone(query_text, topk_indices)
        return refined_indices

# ----------------------
# Dataset Loading Functions
# ----------------------
def load_commonsenseqa(split="train", num_samples=1000):
    dataset = load_dataset("commonsense_qa", split=split)
    dataset = dataset.shuffle().select(range(min(num_samples, len(dataset))))
    texts = []
    labels = []
    for example in dataset:
        question = example['question']
        choices = example['choices']['text']
        label_index = example['answerKey']
        answer = example['choices']['label'].index(label_index)
        choices_text = "\n".join([f"{label}. {text}" for label, text in zip(example['choices']['label'], choices)])
        prompt = f"Question: {question}\n{choices_text}"
        label_text = example['choices']['text'][answer]
        full_prompt = f"{prompt}\nAnswer: {label_text}"
        texts.append(full_prompt)
        labels.append(label_text)
    return texts, labels

def load_ag_news(split="train", num_samples=1000):
    label_map = {1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tech"}
    dataset = load_dataset("sh0416/ag_news", split=split)
    dataset = dataset.shuffle().select(range(num_samples))
    texts = []
    labels = []
    for sample in dataset:
        label = sample["label"]
        title = sample["title"]
        description = sample["description"]
        full_text = f"Title: {title}\nDescription: {description}\nLabel: {label}"
        texts.append(full_text)
        labels.append(str(label))
    return texts, labels

def load_sst5(split="train", num_samples=1000):
    dataset = load_dataset("SetFit/sst5", "default", split=split)  
    dataset = dataset.filter(lambda x: x["label_text"] is not None)
    dataset = dataset.shuffle().select(range(min(num_samples, len(dataset))))
    texts = []
    labels = []
    for sample in dataset:
        text = sample["text"].strip().replace("\n", " ")
        label_text = sample["label_text"]
        full_text = f"Text: {text}\nLabel: {label_text}"
        texts.append(full_text)
        labels.append(label_text)
    return texts, labels

def load_dataset_by_name(name, num_samples=100):
    if name == "commonsenseqa":
        return load_commonsenseqa(num_samples=num_samples)
    elif name == "ag_news":
        return load_ag_news(num_samples=num_samples)
    elif name == "sst5":
        return load_sst5(num_samples=num_samples)
    else:
        raise ValueError(f"Unsupported dataset: {name}")

# ----------------------
# Embedding and Generation
# ----------------------
def embed_texts(texts, model_name='all-MiniLM-L6-v2', batch_size=64):
    model = SentenceTransformer(model_name)
    embeddings = model.encode(texts, batch_size=batch_size, show_progress_bar=True)
    return embeddings

def generate_label_from_prompt(prompt, model, tokenizer, device):
    enc = tokenizer(prompt, return_tensors='pt').to(device)
    with torch.no_grad():
        output = model.generate(
            **enc,
            max_new_tokens=10,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    generated_label = decoded.split("Answer:" if "Answer:" in prompt else "Label:")[-1].strip().split("\n")[0].strip()
    return generated_label

# ----------------------
# Evaluation
# ----------------------
def evaluate_cone(texts, labels, embeddings, model_name, k=5, batch_size=3):
    from transformers import AutoModelForCausalLM, AutoTokenizer
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    model.eval()
    print("\n[1] Starting batch evaluation (3 random queries)...")
    indices = random.sample(range(len(texts)), batch_size)
    accs, f1s = [], []

    for idx in indices:
        separator = "Answer:" if "Answer:" in texts[idx] else "Label:"
        query_input = texts[idx].rsplit(f"\n{separator}", 1)[0] + f"\n{separator}"
        query_label = labels[idx]
        query_embedding = embeddings[idx]

        candidate_embeddings = np.delete(embeddings, idx, axis=0)
        candidate_texts = [t for i, t in enumerate(texts) if i != idx]
        candidate_labels = [l for i, l in enumerate(labels) if i != idx]

        selector = TopK_CoNE(candidate_embeddings, candidate_texts, k=k)
        selector.model = model
        selector.tokenizer = tokenizer
        selector.device = device
        demo_indices = selector.select_demonstrations(query_embedding, query_input)

        demos = [candidate_texts[i] + f"\n{separator} {candidate_labels[i]}" for i in demo_indices]
        #print("len demos:", len(demos))
        prompt = "\n\n".join(demos) + f"\n\n{query_input}"
        pred_label = generate_label_from_prompt(prompt, model, tokenizer, device)

        accs.append(int(str(pred_label).strip().lower() == str(query_label).strip().lower()))
        f1s.append(f1_score([str(query_label)], [str(pred_label)], average='macro'))

        #print("pred label:", pred_label)
        #print("gold:", query_label)

    avg_acc = np.mean(accs)
    avg_f1 = np.mean(f1s)
    print(f"\n[2] Evaluation over {batch_size} samples")
    print(f"Average Accuracy: {avg_acc:.4f}")
    print(f"Average F1 Score: {avg_f1:.4f}")
    return avg_acc, avg_f1

# ----------------------
# Main Driver
# ----------------------
def run_demo_selection(model_name='openai-community/gpt2', dataset_name='ag_news'):
    print(f"\n[1] Loading and embedding {dataset_name} dataset...")
    texts, labels = load_dataset_by_name(dataset_name, num_samples=100)
    embeddings = embed_texts(texts)
    evaluate_cone(texts, labels, embeddings, k=8, batch_size=100, model_name=model_name)

if __name__ == "__main__":
    run_demo_selection()
