In [1]:
"""
Unified Content Classification Script
Classifies labor and inflation content across all Fed communication sources
"""

from google.colab import drive
drive.mount("/content/drive", force_remount=True)

import os
import pandas as pd
import numpy as np
import re
import json
import random
import time
from glob import glob

# Set random seed for reproducibility
seed = int(time.time())
random.seed(seed)
np.random.seed(seed)

# ============================================================================
# CONFIGURATION - SET WHICH SOURCES TO PROCESS
# ============================================================================

SOURCES_TO_PROCESS = {
    'transcripts': False,    # Set to False to skip
    'statements': False,
    'minutes': True,
    'speeches': False,
    'press_conferences': False
}

# ============================================================================
# PATHS CONFIGURATION
# ============================================================================

BASE_DIR = '/content/drive/MyDrive/FedComs'
DICT_DIR = f'{BASE_DIR}/Dictionaries'
VALIDATION_DIR = f'{BASE_DIR}/Validation_Sets'

SOURCE_CONFIGS = {
    'transcripts': {
        'input_dir': f'{BASE_DIR}/Transcripts/final_transcripts',
        'output_dir': f'{BASE_DIR}/Transcripts',
        'output_file': 'transcripts_content.csv',
        'validation_file': 'transcripts_validate.csv',
        'text_column': 'Text',
        'id_column': 'id',
        'date_column': 'Date',
        'name_column': 'Name',
        'role_column': 'Role',
        'process_type': 'multiple_files',  # Multiple CSV files
        'group_by': 'row'  # Each row is a separate item
    },
    'statements': {
        'input_file': f'{BASE_DIR}/Statements/fomc_statements_cleaned.csv',
        'output_dir': f'{BASE_DIR}/Statements',
        'output_file': 'statement_content.csv',
        'validation_file': 'statement_validate.csv',
        'text_column': 'text',
        'id_column': 'id',
        'date_column': 'date',
        'process_type': 'single_file',
        'group_by': 'row'
    },
    'minutes': {
        'input_file': f'{BASE_DIR}/Minutes/fomc_minutes_cleaned.csv',
        'output_dir': f'{BASE_DIR}/Minutes',
        'output_file': 'minutes_content.csv',
        'validation_file': 'minutes_validate.csv',
        'text_column': 'text',
        'id_column': 'id',
        'date_column': 'date',
        'process_type': 'single_file',
        'group_by': 'row'
    },
    'speeches': {
        'input_dir': f'{BASE_DIR}/Speeches/fed_speeches_clean',
        'output_dir': f'{BASE_DIR}/Speeches',
        'output_file': 'speeches_content.csv',
        'validation_file': 'speeches_validate.csv',
        'text_column': 'text',  # Will search for columns containing 'text'
        'date_column': 'date',
        'name_column': 'official_name',  # Will search multiple options
        'process_type': 'multiple_files',
        'group_by': 'row'
    },
    'press_conferences': {
        'input_file': f'{BASE_DIR}/PressConf/fomc_press_conferences.csv',
        'output_dir': f'{BASE_DIR}/PressConf',
        'output_file': 'press_conferences_content.csv',
        'validation_file': 'press_conferences_validate.csv',
        'text_column': 'text',
        'id_column': 'id',
        'date_column': 'date',
        'speaker_column': 'speaker',
        'process_type': 'single_file',
        'group_by': 'row'
    }
}

# ============================================================================
# LOAD DICTIONARIES
# ============================================================================

print("\n" + "="*70)
print("LOADING DICTIONARIES")
print("="*70)

with open(os.path.join(DICT_DIR, 'labor_indicators.json'), 'r') as f:
    LABOR_INDICATORS = json.load(f)

with open(os.path.join(DICT_DIR, 'inflation_indicators.json'), 'r') as f:
    INFLATION_INDICATORS = json.load(f)

with open(os.path.join(DICT_DIR, 'inflation_pattern_mapping.json'), 'r') as f:
    INFLATION_PATTERN_TO_INDICATOR = json.load(f)

