In [1]:
import os
import time
import re
import warnings
import datetime
import json

from copy import copy
from dataclasses import dataclass, field
import traceback

import numpy as np
import pandas as pd
from itertools import chain

In [2]:
import openai
import tiktoken

In [3]:
from tqdm.notebook import tqdm
tqdm.pandas()

In [4]:
from tools import embedding_pipeline

In [5]:
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
from langchain.vectorstores import Chroma
from langchain.embeddings import AzureOpenAIEmbeddings

In [6]:
# Fix path
os.chdir("/mnt/c/Users/leonl/OneDrive/College/Senior/CSE 481DS/Analysis/classify/gpt")

In [7]:
TRAINING_DATA_PATH   = './trainset.csv'
TO_LABEL_PATH        = '../../data/headlines.csv'
OUT_PATH             = '../../data/gpt_classified.csv'
NUM_EXAMPLES         = 6

In [8]:
openai.api_version = '2024-06-01'
openai.api_base = "https://llm.leibmann.org/v1"
openai.api_key = "keyzoned"

Some helpful vLLM commands:
- Start server: `python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B --dtype=half --download-dir /projects/bdata/llm_models/`
- Base url: `http://localhost:8000/v1`
- List models: `curl http://localhost:8000/v1/models`

In [9]:
MAX_RETRIES          = 3
MAX_TOKENS           = 300 # only include response; smaller helps with ratelimiting
TEMPERATURE          = 0 # using 0 based on xinyi, down from 1.4 earlier
RETRY_SECS           = 5
PAUSE_SECS           = 1
TIMEOUT_SECS         = 5
LLM_DEPLOYMENT       = 'mistral-7b-instruct-v0.2' # 'gpt-4' or 'GPT-4o'. For vLLM use full path

if 'localhost' in openai.api_base or 'leibmann.org' in openai.api_base:
    PROMPT_COST = 0 # dollars per 1,000 tokens
    OUTPUT_COST = 0 # dollars per 1,000 tokens
    
# input from here https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
elif LLM_DEPLOYMENT == 'GPT-4o':
    PROMPT_COST = 0.005 # dollars per 1,000 tokens
    OUTPUT_COST = 0.015 # dollars per 1,000 tokens
elif LLM_DEPLOYMENT == 'gpt-4':
    PROMPT_COST = 0.03 # dollars per 1,000 tokens
    OUTPUT_COST = 0.06 # dollars per 1,000 tokens
else: raise ValueError(f"Unknown LLM_DEPLOYMENT: {LLM_DEPLOYMENT}")

In [10]:
# Accounting
NUM_TOKENS_PROMPTED  = 0
NUM_TOKENS_GENERATED = 0
enc = tiktoken.encoding_for_model('gpt-4')
def get_num_tokens(s):
    return len( enc.encode(s) )

In [11]:
# Define Helpers for output parsing
@dataclass
class LLMCol:
    """Represents a single key-value pair in the dataframe. Used for easy parsing"""
    key: str
    allowed_values: set = field(default_factory=set)
    
    def __post_init__(self):
        self.allowed_values = set(map(LLMCol.normalize_str, self.allowed_values))

    def to_str(self, value:str):
        value = LLMCol.normalize_str(value)
        if len(self.allowed_values) > 0 and value not in self.allowed_values:
            raise ValueError(f"Malformed input: Value {value} not in allowed values {self.allowed_values}")
        return f"{self.key}: {value}"
    
    def parse_line(self, line:str) -> dict:
        """Parse a line of text into a key-value pair, performing the inverse of to_str"""
        try:
            k,v = re.split(r'\W*:\W*', line, maxsplit=1)
            if len(self.allowed_values) > 0 and LLMCol.normalize_str(v) not in self.allowed_values:
                raise ValueError(f"Malformed output: value {v} not in allowed values {self.allowed_values}")
            if k.strip().casefold() != self.key.casefold():
                raise ValueError(f"Malformed output: key {k} does not match expected key {self.key}")
            
            return {self.key: v}
        except ValueError:
            raise ValueError(f'Malformed output for line "{line}"')
        
    def parse_row(self, row:pd.Series|dict) -> str:
        """Parse a row of a dataframe into a string, performing the inverse of to_str"""
        if isinstance(row, pd.Series):
            if self.key not in row.index:
                raise ValueError(f"Key {self.key} not found in row")
            return self.to_str(row[self.key])
        elif isinstance(row, dict):
            if self.key not in row:
                raise ValueError(f"Key {self.key} not found in row")
            return self.to_str(row[self.key])
        else:
            raise ValueError("Expected a Series or dict")
        
    @staticmethod
    def normalize_str(s:str) -> str:
        return str(s).strip().casefold()
    
