In [15]:
import os
import time
import re
import warnings
import datetime
import json
from tqdm import tqdm
from copy import copy
from dataclasses import dataclass, field
import traceback

import numpy as np
import pandas as pd

In [16]:
import openai
import tiktoken

ModuleNotFoundError: No module named 'openai'

In [None]:
from tools import embedding_pipeline

In [None]:
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
from langchain.vectorstores import Chroma
from langchain.embeddings import AzureOpenAIEmbeddings

In [None]:
TRAINING_DATA_PATH   = './trainset.csv'
TO_LABEL_PATH        = './headlines.csv'
OUT_PATH             = '../data/gpt_classified.csv'
NUM_EXAMPLES         = 6

In [None]:
openai.api_version = '2024-06-01'
openai.api_base = "https://llm.leibmann.org/v1"
openai.api_key = "keyzoned"

Some helpful vLLM commands:
- Start server: `python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B --dtype=half --download-dir /projects/bdata/llm_models/`
- Base url: `http://localhost:8000/v1`
- List models: `curl http://localhost:8000/v1/models`

In [None]:
MAX_RETRIES          = 3
MAX_TOKENS           = 300 # only include response; smaller helps with ratelimiting
TEMPERATURE          = 0 # using 0 based on xinyi, down from 1.4 earlier
RETRY_SECS           = 5
PAUSE_SECS           = 1
TIMEOUT_SECS         = 5
LLM_DEPLOYMENT       = 'mistral-7b-instruct-v0.2' # 'gpt-4' or 'GPT-4o'. For vLLM use full path

if 'localhost' in openai.api_base:
    PROMPT_COST = 0 # dollars per 1,000 tokens
    OUTPUT_COST = 0 # dollars per 1,000 tokens
    
# input from here https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
elif LLM_DEPLOYMENT == 'GPT-4o':
    PROMPT_COST = 0.005 # dollars per 1,000 tokens
    OUTPUT_COST = 0.015 # dollars per 1,000 tokens
elif LLM_DEPLOYMENT == 'gpt-4':
    PROMPT_COST = 0.03 # dollars per 1,000 tokens
    OUTPUT_COST = 0.06 # dollars per 1,000 tokens
else: raise ValueError(f"Unknown LLM_DEPLOYMENT: {LLM_DEPLOYMENT}")

In [None]:
# Accounting
NUM_TOKENS_PROMPTED  = 0
NUM_TOKENS_GENERATED = 0
enc = tiktoken.encoding_for_model('gpt-4')
def get_num_tokens(s):
    return len( enc.encode(s) )

In [None]:
# Define Helpers for output parsing
@dataclass
class LLMCol:
    """Represents a single key-value pair in the dataframe. Used for easy parsing"""
    key: str
    allowed_values: set = field(default_factory=set)
    
    def __post_init__(self):
        self.allowed_values = set(map(LLMCol.normalize_str, self.allowed_values))

    def to_str(self, value:str):
        value = LLMCol.normalize_str(value)
        if len(self.allowed_values) > 0 and value not in self.allowed_values:
            raise ValueError(f"Malformed input: Value {value} not in allowed values {self.allowed_values}")
        return f"{self.key}: {value}"
    
    def parse_line(self, line:str) -> dict:
        """Parse a line of text into a key-value pair, performing the inverse of to_str"""
        try:
            k,v = re.split(r'\W*:\W*', line, maxsplit=1)
            if len(self.allowed_values) > 0 and LLMCol.normalize_str(v) not in self.allowed_values:
                raise ValueError(f"Malformed output: value {v} not in allowed values {self.allowed_values}")
            if k.strip().casefold() != self.key.casefold():
                raise ValueError(f"Malformed output: key {k} does not match expected key {self.key}")
            
            return {self.key: v}
        except ValueError:
            raise ValueError(f'Malformed output for line "{line}"')
        
    def parse_row(self, row:pd.Series|dict) -> str:
        """Parse a row of a dataframe into a string, performing the inverse of to_str"""
        if isinstance(row, pd.Series):
            if self.key not in row.index:
                raise ValueError(f"Key {self.key} not found in row")
            return self.to_str(row[self.key])
        elif isinstance(row, dict):
            if self.key not in row:
                raise ValueError(f"Key {self.key} not found in row")
            return self.to_str(row[self.key])
        else:
            raise ValueError("Expected a Series or dict")
        
    @staticmethod
    def normalize_str(s:str) -> str:
        return str(s).strip().casefold()
    
