In [None]:
!pip install openai datasets

In [None]:
import os
import json
from pydantic import BaseModel
from datasets import load_dataset
from openai import OpenAI

In [None]:
os.environ["MODEL"] = "gpt-4o"
os.environ["OPENAI_API_KEY"] = ""

## Function to classify a research paper using OpenAI GPT model

In [None]:
def classify(sample, client, temperature=0.6, top_p=0.0):
    """
    Classifies the title of a research paper into one or more predefined subjects.

    Parameters
    ==========
    sample (dict): A dictionary containing a key 'text' with the title of the research paper.
    client (OpenAI): An instance of the OpenAI client to generate predictions.
    temperature (float): Sampling temperature for the model.
    top_p (float): Top-p sampling value for nucleus sampling.

    Returns
    =======
    dict: The input `sample` dictionary augmented with a 'Predictions' key containing the predicted subjects.
    """

    # Define the schema for the structured response format
    # Returns a list of predicted subjects
    class classify_research_paper(BaseModel):
        Subjects: list[str]

    # Instruction prompt
    system_msg = """Given the title of a research paper, classify it into one or more of the following subjects based on the content: ['Computer Science', 'Physics', 'Mathematics', 'Statistics', 'Quantitative Biology', 'Quantitative Finance'].

    Return only the list with only the most appropriate subjects (1-3) from the list.
    Do not include subjects outside the provided list.
    Avoid selecting all subjects. Select subjects most relevant to the content.
    """

    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": sample['text']}
    ]

    # Call the OpenAI API for classification
    completion = client.beta.chat.completions.parse(
        model=os.getenv("MODEL"),
        messages=messages,
        response_format=classify_research_paper,
        temperature=temperature,
        top_p=top_p
    )

    # Extract and parse the response
    response_content = completion.choices[0].message.parsed
    response = json.loads(response_content)

    # Add the predicted subjects to the sample
    sample['Predictions'] = response['Subjects']

    return sample


## Inference

In [None]:
# Initialize the OpenAI client
client = OpenAI()

dataset = load_dataset('bhujith10/multi_class_classification_dataset', split="test")

# Apply the classification function to each entry in the dataset
dataset_with_predictions = dataset.map(lambda sample: classify(sample, client))