print("Dictionaries loaded successfully!")
print(f"Labor indicators: {list(LABOR_INDICATORS.keys())}")
print(f"Inflation categories: {list(INFLATION_INDICATORS.keys())}")

# ============================================================================
# TEXT PROCESSING FUNCTIONS
# ============================================================================

def fix_text_encoding(text):
    """Fix common text encoding issues."""
    text = text.replace('Ã¢â‚¬"', '–')
    text = text.replace('Ã¢â‚¬"', '—')
    text = text.replace('Ã¢â‚¬Å"', '"')
    text = text.replace('Ã¢â‚¬', '"')
    text = text.replace('\u2013', '–')
    text = text.replace('\u2014', '—')
    text = text.replace('\u2018', "'")
    text = text.replace('\u2019', "'")
    text = text.replace('\u201c', '"')
    text = text.replace('\u201d', '"')
    text = text.replace('\u2026', '...')
    text = re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f-\x9f]', '', text)
    return text

def split_into_sentences(text):
    """Split text into sentences, preserving initials and abbreviations."""
    text = fix_text_encoding(text)

    abbreviations = [
        r'\bU\.S\.A\.', r'\bU\.S\.', r'\bU\.K\.', r'\bE\.U\.',
        r'\bSt\.', r'\bMr\.', r'\bMrs\.', r'\bMs\.', r'\bDr\.',
        r'\bProf\.', r'\bSr\.', r'\bJr\.', r'\bvs\.', r'\betc\.',
        r'\bi\.e\.', r'\be\.g\.', r'\bVol\.', r'\bNo\.', r'\bpp\.',
        r'\bCo\.', r'\bInc\.', r'\bLtd\.', r'\bCorp\.',
        r'\bPh\.D\.', r'\bM\.A\.', r'\bM\.S\.', r'\bB\.A\.',
        r'\bD\.C\.', r'\bA\.M\.', r'\bP\.M\.'
    ]

    for idx, abbr in enumerate(abbreviations):
        text = re.sub(abbr, f'<ABBR_{idx}>', text, flags=re.IGNORECASE)

    text = re.sub(r'\b([A-Z])\.([\s+[A-Z]\.)*(?=\s+[A-Z][a-z]+)', lambda m: m.group(0).replace('.', '<NAME>'), text)
    text = re.sub(r'\b\d+\.\d+\b', lambda m: m.group(0).replace('.', '<DEC>'), text)

    voting_pattern = r'((?:Voting for|Voting against)\s+[^.!?]+?)([.!?]+\s+|$)'
    voting_matches = []
    def store_voting_match(match):
        voting_matches.append(match.group(1))
        return f'<VOTE_{len(voting_matches) - 1}>'
    text = re.sub(voting_pattern, store_voting_match, text)

    sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z]|$)', text)
    sentences = [s.strip() for s in sentences if s.strip()]

    restored_sentences = []
    for sentence in sentences:
        for idx in range(len(abbreviations)):
            sentence = sentence.replace(f'<ABBR_{idx}>', abbreviations[idx].replace(r'\b', '').replace(r'\.', '.'))
        sentence = sentence.replace('<NAME>', '.')
        sentence = sentence.replace('<DEC>', '.')
        for i, voting_list in enumerate(voting_matches):
            placeholder = f'<VOTE_{i}>'
            if placeholder in sentence:
                sentence = sentence.replace(placeholder, voting_list)
        restored_sentences.append(sentence)

    return restored_sentences

# ============================================================================
# CLASSIFICATION FUNCTIONS
# ============================================================================

def check_keywords_in_sentence(sentence, keywords):
    """Check if any keyword appears in the sentence."""
    sentence_lower = sentence.lower()
    for keyword in keywords:
        pattern = r'\b' + re.escape(keyword.lower()) + r'\b'
        if re.search(pattern, sentence_lower):
            return True
    return False

