# Introduction

In [1]:
# Import Modules
import instructor
from collections import Counter
from datasets import load_dataset
from openai import OpenAI
from pprint import pprint
from pydantic import BaseModel, Field
from tqdm.notebook import tqdm
from typing import Literal

## Dataset and Labels

In [2]:
# Load AG News dataset
dataset = load_dataset("ag_news")

# Convert to Pandas DataFrame
df_train = dataset["train"].to_pandas()
df_test = dataset["test"].to_pandas()

# Class Labels
class_names = ['World', 'Sports', 'Business', 'Sci/Tech']
df_train['label_name'] = df_train['label'].apply(lambda x: class_names[x])
df_test['label_name'] = df_test['label'].apply(lambda x: class_names[x])

# Dataset Sizes
print("Train size:", len(df_train))
print("Test size:", len(df_test))

Train size: 120000
Test size: 7600


## Setup Instructor and LLM

In [3]:
# Specify Structured Output
# Only allow Class Names for AG News Subset. Only allow a single value from the fixed set
class Category(BaseModel):
    category: Literal['World', 'Sports', 'Business', 'Sci/Tech'] = Field(
        ...,
        description = "Choose exactly one category of: 'World', 'Sports', 'Business' or 'Sci/Tech'"
    )

In [4]:
# Create Client
client = instructor.from_openai(OpenAI(base_url = "http://localhost:11434/v1", api_key = "ollama"), mode = instructor.Mode.JSON)

## Classify AG News Subset

In [5]:
# Subsample
df_test = df_test.sample(n = 1000)

In [6]:
# In Context Examples - 1 per category
SEED = 7
examples = {label: df_train.loc[df_train["label_name"].eq(label), "text"].sample(1, random_state=SEED).iloc[0]  for label in class_names}

# Format
FEWSHOT_TEXT = "\n\n".join([f"""Category: {label}\nNews Article: {examples[label]}""" for label in class_names])

# Summary
print(FEWSHOT_TEXT)

Category: World
News Article: Kerry touts job-creation plans (AFP) AFP - Democratic White House hopeful John Kerry promised to help boost US job growth and take a harder line on alleged abuses by China as he touted his economic plan.

Category: Sports
News Article: Smit lauds Boks resolve South Africa skipper John Smit paid tribute to his team #39;s resilience after the 23-19 victory over Australia, which won them the 2004 Tri-Nations.

Category: Business
News Article: Dollar Rises; Traders Drop Bets Currency to Reach One-Month Low Aug. 20 (Bloomberg) -- The dollar climbed against the euro after some traders abandoned bets that a slowdown in growth reflected in economic reports this week would push the US currency to a one-month low. 

Category: Sci/Tech
News Article: Separate Genes Responsible for Drinking, Alcoholism By Randy Dotinga, HealthDay Reporter    HealthDayNews -- Some people can drink a lot of alcohol without becoming addicted, and specific genes may help explain why, resea

In [7]:
# Create Message
def create_message(text):
    prompt = f"""
You are a precise news-topic classifier. Choose exactly one category from:
World, Sports, Business, Sci/Tech.

Here are four labeled examples (one per category):
{FEWSHOT_TEXT}

Now classify the following news article.

News Article:
{text}
"""
    
    message = [{
                    "role": "system", 
                    "content": "You are a precise news-topic classifier. Classify each article into exactly one of: World, Sports, Business, Sci/Tech. "
               },
               {
                   "role": "user", 
                   "content": prompt
               }]
    
    return message

In [8]:
# Example
pprint(create_message(df_train.iloc[0]['text']))

[{'content': 'You are a precise news-topic classifier. Classify each article '
             'into exactly one of: World, Sports, Business, Sci/Tech. ',
  'role': 'system'},
 {'content': '\n'
             'You are a precise news-topic classifier. Choose exactly one '
             'category from:\n'
             'World, Sports, Business, Sci/Tech.\n'
             '\n'
             'Here are four labeled examples (one per category):\n'
             'Category: World\n'
             'News Article: Kerry touts job-creation plans (AFP) AFP - '
             'Democratic White House hopeful John Kerry promised to help boost '
             'US job growth and take a harder line on alleged abuses by China '
             'as he touted his economic plan.\n'
             '\n'
             'Category: Sports\n'
             'News Article: Smit lauds Boks resolve South Africa skipper John '
             'Smit paid tribute to his team #39;s resilience after the 23-19 '
             'victory over Australia

In [9]:
# PlaceHolders
y_true, y_pred = [], []

# Classification Loops
for index, row in tqdm(df_test.iterrows(), total = df_test.shape[0]):
    # Get Text and GT label
    text = row['text']
    gt_label = row['label_name']


    # Call LLM
    response = client.chat.completions.create(model = "gemma3:12b",
                                              messages = create_message(text),
                                              response_model = Category,
                                              temperature = 0.01,
                                              top_p = 0.98,
                                              strict = True,
                                              max_retries = 15)

    # !! DEBUG Only Print Output
    #print(resp.model_dump_json(indent=2))
    #print(f"GT Label: {gt_label}     Model Label: {resp.category}")

    # Store Predictions
    y_true.append(gt_label)
    y_pred.append(response.category)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [10]:
# Average and per Class Accuracy
total = Counter(y_true)
correct = Counter(t for t, p in zip(y_true, y_pred) if t == p)

overall_acc = sum(correct.values()) / len(y_true)
print("\nOverall accuracy: {:.2%}".format(overall_acc))

print("Per-class accuracy:")
for c in class_names:
    acc = correct[c] / total[c]
    print(f"  {c:8}: {acc:.2%} (n={total[c]})")


Overall accuracy: 83.60%
Per-class accuracy:
  World   : 90.33% (n=269)
  Sports  : 97.07% (n=205)
  Business: 87.65% (n=243)
  Sci/Tech: 63.96% (n=283)


## Conclusion

