In [None]:
import os
import json
import random
from tqdm.notebook import tqdm # Use tqdm.notebook for better display in notebooks
from collections import defaultdict
import math # For ceiling division if needed


In [None]:
def get_text_length_category(text):
    """
    Categorize text length into different ranges
    Returns: 'short', 'medium', or 'long'
    """
    word_count = len(text.split())
    if word_count < 100:
        return 'short'
    elif word_count < 300:
        return 'medium'
    else:
        return 'long'


In [None]:
def select_balanced_news(input_file, output_file, target_size=5000):
    """
    Load news, filter non-truncated articles, and select a subset
    balanced by text length (short, medium, long), aiming for target_size.
    """
    print(f"Loading news from {input_file}...")
    try:
        with open(input_file, 'r', encoding='utf-8') as f:
            news_data = json.load(f)
    except FileNotFoundError:
        print(f"ERROR: Input file not found at {input_file}")
        return
    except json.JSONDecodeError:
        print(f"ERROR: Could not decode JSON from {input_file}")
        return

    print(f"Loaded {len(news_data)} total articles.")

    # --- Filtering ---
    print("Filtering news articles (keeping non-truncated title/description)...")
    filtered_news = [
        article for article in tqdm(news_data)
        if isinstance(article, dict) # Basic check for valid article structure
        and not article.get('truncated_title', True)
        and not article.get('truncated_description', True)
        and 'id' in article # Ensure articles have an ID
        and 'processed_text_finbert' in article # Ensure text exists
    ]
    print(f"\nFound {len(filtered_news)} valid (non-truncated) articles with text and ID.")

    if not filtered_news:
        print("ERROR: No valid articles found after filtering. Cannot proceed.")
        return

    # --- Deduplication (just in case) ---
    seen_ids = set()
    unique_filtered_news = []
    for article in filtered_news:
        if article['id'] not in seen_ids:
            unique_filtered_news.append(article)
            seen_ids.add(article['id'])
    if len(unique_filtered_news) < len(filtered_news):
        print(f"Removed {len(filtered_news) - len(unique_filtered_news)} duplicate articles based on ID.")
    filtered_news = unique_filtered_news


    if len(filtered_news) < target_size:
        print(f"WARNING: Only {len(filtered_news)} valid articles available, which is less than the target size of {target_size}. Selecting all available.")
        target_size = len(filtered_news) # Adjust target size

    # --- Categorize by Length ---
    length_categories = defaultdict(list)
    for article in filtered_news:
        text = article.get('processed_text_finbert', '')
        category = get_text_length_category(text)
        length_categories[category].append(article)

    print("\nInitial Text length distribution of valid articles:")
    for category, articles in length_categories.items():
        print(f"- {category.capitalize()}: {len(articles)} articles")

    # --- Selection Phase 1: Balanced Selection ---
    print(f"\n--- Selection Phase 1: Aiming for balance towards {target_size} articles ---")
    # Use math.ceil to slightly favor selecting more if target_size isn't divisible by 3
    target_per_category = math.ceil(target_size / 3)
    print(f"Target per category (approx): {target_per_category}")

    selected_news = []
    selected_ids = set()

    for category in ['short', 'medium', 'long']: # Process in defined order
        articles = length_categories.get(category, [])
        available_for_category = [a for a in articles if a['id'] not in selected_ids] # Should be all initially
        count_to_select = min(len(available_for_category), target_per_category)

        if count_to_select > 0:
                selected_from_category = random.sample(available_for_category, count_to_select)
                selected_news.extend(selected_from_category)
                selected_ids.update(a['id'] for a in selected_from_category)
                print(f"Selected {count_to_select} {category} articles.")
        else:
                print(f"No {category} articles available or needed in this phase.")


    # --- Selection Phase 2: Top Up if Needed ---
    current_count = len(selected_news)
    remaining_needed = target_size - current_count
    print(f"\n--- Selection Phase 2: Currently selected {current_count} articles ---")

    if remaining_needed > 0:
        print(f"Need to select {remaining_needed} more articles.")
        # Get all *remaining* filtered articles that haven't been selected yet
        all_filtered_ids = {a['id'] for a in filtered_news}
        remaining_available_ids = all_filtered_ids - selected_ids
        remaining_available_articles = [a for a in filtered_news if a['id'] in remaining_available_ids]

        count_to_select_phase2 = min(len(remaining_available_articles), remaining_needed)

        if count_to_select_phase2 > 0:
            print(f"Sampling {count_to_select_phase2} articles randomly from the remaining {len(remaining_available_articles)} available articles.")
            additional_selection = random.sample(remaining_available_articles, count_to_select_phase2)
            selected_news.extend(additional_selection)
            # selected_ids.update(a['id'] for a in additional_selection) # Update IDs if needed later
            print(f"Selected {count_to_select_phase2} additional articles.")
        else:
            print("No more available articles to select in Phase 2.")
    else:
        print("Target size reached or exceeded in Phase 1. No Phase 2 needed.")


    # --- Final Steps ---
    print("\n--- Finalizing Selection ---")
    # Shuffle the final list
    random.shuffle(selected_news)

    # Truncate if slightly over target (e.g., if target=5000, target_per_category=1667, 3*1667=5001)
    if len(selected_news) > target_size:
            print(f"Selection slightly over target ({len(selected_news)}). Truncating to {target_size}.")
            selected_news = selected_news[:target_size]

    print(f"Final number of selected articles: {len(selected_news)}")

    # Simplify the selected news to only include id and text
    simplified_news = [
        {
            "id": article["id"],
            "text": article["processed_text_finbert"]
        }
        for article in selected_news
    ]

    # Final length distribution check
    final_categories = defaultdict(int)
    for article_simple in simplified_news:
        category = get_text_length_category(article_simple["text"])
        final_categories[category] += 1

    print("\nFinal length distribution of selected articles:")
    for category, count in final_categories.items():
        print(f"- {category.capitalize()}: {count} articles")

    # Save selected news
    os.makedirs(os.path.dirname(output_file), exist_ok=True) # Ensure output directory exists
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(simplified_news, f, indent=2, ensure_ascii=False)
    print(f"\nSaved {len(simplified_news)} selected articles to {output_file}")


In [None]:
# --- Configuration ---
# !! Adjust this path to point to your actual FinBERT processed news file !!
input_news_file = "../../../Data/Historical News/FinBERT_Data/news.finbert.json"

# Output file relative to the notebook's location (e.g., inside Information Extraction/Data Selection/)
# This will save it in Information Extraction/Unlabeled/
output_news_file = "../../../Data/Historical News/NER_Data/Unlabeled/5000_news_for_NER.json"
target_news_count = 5000

# --- Run Selection ---
select_balanced_news(input_news_file, output_news_file, target_news_count)