def check_employment_indicator(sentence, keywords):
    """Check for Employment indicator, excluding maximum/full employment."""
    sentence_lower = sentence.lower()

    if re.search(r'\b(?:maximum|full)\s+employment\b', sentence_lower):
        return False
    if re.search(r'\bemployment\s+goal\b', sentence_lower):
        return False

    for keyword in keywords:
        pattern = r'\b' + re.escape(keyword.lower()) + r'\b'
        if re.search(pattern, sentence_lower):
            return True
    return False

def check_general_labor_term(sentence):
    """Check if sentence contains general labor terms."""
    sentence_lower = sentence.lower()
    general_labor_keywords = LABOR_INDICATORS.get("General Labor", [])
    for keyword in general_labor_keywords:
        pattern = r'\b' + re.escape(keyword.lower()) + r'\b'
        if re.search(pattern, sentence_lower):
            return True
    return False

def check_general_inflation_terms(sentence):
    """Check if sentence contains general inflation terms."""
    sentence_lower = sentence.lower()
    general_inflation_patterns = INFLATION_INDICATORS.get("General Inflation", {}).get("general_patterns", [])
    for pattern in general_inflation_patterns:
        if re.search(pattern, sentence_lower, re.IGNORECASE):
            return True
    return False

def check_inflation_sentence(sentence):
    """Check if sentence mentions any inflation indicator."""
    mentioned_indicators = set()
    sentence_lower = sentence.lower()

    for category, subcategories in INFLATION_INDICATORS.items():
        for pattern_name, pattern_list in subcategories.items():
            for pattern in pattern_list:
                if re.search(pattern, sentence_lower, re.IGNORECASE):
                    indicator_name = INFLATION_PATTERN_TO_INDICATOR.get(pattern_name, "Other")
                    mentioned_indicators.add(indicator_name)
                    break

    # Remove generic indicators when specific ones are present
    if "Core_CPI" in mentioned_indicators and "Core" in mentioned_indicators:
        mentioned_indicators.discard("Core")
    if "Core_PCE" in mentioned_indicators and "Core" in mentioned_indicators:
        mentioned_indicators.discard("Core")
    if "Headline_CPI" in mentioned_indicators and "Headline" in mentioned_indicators:
        mentioned_indicators.discard("Headline")
    if "Headline_PCE" in mentioned_indicators and "Headline" in mentioned_indicators:
        mentioned_indicators.discard("Headline")

    return mentioned_indicators

def classify_sentence(sentence):
    """Classify a single sentence and return its indicators."""
    labor_specific_found = False
    labor_indicators_in_sentence = set()

    for indicator, keywords in LABOR_INDICATORS.items():
        if indicator == "General Labor":
            continue

        if indicator == "Employment":
            if check_employment_indicator(sentence, keywords):
                labor_indicators_in_sentence.add(indicator)
                labor_specific_found = True
        else:
            if check_keywords_in_sentence(sentence, keywords):
                labor_indicators_in_sentence.add(indicator)
                labor_specific_found = True

    labor_general_found = check_general_labor_term(sentence)
    labor_found = labor_specific_found or labor_general_found

    inflation_indicators_in_sentence = check_inflation_sentence(sentence)
    inflation_specific_found = bool(inflation_indicators_in_sentence)

    inflation_general_found = check_general_inflation_terms(sentence)
    inflation_found = inflation_specific_found or inflation_general_found

    if labor_found and inflation_found:
        classification = "Both"
    elif labor_found:
        classification = "Labor"
    elif inflation_found:
        classification = "Inflation"
    else:
        classification = "Neither"

    return {
        'classification': classification,
        'labor_indicators': list(labor_indicators_in_sentence),
        'inflation_indicators': list(inflation_indicators_in_sentence)
    }

