In [None]:
import os
import json
import time
from openai import OpenAI


In [None]:
# Provide your OpenAI API key
# WARNING: Storing API keys directly in code is not recommended for production.
# Consider using environment variables or a secure key management system.
api_key = ""
client = OpenAI(api_key=api_key)


In [None]:
# Define Paths for Data and Output
project_root = '/content/drive/MyDrive/IS450 Project'
input_json_path = os.path.join(project_root, 'Data', 'Historical News', 'NER_Data', 'Unlabeled', '5000_news_for_NER.json')
output_json_path = os.path.join(project_root, 'Data', 'Historical News', 'NER_Data', 'Labeled', 'ner_labeled_news_dataset.json')

# Ensure output directory exists
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)

In [None]:
# Define allowed entity labels
ALLOWED_LABELS = {
    "ORG": "Organizations, companies, institutions",
    "PER": "People",
    "STOCK": "Individual company stock symbols and tickers (e.g., AAPL, MSFT, GOOGL)",
    "INDEX": "Market indices and benchmarks (e.g., S&P 500, Russell 2000, Dow Jones, NASDAQ Composite)",
    "CRYPTO": "Cryptocurrency names and symbols",
    "MONEY": "Actual monetary values (e.g., $100, €50, 2.5 million)",
    "METRIC": "Quantitative measurements (percentages, numbers, statistics, e.g., 50%, 1000 units, 5x growth)",
    "DATE": "Dates and time periods",
    "PRODUCT": "Products, services, or technologies",
    "EVENT": "Events, meetings, or occurrences",
    "LOC": "Locations (places, buildings, landmarks)",
    "GPE": "Geopolitical entities (countries, cities, regions)"
}


In [None]:
# Load or initialize output JSON
if os.path.exists(output_json_path):
    with open(output_json_path, 'r', encoding='utf-8') as f:
        processed_data = json.load(f)
    processed_ids = set(item["id"] for item in processed_data)
    print(f"Found {len(processed_ids)} previously processed articles")
else:
    processed_data = []
    processed_ids = set()
    print("Starting fresh processing")
    

In [None]:
# Load input JSON file
with open(input_json_path, 'r', encoding='utf-8') as f:
    data = json.load(f)
print(f"Loaded {len(data)} articles to process")


In [None]:
def validate_entities(entities):
    """
    Validate that all entities use allowed labels
    """
    invalid_labels = set()
    for ent in entities:
        if ent["label"] not in ALLOWED_LABELS:
            invalid_labels.add(ent["label"])

    if invalid_labels:
        raise ValueError(f"Invalid labels found: {invalid_labels}. Allowed labels are: {list(ALLOWED_LABELS.keys())}")
    

In [None]:
total_processed = 0
total_errors = 0
articles_to_process = [item for item in data if item.get("id") not in processed_ids]
total_remaining = len(articles_to_process)
print(f"Articles remaining to process: {total_remaining}")


for item in articles_to_process:
    post_id = item.get("id")
    # This check is redundant now but kept for safety
    # if post_id in processed_ids:
    #     continue

    text = item.get("text", "")
    prompt = (
        "Label entities in the text using ONLY these labels:\n"
        f"{json.dumps(ALLOWED_LABELS, indent=2)}\n\n"
        "Rules:\n"
        "1. Use only the labels above\n"
        "2. Output as JSON array with 'text' and 'label' fields\n"
        "3. The 'text' field must be the exact substring from the original text\n"
        "4. Entities must not overlap. If a shorter entity is part of a longer one, only label the longer entity (e.g., label \"S&P 500 Index\" as INDEX, not \"S&P 500\").\n\n"
        "Example:\n"
        '[\n'
        '  {"text": "Microsoft", "label": "ORG"},\n'
        '  {"text": "AAPL", "label": "STOCK"},\n'
        '  {"text": "S&P 500 Index", "label": "INDEX"},\n'
        '  {"text": "New York", "label": "GPE"}\n'
        ']\n\n'
        f"Text:\n{text}\n"
    )

    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": "You are a NER labeling assistant. Respond with a JSON array of objects containing 'text' and 'label' fields. The 'text' field must be the exact substring. Use only the specified labels."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1
        )

        # Clean the response and ensure it's valid JSON
        response_text = response.choices[0].message.content.strip()
        # Remove any markdown code block markers if present
        response_text = response_text.replace('```json', '').replace('```', '').strip()

        try:
            entities = json.loads(response_text)
            # Validate that all labels are allowed
            validate_entities(entities)
        except (json.JSONDecodeError, ValueError) as e:
            print(f"Error parsing JSON or validating labels for {post_id}: {e}")
            print(f"Response text received: {response_text}")
            total_errors += 1
            continue # Skip saving this item

        # Add to processed data
        processed_data.append({
            "id": post_id,
            "text": text,
            "entities": entities
        })

        # Save after each successful processing
        with open(output_json_path, 'w', encoding='utf-8') as f:
            json.dump(processed_data, f, indent=2, ensure_ascii=False)

        processed_ids.add(post_id) # Add here only after successful save
        total_processed += 1
        print(f"Processed {post_id} ({total_processed}/{total_remaining})")

    except Exception as e:
        print(f"General error processing {post_id}: {e}")
        total_errors += 1
        # Optionally add a delay or retry logic here
        time.sleep(1) # Simple delay
        continue


In [None]:
print(f"\nProcessing complete!")
print(f"Total articles successfully processed in this run: {total_processed}")
print(f"Total errors encountered in this run: {total_errors}")
if total_processed + total_errors > 0:
    success_rate = (total_processed / (total_processed + total_errors)) * 100
    print(f"Success rate for this run: {success_rate:.2f}%")
else:
    print("No new articles processed in this run.")
print(f"Total labeled articles in output file: {len(processed_data)}")