@dataclass
class LLMSchema:
    """Represents a set of LLMRows, used for easy parsing"""
    cols: list[LLMCol]
    
    def __post_init__(self):
        self.cols = tuple(self.cols)
        
        # Ensure no duplicate keys
        keys_found = set()
        for col in self.cols:
            if col.key in keys_found:
                raise ValueError(f"Duplicate key: {col.key}")
            keys_found.add(col.key)
    
    @property
    def keys(self):
        return {c.key for c in self.cols}
    
    def parse_row(self, row:pd.Series|dict) -> str:
        values = list()
        for c in self.cols:
            try:
                values.append(c.parse_row(row))
            except ValueError:
                continue
        return '\n'.join(values)
    
    def parse_text(self, text:str) -> list[dict]:
        """Parse a text block into a dictionary of key-value pairs"""
        output = list()
        for block in text.split('\n\n'):
            lines = re.split(r'(?<=\n)(?=[^\n]*:)', block)
            pairs = dict()
            for line in lines:
                line = line.strip()
                for col in self.cols:
                    try:
                        pairs.update(col.parse_line(line))
                    except ValueError:
                        continue
            output.append(pairs)
        return output

# Load Training Data

In [None]:
# Define the schema for the LLM output
schema = LLMSchema([
    # LLMCol('subreddit'),
    # LLMCol('rule'),
    # LLMCol('rule_description'),
    
    LLMCol('breach mentioned', {'true', 'false'}),
    LLMCol('company mentioned')
])

In [None]:
# Load training set
train_set = pd.read_csv(TRAINING_DATA_PATH)

train_set['Headline'] = train_set['Headline'].apply(str.strip)
train_set.set_index(['Date', 'Publication', 'Headline', 'URL'], inplace=True)
for col in schema.cols:
    if col.key in train_set.index.names:
        continue
    
    if col.key not in train_set.columns:
        print(f"WARNING: Column {col.key} not found in training set but specified in schema")
        continue
    
    train_set[col.key] = train_set[col.key].apply(lambda x: len(str(x)) > 0 and str(x).lower() not in ['nan','false','0'])

train_set.head(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Prescriptive,Restrictive,Post Content,Post Format,User-Related,Not a Rule,"Spam, Low Quality, Off-Topic, and Reposts",Post Tagging & Flairing,Peer Engagement,Links & External Content,Images,Commercialization,Illegal Content,Divisive Content,Respect for Others,Brigading,Ban Mentioned,Karma/Score Mentioned
subreddit,rule,rule_description,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
soccer,5. No Duplicates,,False,True,True,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False
movies,8. Extraneous Comic Book Movie submission,,False,True,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
trees,1. 1. Age,,False,True,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False


# Construct Prompt

In [None]:
# Construct prompt
from itertools import chain

example_prompt = PromptTemplate.from_template(
    '\n'.join(f'{c} : {{{c}}}' for c in chain(train_set.index.names, train_set.columns)),
)

In [None]:
examples = train_set.reset_index().astype(str).to_dict(orient='records')

In [None]:
start_time = time.monotonic()

print('Computing embeddings for examples...', end='')
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,       # This is the list[dict] of examples available to select from.
    embedding_pipeline(),
    # AzureOpenAIEmbeddings(
    #     deployment = EMBEDDING_DEPLOYMENT
    # ),              # This is the embedding class used to produce embeddings which are used to measure semantic similarity.
    Chroma,         # This is the VectorStore class that is used to store the embeddings and do a similarity search over.
    k=NUM_EXAMPLES  # This is the number of examples to produce.
)
print(f'done in {time.monotonic()-start_time:.1f} seconds.')

Computing embeddings for examples...

<All keys matched successfully>
Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


done in 66.9 seconds.


In [None]:
with open('./prompt.txt') as f:
    prompt_prefix = f.read()
prompt_prefix += '\n\nYour answer should follow the format given in the examples:\n'
    
prompt = FewShotPromptTemplate(
    example_selector = example_selector,
    example_prompt   = example_prompt,
    input_variables = list(train_set.index.names),
    prefix           = prompt_prefix,
    suffix           = '\n'.join(chain((f'{c} : {{{c}}}' for c in train_set.index.names))),
)

In [None]:
print('Prompt formatting configured.\nExample prompt:')

print(
    prompt.format(
        subreddit = 'AskReddit',
        rule='Please mark all posts with their content',
        rule_description=''
    )
)

Prompt formatting configured.
Example prompt:
Given a rule in a specific subreddit, identify topics and qualities about the rule.

If the rule explicitly limits or forbids certain actions, mark it as "Restrictive".

If a rule expresses general guidelines or desires for a community, mark it as "Prescriptive".

If a rule explicitly states desired or undesired content within the subreddit, mark it as "Post Content".

If a rule is related to users or would cause unequal enforcement if two different users posted the same content (including verification and prior approval rules), mark it as "User-Related".

If a rule prescribes post structure, formatting, titling, or references a location to post (such as other subreddits or specific threads), mark it as "Post Format".

If the content is sidebar information and not necessarily a rule, mark it as "Not a Rule".

If a rule pertains to labeling/flairing, marking nsfw, using tags, etc., mark it as "Post Tagging & Flairing".