def analyze_text(text):
    """Analyze a single text for labor and inflation content."""
    sentences = split_into_sentences(text)
    total_sentences = len(sentences)

    labor_sentences = 0
    inflation_sentences = 0
    both_sentences = 0

    labor_indicator_counts = {indicator: 0 for indicator in LABOR_INDICATORS.keys() if indicator != "General Labor"}
    inflation_indicator_list = sorted(list(set(
        indicator for indicator in INFLATION_PATTERN_TO_INDICATOR.values()
        if indicator not in ["General_Inflation", "Other"]
    )))
    inflation_indicator_counts = {indicator: 0 for indicator in inflation_indicator_list}

    sentence_data_list = []

    for sent_idx, sentence in enumerate(sentences):
        classification_result = classify_sentence(sentence)

        labor_indicators_filtered = [ind for ind in classification_result['labor_indicators']
                                      if ind != "General Labor"]
        inflation_indicators_filtered = [ind for ind in classification_result['inflation_indicators']
                                          if ind not in ["General_Inflation", "Other"]]

        sentence_data = {
            'sentence_number': sent_idx + 1,
            'sentence_text': sentence,
            'classification': classification_result['classification'],
            'labor_indicators': ', '.join(sorted(labor_indicators_filtered)) if labor_indicators_filtered else '',
            'inflation_indicators': ', '.join(sorted(inflation_indicators_filtered)) if inflation_indicators_filtered else ''
        }
        sentence_data_list.append(sentence_data)

        labor_specific_found = bool(classification_result['labor_indicators'])
        labor_general_found = check_general_labor_term(sentence)
        labor_found = labor_specific_found or labor_general_found

        inflation_specific_found = bool(classification_result['inflation_indicators'])
        inflation_general_found = check_general_inflation_terms(sentence)
        inflation_found = inflation_specific_found or inflation_general_found

        if labor_found and inflation_found:
            both_sentences += 1
            labor_sentences += 1
            inflation_sentences += 1
        elif labor_found:
            labor_sentences += 1
        elif inflation_found:
            inflation_sentences += 1

        for indicator in classification_result['labor_indicators']:
            if indicator in labor_indicator_counts:
                labor_indicator_counts[indicator] += 1

        for indicator in classification_result['inflation_indicators']:
            if indicator in inflation_indicator_counts:
                inflation_indicator_counts[indicator] += 1

    total_labor_mentions = sum(labor_indicator_counts.values())
    total_inflation_mentions = sum(inflation_indicator_counts.values())

    labor_emphasis = {}
    for indicator, count in labor_indicator_counts.items():
        labor_emphasis[f"labor_emphasis_{indicator}"] = count / total_labor_mentions if total_labor_mentions > 0 else 0

    inflation_emphasis = {}
    for indicator, count in inflation_indicator_counts.items():
        inflation_emphasis[f"inflation_emphasis_{indicator}"] = count / total_inflation_mentions if total_inflation_mentions > 0 else 0

    labor_sentence_share = {}
    for indicator, count in labor_indicator_counts.items():
        labor_sentence_share[f"labor_share_total_sentences_{indicator}"] = count / total_sentences if total_sentences > 0 else 0

    inflation_sentence_share = {}
    for indicator, count in inflation_indicator_counts.items():
        inflation_sentence_share[f"inflation_share_total_sentences_{indicator}"] = count / total_sentences if total_sentences > 0 else 0

    labor_inflation_total = labor_sentences + inflation_sentences - both_sentences
    labor_share_of_labor_inflation = labor_sentences / labor_inflation_total if labor_inflation_total > 0 else 0

    summary_results = {
        'sentences_on_labor': labor_sentences,
        'sentences_on_inflation': inflation_sentences,
        'sentences_on_both': both_sentences,
        'total_sentences': total_sentences,
        'labor_share_of_labor_inflation_sentences': labor_share_of_labor_inflation
    }

    for indicator, count in labor_indicator_counts.items():
        summary_results[f'labor_{indicator}_count'] = count

    for indicator, count in inflation_indicator_counts.items():
        summary_results[f'inflation_{indicator}_count'] = count

    summary_results.update(labor_emphasis)
    summary_results.update(inflation_emphasis)
    summary_results.update(labor_sentence_share)
    summary_results.update(inflation_sentence_share)

    return summary_results, sentence_data_list

