In [None]:
import os
import json
import random
from collections import defaultdict
from tqdm.notebook import tqdm # Use notebook version

In [None]:
def select_reddit_posts(input_dir, output_file, target_size_per_subreddit=850, min_words=50, max_words=500):
    """
    Load posts from FinBERT processed files for multiple subreddits,
    filter by length, and select a balanced subset.
    """
    subreddit_posts = defaultdict(list)
    print(f"Loading posts from directory: {input_dir}")

    try:
        # Check if input directory exists
        if not os.path.isdir(input_dir):
                print(f"ERROR: Input directory not found: {input_dir}")
                return
        filenames = os.listdir(input_dir)
    except Exception as e:
            print(f"ERROR: Could not list files in input directory {input_dir}: {e}")
            return

    # --- Load and Filter Posts ---
    print("Processing files...")
    files_processed = 0
    for filename in tqdm(filenames):
        if filename.startswith('finbert_r_') and filename.endswith('.json'):
            files_processed += 1
            subreddit = filename[len('finbert_r_'):-len('.json')]
            file_path = os.path.join(input_dir, filename)

            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    posts = json.load(f)
            except FileNotFoundError:
                print(f"\nWARNING: File not found {file_path} (skipped)")
                continue
            except json.JSONDecodeError:
                print(f"\nWARNING: Could not decode JSON from {file_path} (skipped)")
                continue
            except Exception as e:
                print(f"\nERROR: Reading file {file_path}: {e} (skipped)")
                continue

            # Filter posts based on length and structure
            valid_posts_for_subreddit = []
            for post in posts:
                if isinstance(post, dict) and 'processed_text_finbert' in post and 'id' in post:
                        text = post.get('processed_text_finbert', '')
                        word_count = len(text.split())
                        if min_words <= word_count <= max_words:
                            # Add subreddit info if missing (using filename)
                            if 'subreddit' not in post:
                                post['subreddit'] = subreddit
                            valid_posts_for_subreddit.append(post)

            subreddit_posts[subreddit].extend(valid_posts_for_subreddit)
            # Optional: Add print inside loop if needed: print(f"Found {len(valid_posts_for_subreddit)} valid posts in r/{subreddit}")
    print(f"\nProcessed {files_processed} potential subreddit files.")
    if not subreddit_posts:
        print("ERROR: No valid posts found in any subreddit file. Cannot proceed.")
        return

    # --- Select Balanced Subset ---
    print(f"\n--- Selecting Posts (Target per subreddit: {target_size_per_subreddit}) ---")
    selected_posts = []
    total_available = 0
    for subreddit, posts in subreddit_posts.items():
        total_available += len(posts)
        print(f"Subreddit r/{subreddit}: Found {len(posts)} valid posts.")
        count_to_select = min(len(posts), target_size_per_subreddit)

        if count_to_select > 0:
            selected = random.sample(posts, count_to_select)
            print(f" -> Selecting {len(selected)} posts.")
            selected_posts.extend(selected)
        else:
                print(f" -> Not enough valid posts to select.")

    print(f"\nTotal valid posts available across subreddits: {total_available}")

    # --- Final Steps ---
    print("\n--- Finalizing Selection ---")
    # Shuffle the final selection
    random.shuffle(selected_posts)
    print(f"Total posts selected across all subreddits: {len(selected_posts)}")

    # Simplify output format if needed (currently keeps original structure)
    # If you only want id and text like the news script:
    simplified_posts = [
        {
            "id": post["id"],
            "text": post["processed_text_finbert"]
            # Optionally keep subreddit: "subreddit": post.get('subreddit', 'unknown')
        } for post in selected_posts
    ]
    print(f"Saving {len(simplified_posts)} posts with 'id' and 'text' fields.")


    # Save selected posts
    os.makedirs(os.path.dirname(output_file), exist_ok=True) # Ensure output directory exists
    with open(output_file, 'w', encoding='utf-8') as f:
        # json.dump(selected_posts, f, indent=2, ensure_ascii=False) # Use this to save full structure
        json.dump(simplified_posts, f, indent=2, ensure_ascii=False) # Use this to save simplified structure
    print(f"\nSaved selected posts to {output_file}")

    # Print final counts per subreddit
    print("\nFinal Posts per subreddit in selection:")
    final_subreddit_counts = defaultdict(int)
    # Adjust based on whether you saved simplified_posts or selected_posts
    for post in selected_posts: # Iterate original selection to get subreddit info easily
            final_subreddit_counts[post.get('subreddit', 'unknown')] += 1

    # Check if any simplified posts are missing subreddit info (shouldn't happen with fix)
    # for post in simplified_posts:
    #    final_subreddit_counts[post.get('subreddit', 'unknown')] += 1

    for subreddit, count in final_subreddit_counts.items():
            print(f"- r/{subreddit}: {count} posts")

In [None]:
# --- Configuration ---
# !! Adjust this path to point to your FinBERT processed Reddit data directory !!
# Assumes it's relative to the IS450 Project/Notebooks/ level
input_reddit_dir = "../../../Data/Historical Reddit/FinBERT_Data/"

# Output file relative to the notebook's location (e.g., inside Information Extraction/Data Selection/)
# This will save it in Information Extraction/Unlabeled/
output_reddit_file = "../../../Data/Historical Reddit/NER_Data/Unlabeled/5000_reddit_posts_for_NER.json"

target_posts_per_sub = 850 # Target number of posts per subreddit
min_word_count = 50
max_word_count = 500

# --- Run Selection ---
select_reddit_posts(input_reddit_dir, output_reddit_file, target_posts_per_sub, min_word_count, max_word_count)