In [None]:
import os
import json
import time
from openai import OpenAI


In [None]:
# Provide your OpenAI API key
# WARNING: Storing API keys directly in code is not recommended for production.
# Consider using environment variables or a secure key management system.
api_key = ""
client = OpenAI(api_key=api_key)


In [None]:
# Define Paths for Data and Output
project_root = '/content/drive/MyDrive/IS450 Project'
input_json_path = os.path.join(project_root, 'Data', 'Historical Reddit', 'NER', 'Unlabeled', '5000_reddit_posts_for_NER.json')
output_json_path = os.path.join(project_root, 'Data', 'Historical Reddit', 'NER', 'Labeled', 'ner_labeled_reddit_dataset.json')

# Ensure output directory exists
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)


In [None]:
# Define allowed entity labels for Reddit context
ALLOWED_LABELS = {
    "ORG": "Organizations, companies, institutions",
    "PER": "People",
    "STOCK": "Individual company stock symbols and tickers (e.g., AAPL, MSFT, GOOGL)",
    "INDEX": "Market indices and benchmarks (e.g., S&P 500, Russell 2000, Dow Jones, NASDAQ Composite)",
    "CRYPTO": "Cryptocurrency names and symbols",
    "MONEY": "Actual monetary values (e.g., $100, €50, 2.5 million)",
    "METRIC": "Quantitative measurements (percentages, numbers, statistics, e.g., 50%, 1000 units, 5x growth)",
    "DATE": "Dates and time references",
    "PRODUCT": "Specific financial instruments and services (e.g., 401k, bonds, options, ETFs)",
    "EVENT": "Specific market events and economic conditions (e.g., bull run, recession, market crash)",
    "CONCEPT": "General financial concepts and terms (e.g., diversification, value investing)",
    "GPE": "Countries, cities, states, or administrative regions",
    "LOC": "Physical locations or places"
}


In [None]:
# Load or initialize output JSON
if os.path.exists(output_json_path):
    with open(output_json_path, 'r', encoding='utf-8') as f:
        processed_data = json.load(f)
    processed_ids = set(item["id"] for item in processed_data)
    print(f"Found {len(processed_ids)} previously processed posts")
else:
    processed_data = []
    processed_ids = set()
    print("Starting fresh processing")


In [None]:
# Load input JSON file
with open(input_json_path, 'r', encoding='utf-8') as f:
    data = json.load(f)
print(f"Loaded {len(data)} posts to process")


In [None]:
def validate_entities(entities):
    """
    Validate that all entities use allowed labels
    """
    invalid_labels = set()
    for ent in entities:
        # Check if ent is a dictionary and has a 'label' key
        if isinstance(ent, dict) and "label" in ent:
            if ent["label"] not in ALLOWED_LABELS:
                invalid_labels.add(ent["label"])
        else:
            # Handle cases where an item in 'entities' is not a dictionary or lacks 'label'
            print(f"Warning: Invalid entity format found: {ent}")
            # Depending on requirements, you might want to raise an error or skip validation for this entity
            # For now, let's treat it as an invalid label for simplicity
            invalid_labels.add(str(type(ent))) # Add type or placeholder

    if invalid_labels:
        raise ValueError(f"Invalid labels found: {invalid_labels}. Allowed labels are: {list(ALLOWED_LABELS.keys())}")


In [None]:
total_processed = 0
total_errors = 0
posts_to_process = [item for item in data if item.get("id") not in processed_ids]
total_remaining = len(posts_to_process)
print(f"Posts remaining to process: {total_remaining}")

for item in posts_to_process:
    post_id = item.get("id")
    # Redundant check, kept for safety
    # if post_id in processed_ids:
    #     continue

    # Use 'processed_text_finbert' as per the original script
    text = item.get("processed_text_finbert", "")
    if not text: # Skip if text is empty
        print(f"Skipping {post_id}: Empty text field ('processed_text_finbert').")
        continue

    prompt = (
        "Label entities in the Reddit post using ONLY these labels:\n"
        f"{json.dumps(ALLOWED_LABELS, indent=2)}\n\n"
        "Rules:\n"
        "1. Use only the labels above\n"
        "2. Include repeated entities\n"
        "3. Output as JSON array with 'text' and 'label' fields\n"
        "4. Pay attention to Reddit-specific terminology (e.g. stonks, tendies, diamond hands - do not label these unless they refer to actual stocks/concepts)\n"
        "5. Do not label general sentiment expressions (e.g., 'to the moon', 'HODL') or emojis as entities\n"
        "6. For market indices (like S&P 500, Dow Jones), use INDEX label, not STOCK or ORG\n\n"
        "Example:\n"
        '[\n'
        '  {"text": "AAPL", "label": "STOCK"},\n'
        '  {"text": "S&P 500", "label": "INDEX"},\n'
        '  {"text": "the Fed", "label": "ORG"},\n'
        '  {"text": "Wall Street", "label": "LOC"},\n'
        '  {"text": "24%", "label": "METRIC"}\n'
        ']\n\n'
        f"Text:\n{text}\n"
    )

    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are a NER labeling assistant for Reddit financial posts. Respond with a JSON array of objects containing 'text' and 'label' fields. Use only the specified labels. Pay attention to informal expressions and Reddit-specific terminology."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1 # Low temperature for more deterministic output
        )

        # Clean the response and ensure it's valid JSON
        response_text = response.choices[0].message.content.strip()
        # Remove any markdown code block markers if present
        response_text = response_text.replace('```json', '').replace('```', '').strip()

        try:
            # Handle potential empty or non-JSON responses
            if not response_text or not response_text.startswith('['):
                    print(f"Warning for {post_id}: Received non-JSON or empty response: '{response_text}'. Treating as no entities found.")
                    entities = [] # Assume no entities if response is invalid
            else:
                entities = json.loads(response_text)
                # Validate that all labels are allowed
                validate_entities(entities)
        except (json.JSONDecodeError, ValueError) as e:
            print(f"Error parsing JSON or validating labels for {post_id}: {e}")
            print(f"Response text received: {response_text}")
            total_errors += 1
            continue # Skip saving this item

        # Add to processed data
        processed_data.append({
            "id": post_id,
            "text": text, # Store the text used for labeling
            "entities": entities
        })

        # Save after each successful processing
        with open(output_json_path, 'w', encoding='utf-8') as f:
            json.dump(processed_data, f, indent=2, ensure_ascii=False)

        processed_ids.add(post_id) # Add here only after successful save
        total_processed += 1
        print(f"Processed {post_id} ({total_processed}/{total_remaining})")

        # Add a small delay to help avoid rate limits
        time.sleep(1)

    except Exception as e:
        print(f"General error processing {post_id}: {e}")
        total_errors += 1
        # Optionally add a longer delay or retry logic here
        time.sleep(5) # Longer delay after general error
        continue


In [None]:
print(f"\nProcessing complete!")
print(f"Total posts successfully processed in this run: {total_processed}")
print(f"Total errors encountered in this run: {total_errors}")
if total_processed + total_errors > 0:
    success_rate = (total_processed / (total_processed + total_errors)) * 100
    print(f"Success rate for this run: {success_rate:.2f}%")
else:
    print("No new posts processed in this run.")
print(f"Total labeled posts in output file: {len(processed_data)}")
