# Text Classification using OpenAI and Pydantic

This notebook demonstrates how to implement text classification tasks — specifically, single-label and multi-label classifications — using the OpenAI API, Python's `enum` module, and Pydantic models.  
This recipe has been updated to use create_iterable from instructor 1.0, which fixes previous issues in multi-label classification.

## Motivation

Text classification is a common problem in many NLP applications, such as spam detection or support ticket categorization. The goal is to provide a systematic way to handle these cases using OpenAI's GPT models in combination with Python data structures.

In [27]:
import enum
from pydantic import BaseModel
from openai import OpenAI
import instructor
from typing import List, Literal
from dotenv import load_dotenv

In [19]:
load_dotenv(dotenv_path='../api_keys.env')

True

## Single-Label Classification

### Defining the Structures

For single-label classification, we first define an `enum` for possible labels and a Pydantic model for the output.

In [57]:
class SpamPrediction(BaseModel):
    label: Literal[
        "spam",
        "not_spam",
    ]

### Classifying Text

The function `classify` will perform the single-label classification.

In [58]:
# Apply the patch to the OpenAI client
# enables response_model keyword
client = instructor.from_openai(OpenAI())

def classify(data: str) -> SinglePrediction:
    """Perform single-label classification on the input text."""
    return client.chat.completions.create(
        model="gpt-4o",
        response_model=SpamPrediction,
        messages=[
            {
                "role": "user",
                "content": f"Classify the following text: {data}",
            },
        ],
    )  # type: ignore

### Testing and Evaluation

Let's run an example to see if it correctly identifies a spam message.

In [60]:
# Test single-label classification
prediction = classify("Hello there I'm a Nigerian prince and I want to give you money")
assert prediction.label == "spam"
print(f"Prediction: {prediction.label}")

Prediction: spam


## Multi-Label Classification

### Defining the Structures

For multi-label classification, we introduce a new enum class and a different Pydantic model to handle multiple labels.

In [63]:
class CustomerSupportType(BaseModel):
    label: Literal[
        "tech_issue",
        "billing",
        "general_query",
    ]

### Classifying Text

The function `multi_classify` is responsible for multi-label classification.

In [64]:
def multi_classify(data: str) -> MultiClassPrediction:
    """Perform multi-label classification on the input text."""
    return client.chat.completions.create_iterable(
        model="gpt-4o",
        response_model=CustomerSupportType,
        messages=[
            {
                "role": "user",
                "content": f"Classify the following support ticket: {data}",
            },
        ],
    )  # type: ignore

### Testing and Evaluation

Finally, we test the multi-label classification function using a sample support ticket.

In [82]:
# Test multi-label classification
ticket = "My account is locked and I can't access my billing info."
predictions = [x.label for x in multi_classify(ticket)]
assert "tech_issue" in predictions
assert "billing" in predictions
print(f"Predictions: {predictions}")

Predictions: ['tech_issue', 'billing']