If a rule encourages 

In [None]:
def query(prompt):
    current_tries = 1
    
    while current_tries <= MAX_RETRIES:
        try:
            response = openai.ChatCompletion.create(
                model         = LLM_DEPLOYMENT if 'localhost' in openai.api_base else None,
                deployment_id = LLM_DEPLOYMENT if 'localhost' not in openai.api_base else None,
                messages      = [{'role': 'user', 'content': prompt}],
                max_tokens    = MAX_TOKENS,
                temperature   = TEMPERATURE,
            )
            break
        except openai.error.RateLimitError:
            time.sleep(.1*4**current_tries)
            current_tries += 1
        except Exception as e:
            if 'The response was filtered due to the prompt' in str(e):
                raise RuntimeError('Prompt was filtered.')
            
            print('\tError from OpenAI:', str(e))
            print('\tRetrying...')
            time.sleep(RETRY_SECS)
            current_tries += 1 
    
    if current_tries > MAX_RETRIES:
        raise RuntimeError('No valid response from OpenAI.')
    
    global NUM_TOKENS_PROMPTED
    global NUM_TOKENS_GENERATED
    
    NUM_TOKENS_PROMPTED  += get_num_tokens( prompt )
    NUM_TOKENS_GENERATED += get_num_tokens( response['choices'][0]['message']['content'] )
    
    return response

In [None]:
to_label = pd.read_csv(TO_LABEL_PATH)

to_label['rule'] = to_label['rule'].apply(str.strip)  
to_label['rule_description'] = to_label['rule_description'].apply(lambda x: str(x).strip() if str(x).lower() != 'nan' else '')

to_label = to_label[~to_label.index.isin(train_set.index)] # Dedupe
to_label.set_index(['subreddit','rule','rule_description'], inplace=True) # Match train set schema

# Create columns to label
for col in schema.cols:
    if col.key in ['subreddit','rule','rule_description']:
        continue
    
    if col.key not in to_label.columns:
        to_label[col.key] = False

# Assert that the schema of to_label and train_set match
if not set(chain(train_set.index.names, train_set.columns)) <= set(chain(to_label.index.names, to_label.columns)):
    raise ValueError(f'To Label schema does not match train set schema! \n Missing Columns: {set(chain(train_set.index.names, train_set.columns)) - set(chain(to_label.index.names, to_label.columns))}')

print(f'Loaded {len(to_label)} rules to classify.')

Loaded 3000 rules to classify.


In [None]:
start_time = time.monotonic()

to_bool = lambda x: len(str(x)) > 0 and str(x).lower() not in ['nan','false','0']

for idx, row in tqdm(list(to_label.iterrows())):
    prompt_args = dict(zip(to_label.index.names, idx))
    prompt_str = prompt.format(**prompt_args) # format prompt with row index
    parsed_response = dict()
    
    try:
        response = query(prompt_str)
        
        response_str = response['choices'][0]['message']['content']
        response_str = response_str.replace(prompt_str, '').strip()
        response_str = "Prescriptive : " + response_str if not ':' in response_str.splitlines()[0] else response_str
        response_str = response_str[response_str.find('Prescriptive : '):]
        
        parsed_response = schema.parse_text(response_str)[0]
        
        # Drop all keys in parsed_response that are not in schema
        for k in list(parsed_response.keys()):
            if k not in schema.keys:
                del parsed_response[k]
        
    except Exception as e:
        print(f"Error Classifying {idx}:\n{e}")
        
        # Find line number of error
        print(next((line for line in reversed(traceback.format_exc().split('\n')) if re.search(r'line \d+', line)), 'Unknown'))
        
    # Fill all missing fields with NaN
    if not set(schema.keys) <= set(parsed_response.keys()):
        print(f"Missing keys in response for {idx}: {set(schema.keys) - set(parsed_response.keys())}")
        print('Filling with NaN')
        for k in schema.keys:
            if k not in parsed_response:
                parsed_response[k] = np.nan
    
    # Fill in the row with the parsed response
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for k,v in parsed_response.items():
            to_label.loc[idx,k] = to_bool(v)
    to_label.to_csv(OUT_PATH)
    
elapsed_time = time.monotonic() - start_time

  0%|          | 0/3000 [00:00<?, ?it/s]

 62%|██████▏   | 1874/3000 [1:02:06<53:08,  2.83s/it]

Missing keys in response for ('dankmemes', "15. Yes, we have a karma threshold. Don't ask us about it. This rule literally explains everything. So, to post here, your account needs to have: \n\n \n existed for Some Time \n 500 post karma \n positive comment karma \n \n\n Or, if you're a special enough snowflake who can produce original content, you can go to  r/specialsnowflake  and earn your way in a bit faster. \n\n Check your karma breakdown here:  http://old.reddit.com/u/me/overview \n\n Types of karma are explained here:  https://www.reddit.com/r/help/comments/au5t4i/post_karma_vs_comment_karma/eh5tgck/ \n\n You do not need more information. So don't ask.", ''): {'Ban Mentioned', 'Karma/Score Mentioned'}
Filling with NaN


 63%|██████▎   | 1889/3000 [1:02:49<1:07:57,  3.67s/it]