@dataclass
class LLMSchema:
    """Represents a set of LLMRows, used for easy parsing"""
    cols: list[LLMCol]
    
    def __post_init__(self):
        self.cols = tuple(self.cols)
        
        # Ensure no duplicate keys
        keys_found = set()
        for col in self.cols:
            if col.key in keys_found:
                raise ValueError(f"Duplicate key: {col.key}")
            keys_found.add(col.key)
    
    @property
    def keys(self):
        return {c.key for c in self.cols}
    
    def parse_row(self, row:pd.Series|dict) -> str:
        values = list()
        for c in self.cols:
            try:
                values.append(c.parse_row(row))
            except ValueError:
                continue
        return '\n'.join(values)
    
    def parse_text(self, text:str) -> list[dict]:
        """Parse a text block into a dictionary of key-value pairs"""
        output = list()
        for block in text.split('\n\n'):
            lines = re.split(r'(?<=\n)(?=[^\n]*:)', block)
            pairs = dict()
            for line in lines:
                line = line.strip()
                for col in self.cols:
                    try:
                        pairs.update(col.parse_line(line))
                    except ValueError:
                        continue
            output.append(pairs)
        return output

# Load Training Data

In [12]:
# Define the schema for the LLM output
schema = LLMSchema([
    LLMCol('BreachMentioned', {'true', 'false'}),
    LLMCol('CompanyMentioned')
])

In [13]:
# Load training set
train_set = pd.read_csv(TRAINING_DATA_PATH)
train_set = train_set.loc[:, ~train_set.columns.str.contains('^Unnamed')] # Drop all Unnamed: columns
train_set.drop_duplicates(subset=['Headline'], inplace=True)
train_set.drop_duplicates(subset=['URL'], inplace=True)

train_set['Headline'] = train_set['Headline'].apply(str.strip)
train_set.set_index(['Date', 'Publication', 'Headline', 'URL'], inplace=True)
for col in schema.cols:
    if col.key in train_set.index.names:
        continue
    
    if col.key not in train_set.columns:
        print(f"WARNING: Column {col.key} not found in training set but specified in schema")
        continue
    
    #train_set[col.key] = train_set[col.key].apply(lambda x: len(str(x)) > 0 and str(x).lower() not in ['nan','false','0'])

train_set.fillna(False, inplace=True)
train_set.head(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,BreachMentioned,CompanyMentioned
Date,Publication,Headline,URL,Unnamed: 4_level_1,Unnamed: 5_level_1
20160329,Washington Post,You can soon get unlimited data on AT&T U-verse — but it comes with a big catch,https://www.washingtonpost.com/news/the-switch/wp/2016/03/29/you-can-soon-get-unlimited-data-on-att-u-verse-but-it-comes-with-a-big-catch/,False,AT&T
20180212,Washington Post,"Lending by big banks to small businesses hits a record high, study finds",https://www.washingtonpost.com/news/on-small-business/wp/2018/02/12/lending-by-big-banks-to-small-businesses-hits-a-record-high-study-finds/,False,False
20220614,The Guardian,Yellowstone National park closed after record rain and major flooding,https://www.theguardian.com/us-news/2022/jun/14/yellowstone-national-park-flooding-rain,False,Government


In [14]:
# Load to_label EXACTLY as above
to_label = pd.read_csv(TO_LABEL_PATH)

to_label.dropna(subset=['Headline'], inplace=True)
to_label['Headline'] = to_label['Headline'].apply(str.strip)
to_label.set_index(train_set.index.names, inplace=True)

# Create columns to label
for col in schema.cols:
    if col.key in train_set.index.names:
        continue
    
    if col.key not in to_label.columns:
        # If we are trying to label a column that doesn't exist, create it
        to_label[col.key] = False

# Assert that the schema of to_label and train_set match
if not set(chain(train_set.index.names, train_set.columns)) <= set(chain(to_label.index.names, to_label.columns)):
    raise ValueError(f'To Label schema does not match train set schema! \n Missing Columns: {set(chain(train_set.index.names, train_set.columns)) - set(chain(to_label.index.names, to_label.columns))}')

print(f'Loaded {len(to_label)} rules to classify.')

Loaded 4361388 rules to classify.


# Construct Prompt

In [15]:
# Construct prompt
example_prompt = PromptTemplate.from_template(
    '\n'.join(f'{c} : {{{c}}}' for c in chain(train_set.index.names, train_set.columns)),
)

In [16]:
examples = train_set.reset_index().astype(str).to_dict(orient='records')
example_prompt.format(**examples[5])

'Date : 20120330\nPublication : CNBC\nHeadline : MasterCard, Visa Warn of Possible Security Breach\nURL : http://www.cnbc.com/id/46904168\nBreachMentioned : Yes\nCompanyMentioned : MasterCard, Visa'

In [17]:
start_time = time.monotonic()

print('Computing embeddings for examples...', end='')
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,       # This is the list[dict] of examples available to select from.
    embedding_pipeline(use_gpu=False),
    # AzureOpenAIEmbeddings(
    #     deployment = EMBEDDING_DEPLOYMENT
    # ),              # This is the embedding class used to produce embeddings which are used to measure semantic similarity.
    Chroma,         # This is the VectorStore class that is used to store the embeddings and do a similarity search over.
    k=NUM_EXAMPLES  # This is the number of examples to produce.
)
print(f'done in {time.monotonic()-start_time:.1f} seconds.')

