In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm

In [2]:
splits = {'train': 'data/train-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
# df_train = pd.read_parquet("hf://datasets/wangrongsheng/ag_news/" + splits["train"])
df_test = pd.read_parquet("hf://datasets/wangrongsheng/ag_news/" + splits["test"])

In [3]:
df_test.shape

(7600, 2)

In [4]:
# df_train_5000 = df_train.sample(n=5000, random_state=42).reset_index(drop=True)
# df_train_5000['uid'] = df_train_5000.index
df_test_1000 = df_test.sample(n=1000, random_state=42)

In [5]:
label_dict = {0: "World", 1: "Sports", 2: "Business", 3: "Technology"}

In [7]:
print(df_test_1000.iloc[0]['text'])
print(df_test_1000.iloc[0]['label'])

Fan v Fan: Manchester City-Tottenham Hotspur This weekend Manchester City entertain Spurs, and with last seasons seven-goal FA Cup epic between the two teams still fresh in the memory, entertain could be the operative word.
1


In [8]:
from sentence_transformers import SentenceTransformer

def encode_text(text):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    return model.encode(text)

In [9]:
from sklearn.metrics.pairwise import cosine_similarity

def select_top_k(query_embedding, candidate_embeddings, k):
    similarities = cosine_similarity([query_embedding], candidate_embeddings)[0]
    top_k_indices = np.argsort(similarities)[-k:][::-1]
    return top_k_indices

In [None]:
# from openai import OpenAI
# client = OpenAI(api_key="your_apikey")

def zero_shot_cot_gpt_4_o(input_):
    completion = client.chat.completions.create(
      model="gpt-4o-mini",
      messages=[
        {"role": "user", "content": "What is the topic of the input? World, sports, business or technology?\nThe response should follow the format: Topic:{world, sports, business or technology}\nReason:{reason}" + f"\nInput: {input_}" + "\nLet's think step by step."}
      ] 
    )
    return completion.choices[0].message.content

def few_shot_cot_gpt_4_o(examples, input_):
    completion = client.chat.completions.create(
      model="gpt-4o-mini",
      messages=[
        {"role": "user", "content": "What is the topic of the input? World, sports, business or technology?" + f"\n{examples}" +"\nThe response should follow the format: Topic:{world, sports, business or technology}\nReason:{reason}\nHere is the test data" + f"\nInput: {input_}" + "\nLet's think step by step."}
      ] 
    )
    return completion.choices[0].message.content

In [12]:
ans = zero_shot_cot_gpt_4_o(df_test_1000.iloc[0]['text'])
print(ans)

Topic:sports  
Reason:The input discusses a specific match between two football teams, Manchester City and Tottenham Hotspur, which is clearly related to sports.


In [14]:
# tqdm.pandas()
# df_train_5000['embedding'] = df_train_5000['text'].progress_apply(encode_text)

In [15]:
# df_train_5000.to_pickle("train_5000_with_embeddings.pkl")

In [13]:
df_train_5000_with_embeddings = pd.read_pickle("train_5000_with_embeddings.pkl")

In [16]:
from collections import Counter

def iterative_demonstration_selection(test_sample, train_samples, k=4, q=2):

    label_dict = {0: "World", 1: "Sports", 2: "Business", 3: "Technology"}
    train_embeddings = np.stack(train_samples['embedding'].to_numpy())
    all_answers = []
    reasoning_path = zero_shot_cot_gpt_4_o(test_sample)

    for _ in range(q):
        query_embedding = encode_text(reasoning_path)
        selected_indices = select_top_k(query_embedding, train_embeddings, k)
        demonstrations = [train_samples.iloc[i] for i in selected_indices]

        examples_prompt = "\n".join(
            [f"Input: {row['text']}\nTopic: {label_dict[row['label']]}" for row in demonstrations]
        )

        result = few_shot_cot_gpt_4_o(examples_prompt, test_sample)

        try:
            topic_line = next(line for line in result.split('\n') if line.lower().startswith("topic"))
            reason_line = next(line for line in result.split('\n') if line.lower().startswith("reason"))
        except StopIteration:
            topic_line = "Topic: unknown"
            reason_line = "Reason: unknown"

        reasoning_path = reason_line
        answer = topic_line.split(":", 1)[-1].strip().lower()
        all_answers.append(answer)

    final_answer = Counter(all_answers).most_common(1)[0][0]
    return final_answer

In [17]:
def run_ids_on_test_set(df_test, df_train, k=4, q=3):
    predictions = []
    reasoning_paths = []

    for i, row in tqdm(df_test.iterrows(), total=len(df_test), desc="Running IDS on test set"):
        test_text = row['text']
        true_label = row['label']
        
        # Run IDS
        try:
            final_answer = iterative_demonstration_selection(test_text, df_train, k=k, q=q)
        except Exception as e:
            final_answer = "error"
            print(f"Error on test sample {i}: {e}")

        predictions.append(final_answer)
        reasoning_paths.append(test_text)

    df_test_result = df_test.copy()
    df_test_result['prediction'] = predictions
    df_test_result['input'] = reasoning_paths

    return df_test_result

In [21]:
from sklearn.metrics import accuracy_score

df_test_result = run_ids_on_test_set(df_test_1000, df_train_5000_with_embeddings, k=4, q=3)

In [23]:
label_dict = {0: "World", 1: "Sports", 2: "Business", 3: "Technology"}
df_test_result['label_str'] = df_test_result['label'].map(label_dict).str.lower()
accuracy = accuracy_score(df_test_result['label_str'], df_test_result['prediction'])
print(f"Accuracy: {accuracy * 100:.2f}%")

Accuracy: 88.30%


In [25]:
df_test_result[['input','label_str','prediction']].to_csv("agnews_ids_results.csv", index=False)

In [24]:
df_test_result

Unnamed: 0,text,label,prediction,input,label_str
7094,Fan v Fan: Manchester City-Tottenham Hotspur T...,1,sports,Fan v Fan: Manchester City-Tottenham Hotspur T...,sports
1017,Paris Tourists Search for Key to 'Da Vinci Cod...,0,world,Paris Tourists Search for Key to 'Da Vinci Cod...,world
2850,Net firms: Don't tax VoIP The Spanish-American...,3,business,Net firms: Don't tax VoIP The Spanish-American...,technology
1452,Dependent species risk extinction The global e...,3,world,Dependent species risk extinction The global e...,technology
457,EDS Is Charter Member of Siebel BPO Alliance (...,3,business,EDS Is Charter Member of Siebel BPO Alliance (...,technology
...,...,...,...,...,...
4127,Britain Charges Cleric Sought by US for Aiding...,0,world,Britain Charges Cleric Sought by US for Aiding...,world
4801,Hobbit-sized Humans Called Homo floresiensis D...,3,technology,Hobbit-sized Humans Called Homo floresiensis D...,technology
4919,Tar Heels beat Miami For the second time this ...,1,sports,Tar Heels beat Miami For the second time this ...,sports
1721,Microsoft Eyes Video for Business IM Software ...,3,technology,Microsoft Eyes Video for Business IM Software ...,technology