Missing keys in response for ('dankmemes', "15. Yes, we have a karma threshold. Don't ask us about it. This rule literally explains everything. So, to post here, your account needs to have: \n\n \n existed for Some Time \n 1000 post + comment karma \n positive comment karma \n \n\n Or, if you're a special enough snowflake who can produce original content, you can go to  r/specialsnowflake  and bypass the threshold. \n\n Check your karma breakdown here:  http://old.reddit.com/u/me/overview \n\n Types of karma are explained here:  https://www.reddit.com/r/help/comments/au5t4i/post_karma_vs_comment_karma/eh5tgck/ \n\n You do not need more information. So don't ask.", ''): {'Karma/Score Mentioned'}
Filling with NaN


 89%|████████▉ | 2679/3000 [1:28:31<09:56,  1.86s/it]  

Error Classifying ('wallstreetbets', '4. No Gay Shit or Being a Fag', ''):
'content'
  File "/tmp/ipykernel_1484830/1342523222.py", line 33, in query
Missing keys in response for ('wallstreetbets', '4. No Gay Shit or Being a Fag', ''): {'Images', 'Not a Rule', 'Links & External Content', 'Spam, Low Quality, Off-Topic, and Reposts', 'Restrictive', 'Karma/Score Mentioned', 'Peer Engagement', 'Illegal Content', 'Respect for Others', 'Brigading', 'Post Content', 'User-Related', 'Post Format', 'Ban Mentioned', 'Commercialization', 'Divisive Content', 'Prescriptive', 'Post Tagging & Flairing'}
Filling with NaN


 94%|█████████▍| 2816/3000 [1:33:11<05:58,  1.95s/it]

Error Classifying ('Animemes', '12. No lewd Loli/Shota content', ''):
'content'
  File "/tmp/ipykernel_1484830/1342523222.py", line 33, in query
Missing keys in response for ('Animemes', '12. No lewd Loli/Shota content', ''): {'Images', 'Not a Rule', 'Links & External Content', 'Spam, Low Quality, Off-Topic, and Reposts', 'Restrictive', 'Karma/Score Mentioned', 'Peer Engagement', 'Illegal Content', 'Respect for Others', 'Brigading', 'Post Content', 'User-Related', 'Post Format', 'Ban Mentioned', 'Commercialization', 'Divisive Content', 'Prescriptive', 'Post Tagging & Flairing'}
Filling with NaN


100%|██████████| 3000/3000 [1:39:14<00:00,  1.98s/it]


In [None]:
parsed_response = schema.parse_text(response_str)[0]
        
# Drop all keys in parsed_response that are not in schema
for k in list(parsed_response.keys()):
    if k not in schema.keys:
        del parsed_response[k]

In [None]:
response_str

'Prescriptive : False\nRestrictive : True\nPost Content : False\nPost Format : False\nUser-Related : False\nNot a Rule : False\nSpam, Low Quality, Off-Topic, and Reposts : False\nPost Tagging & Flairing : False\nPeer Engagement : True\nLinks & External Content : False\nImages : False\nCommercialization : False\nIllegal Content : False\nDivisive Content : False\nRespect for Others : False\nBrigading : False\nBan Mentioned : False\nKarma/Score Mentioned : True'

In [None]:
print(f'Processed {len(to_label):,d} records in {elapsed_time/60:.1f} minutes.')

print(f'{NUM_TOKENS_PROMPTED:,d} tokens prompted, {NUM_TOKENS_GENERATED:,d} generated.')

run_cost = (NUM_TOKENS_PROMPTED/1000 * PROMPT_COST)+(NUM_TOKENS_GENERATED/1000 * OUTPUT_COST)

print(f'Total run cost ~ ${run_cost:.2f}')

Processed 3,000 records in 99.2 minutes.
4,432,262 tokens prompted, 434,489 generated.
Total run cost ~ $28.68


In [None]:
with open('../../cost_tracking.jsonl', 'a') as f:
    costs = {
        'date'                  : datetime.date.today().isoformat(),
        'time'                  : datetime.datetime.now().strftime('%H:%M'),
        'tokens_prompted'       : NUM_TOKENS_PROMPTED,
        'tokens_generated'      : NUM_TOKENS_GENERATED,
        'run_cost_usd'          : run_cost,
        'elapsed_seconds'       : elapsed_time,
        'records_processed'     : len(to_label),
    }
    
    f.write(json.dumps(costs)+'\n')