Computing embeddings for examples...

<All keys matched successfully>


done in 21.0 seconds.


In [18]:
with open('./prompt.txt') as f:
    prompt_prefix = f.read()
prompt_prefix += '\n\nYour answer should follow the format given in the examples:\n'
    
prompt = FewShotPromptTemplate(
    example_selector = example_selector,
    example_prompt   = example_prompt,
    input_variables  = list(train_set.index.names),
    prefix           = prompt_prefix,
    suffix           = '\n'.join(chain((f'{c} : {{{c}}}' for c in train_set.index.names))),
)

In [19]:
print('Prompt formatting configured.\nExample prompt:')

print(
    prompt.format(
        **examples[5]
    )
)

Prompt formatting configured.
Example prompt:
Given a news article's headline, determine the following information about it:

If the article is about a data breach, make sure you respond with "BreachMentioned: true".
Otherwise, respond with "BreachMentioned: false".

If the headline of the article implies the main subject of the article will be a single distinct company, make sure you respond with "CompanyMentioned: <company name>", filling in the specific company name.
If the article implies that the government will be the main subject of the article, respond with "CompanyMentioned: government"
If the article mentions two or more companies, or the main subject is not clear, respond with "CompanyMentioned: false"

Below are some examples of labelled headlines. Follow the format given exactly.

Your answer should follow the format given in the examples:


Date : 20120330
Publication : CNBC
Headline : MasterCard, Visa Warn of Possible Security Breach
URL : http://www.cnbc.com/id/46904168

In [33]:
openai.ChatCompletion.create(
    model         = LLM_DEPLOYMENT,
    #deployment_id = LLM_DEPLOYMENT,
    messages      = [{'role': 'user', 'content': prompt.format(**examples[5])}],
    max_tokens    = MAX_TOKENS,
    temperature   = TEMPERATURE,
)

<OpenAIObject chat.completion id=chatcmpl-19e7e21f-fe12-46cb-8d2e-489d173ed503 at 0x7f6c251cc950> JSON: {
  "id": "chatcmpl-19e7e21f-fe12-46cb-8d2e-489d173ed503",
  "object": "chat.completion",
  "created": 1728781755,
  "model": "mistral-7b-instruct-v0.2",
  "choices": [
    {
      "index": 0,
      "message": {
        "content": " Date : 20170915\nPublication : Reuters\nHeadline : Hackers steal data from credit reporting agency Equifax, potentially impacting 143 million people\nURL : http://www.reuters.com/article/us-equifax-cybersecurity-idUSKCN1BZ256\nBreachMentioned : Yes\nCompanyMentioned : Equifax",
        "role": "assistant"
      },
      "logprobs": null,
      "finish_reason": "stop"
    }
  ],
  "usage": {
    "prompt_tokens": 806,
    "completion_tokens": 96,
    "total_tokens": 902
  }
}

In [30]:
def query(prompt):
    current_tries = 1
    
    while current_tries <= MAX_RETRIES:
        try:
            response = openai.ChatCompletion.create(
                model         = LLM_DEPLOYMENT,
                #deployment_id = LLM_DEPLOYMENT,
                messages      = [{'role': 'user', 'content': prompt}],
                max_tokens    = MAX_TOKENS,
                temperature   = TEMPERATURE,
            )
            break
        except openai.error.RateLimitError:
            time.sleep(.1*4**current_tries)
            current_tries += 1
        except Exception as e:
            if 'The response was filtered due to the prompt' in str(e):
                raise RuntimeError('Prompt was filtered.')
            
            print('\tError from OpenAI:', str(e))
            print('\tRetrying...')
            time.sleep(RETRY_SECS)
            current_tries += 1 
    
    if current_tries > MAX_RETRIES:
        raise RuntimeError('No valid response from OpenAI.')
    
    global NUM_TOKENS_PROMPTED
    global NUM_TOKENS_GENERATED
    
    NUM_TOKENS_PROMPTED  += get_num_tokens( prompt )
    NUM_TOKENS_GENERATED += get_num_tokens( response['choices'][0]['message']['content'] )
    
    return response