# ============================================================================
# HELPER FUNCTIONS FOR DATA PROCESSING
# ============================================================================

def find_column(df, possible_names):
    """Find a column by checking multiple possible names."""
    for name in possible_names:
        if name in df.columns:
            return name
        # Case-insensitive search
        for col in df.columns:
            if col.lower() == name.lower():
                return col
    return None

def get_text_column(df, config):
    """Get the text column from dataframe."""
    text_col = config.get('text_column')

    # Direct match
    if text_col and text_col in df.columns:
        return text_col

    # Search for columns containing 'text'
    for col in df.columns:
        if 'text' in col.lower():
            return col

    return None

def create_validation_set(sentences_df, source_name):
    """Create validation set from sentence dataframe."""
    n_labor = 15
    n_inflation = 15
    n_both = 5
    n_neither = 10

    validation_samples = []

    labor_sentences = sentences_df[sentences_df['classification'] == 'Labor']
    if len(labor_sentences) >= n_labor:
        validation_samples.append(labor_sentences.sample(n=n_labor, random_state=seed))
    elif len(labor_sentences) > 0:
        validation_samples.append(labor_sentences)

    inflation_sentences = sentences_df[sentences_df['classification'] == 'Inflation']
    if len(inflation_sentences) >= n_inflation:
        validation_samples.append(inflation_sentences.sample(n=n_inflation, random_state=seed))
    elif len(inflation_sentences) > 0:
        validation_samples.append(inflation_sentences)

    both_sentences = sentences_df[sentences_df['classification'] == 'Both']
    if len(both_sentences) >= n_both:
        validation_samples.append(both_sentences.sample(n=n_both, random_state=seed))
    elif len(both_sentences) > 0:
        validation_samples.append(both_sentences)

    neither_sentences = sentences_df[sentences_df['classification'] == 'Neither']
    if len(neither_sentences) >= n_neither:
        validation_samples.append(neither_sentences.sample(n=n_neither, random_state=seed))
    elif len(neither_sentences) > 0:
        validation_samples.append(neither_sentences)

    if validation_samples:
        validation_df = pd.concat(validation_samples, ignore_index=True)
        validation_df = validation_df.sample(frac=1, random_state=seed).reset_index(drop=True)
        return validation_df

    return None

# ============================================================================
# MAIN PROCESSING FUNCTION
# ============================================================================