In [32]:
start_time = time.monotonic()

to_bool = lambda x: len(str(x)) > 0 and str(x).lower() not in ['nan','false','0']

for idx, row in tqdm(to_label.iterrows(), total=to_label.shape[0]):
    prompt_args = dict(zip(to_label.index.names, idx))
    prompt_str = prompt.format(**{k: str(v) for k, v in chain(prompt_args.items(), row.items())})
    parsed_response = dict()
    
    try:
        response = query(prompt_str)
        
        response_str = response['choices'][0]['message']['content']
        response_str = response_str.replace(prompt_str, '').strip()
        response_str = "Prescriptive : " + response_str if not ':' in response_str.splitlines()[0] else response_str
        response_str = response_str[response_str.find('Prescriptive : '):]
        
        parsed_response = schema.parse_text(response_str)[0]
        
        # Drop all keys in parsed_response that are not in schema
        for k in list(parsed_response.keys()):
            if k not in schema.keys:
                del parsed_response[k]
        
    except Exception as e:
        print(f"Error Classifying {idx}:\n{e}")
        
        # Find line number of error
        print(next((line for line in reversed(traceback.format_exc().split('\n')) if re.search(r'line \d+', line)), 'Unknown'))
        
    # Fill all missing fields with NaN
    if not set(schema.keys) <= set(parsed_response.keys()):
        print(f"Missing keys in response for {idx}: {set(schema.keys) - set(parsed_response.keys())}")
        print('Filling with NaN')
        for k in schema.keys:
            if k not in parsed_response:
                parsed_response[k] = np.nan
    
    # Fill in the row with the parsed response
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for k,v in parsed_response.items():
            to_label.loc[idx,k] = to_bool(v)
    to_label.to_csv(OUT_PATH)
    
elapsed_time = time.monotonic() - start_time

  0%|          | 0/4361388 [00:00<?, ?it/s]

Missing keys in response for (20070101, 'New York Times', 'Rush to Hang Hussein Was  Questioned', 'http://www.nytimes.com/2007/01/01/world/middleeast/01iraq.html?hp&ex=1167714000&en=85dae91ed8178e3a&ei=5094&partner=homepage'): {'BreachMentioned', 'CompanyMentioned'}
Filling with NaN


KeyboardInterrupt: 

In [None]:
parsed_response = schema.parse_text(response_str)[0]
        
# Drop all keys in parsed_response that are not in schema
for k in list(parsed_response.keys()):
    if k not in schema.keys:
        del parsed_response[k]

In [None]:
response_str

'Prescriptive : False\nRestrictive : True\nPost Content : False\nPost Format : False\nUser-Related : False\nNot a Rule : False\nSpam, Low Quality, Off-Topic, and Reposts : False\nPost Tagging & Flairing : False\nPeer Engagement : True\nLinks & External Content : False\nImages : False\nCommercialization : False\nIllegal Content : False\nDivisive Content : False\nRespect for Others : False\nBrigading : False\nBan Mentioned : False\nKarma/Score Mentioned : True'

In [None]:
print(f'Processed {len(to_label):,d} records in {elapsed_time/60:.1f} minutes.')

print(f'{NUM_TOKENS_PROMPTED:,d} tokens prompted, {NUM_TOKENS_GENERATED:,d} generated.')

run_cost = (NUM_TOKENS_PROMPTED/1000 * PROMPT_COST)+(NUM_TOKENS_GENERATED/1000 * OUTPUT_COST)

print(f'Total run cost ~ ${run_cost:.2f}')

Processed 3,000 records in 99.2 minutes.
4,432,262 tokens prompted, 434,489 generated.
Total run cost ~ $28.68


In [None]:
with open('../../cost_tracking.jsonl', 'a') as f:
    costs = {
        'date'                  : datetime.date.today().isoformat(),
        'time'                  : datetime.datetime.now().strftime('%H:%M'),
        'tokens_prompted'       : NUM_TOKENS_PROMPTED,
        'tokens_generated'      : NUM_TOKENS_GENERATED,
        'run_cost_usd'          : run_cost,
        'elapsed_seconds'       : elapsed_time,
        'records_processed'     : len(to_label),
    }
    
    f.write(json.dumps(costs)+'\n')