def process_source(source_name, config):
    """Process a single source."""
    print("\n" + "="*70)
    print(f"PROCESSING: {source_name.upper()}")
    print("="*70)

    results_list = []
    all_sentences = []

    # Handle multiple files vs single file
    if config['process_type'] == 'multiple_files':
        input_dir = config['input_dir']
        csv_files = glob(os.path.join(input_dir, '*.csv'))
        print(f"Found {len(csv_files)} files in {input_dir}")

        for idx, csv_file in enumerate(csv_files):
            filename = os.path.basename(csv_file)
            if idx % 5 == 0:
                print(f"Processing file {idx+1}/{len(csv_files)}: {filename}")

            try:
                df = pd.read_csv(csv_file, encoding='utf-8', encoding_errors='replace')
                text_col = get_text_column(df, config)

                if text_col is None:
                    print(f"  Warning: No text column found in {filename}")
                    continue

                for row_idx, row in df.iterrows():
                    text = str(row[text_col]) if pd.notna(row[text_col]) else ''
                    if len(text.strip()) == 0:
                        continue

                    summary_results, sentence_data_list = analyze_text(text)

                    # Add metadata
                    for col in df.columns:
                        if col != text_col:
                            summary_results[col] = str(row[col]) if pd.notna(row[col]) else ''

                    results_list.append(summary_results)

                    # Add sentence-level data
                    for sentence_data in sentence_data_list:
                        for col in df.columns:
                            if col != text_col:
                                sentence_data[col] = str(row[col]) if pd.notna(row[col]) else ''
                        all_sentences.append(sentence_data)

            except Exception as e:
                print(f"  Error processing {filename}: {e}")
                continue

    else:  # single_file
        input_file = config['input_file']
        print(f"Reading from: {input_file}")

        try:
            df = pd.read_csv(input_file, encoding='utf-8', encoding_errors='replace')
            print(f"Loaded {len(df)} records")

            text_col = get_text_column(df, config)
            if text_col is None:
                print(f"ERROR: No text column found")
                return

            for idx, row in df.iterrows():
                if idx % 10 == 0:
                    print(f"Processing record {idx+1}/{len(df)}")

                text = str(row[text_col]) if pd.notna(row[text_col]) else ''
                if len(text.strip()) == 0:
                    continue

                summary_results, sentence_data_list = analyze_text(text)

                # Add metadata
                for col in df.columns:
                    if col != text_col:
                        summary_results[col] = str(row[col]) if pd.notna(row[col]) else ''

                results_list.append(summary_results)

                # Add sentence-level data
                for sentence_data in sentence_data_list:
                    for col in df.columns:
                        if col != text_col:
                            sentence_data[col] = str(row[col]) if pd.notna(row[col]) else ''
                    all_sentences.append(sentence_data)

        except Exception as e:
            print(f"Error reading file: {e}")
            return

    # Save results
    results_df = pd.DataFrame(results_list)

    if len(results_df) > 0:
        # Try to sort by date
        date_col = config.get('date_column')
        if date_col and date_col in results_df.columns:
            try:
                results_df = results_df.sort_values(date_col)
            except:
                pass

        output_file = os.path.join(config['output_dir'], config['output_file'])
        results_df.to_csv(output_file, index=False)
        print(f"\n✓ Summary saved to: {output_file}")
        print(f"  Shape: {results_df.shape}")

    # Create validation set
    sentences_df = pd.DataFrame(all_sentences)

    if len(sentences_df) > 0:
        print(f"\n✓ Total sentences: {len(sentences_df)}")
        print(f"  Classification distribution:")
        print(sentences_df['classification'].value_counts().to_string())

        validation_df = create_validation_set(sentences_df, source_name)

        if validation_df is not None:
            validation_file = os.path.join(VALIDATION_DIR, config['validation_file'])
            validation_df.to_csv(validation_file, index=False)
            print(f"\n✓ Validation set saved to: {validation_file}")
            print(f"  Size: {len(validation_df)}")

# ============================================================================
# RUN PROCESSING
# ============================================================================

print("\n" + "="*70)
print("STARTING UNIFIED CLASSIFICATION")
print("="*70)

for source_name, should_process in SOURCES_TO_PROCESS.items():
    if should_process:
        config = SOURCE_CONFIGS[source_name]
        process_source(source_name, config)
    else:
        print(f"\nSkipping {source_name} (disabled in config)")

print("\n" + "="*70)
print("ALL PROCESSING COMPLETE!")
print("="*70)

Mounted at /content/drive

LOADING DICTIONARIES
Dictionaries loaded successfully!
Labor indicators: ['General Labor', 'Employment', 'Unemployment', 'Participation', 'Wages', 'Vacancies', 'Quits', 'Layoffs', 'Hiring']
Inflation categories: ['General Inflation', 'Core Measures', 'Headline Measures', 'Sectoral Measures', 'Producer Price Index', 'Wage Inflation', 'Inflation Expectations', 'Commodity Prices']

STARTING UNIFIED CLASSIFICATION

Skipping transcripts (disabled in config)

Skipping statements (disabled in config)

PROCESSING: MINUTES
Reading from: /content/drive/MyDrive/FedComs/Minutes/fomc_minutes_cleaned.csv
Loaded 199 records
Processing record 1/199
Processing record 11/199
Processing record 21/199
Processing record 31/199
Processing record 41/199
Processing record 51/199
Processing record 61/199
Processing record 71/199
Processing record 81/199
Processing record 91/199
Processing record 101/199
Processing record 111/199
Processing record 121/199
Processing record 131/199
Pro