<a href="https://colab.research.google.com/github/RyuichiSaito1/inflation-reddit-usa/blob/main/notebooks/gemni_2_0_flash_lite_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required packages
!pip install google-genai pandas gspread google-auth

In [None]:
import pandas as pd
import numpy as np
from datetime import datetime
import time
import re
import os
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

from google.colab import auth
from google import genai
from google.genai.types import HttpOptions
import gspread
from google.auth import default

In [None]:
# Authenticate with Google Cloud
auth.authenticate_user()

# Initialize Vertex AI client
project_id = "#####"
location = "#####"
client = genai.Client(
    vertexai=True,
    project=project_id,
    location=location,
    http_options=HttpOptions(api_version="v1")
)

# Get the tuned model
tuning_job_name = "#####"
tuning_job = client.tunings.get(name=tuning_job_name)
model_endpoint = tuning_job.tuned_model.endpoint

print(f"Successfully connected to model: {model_endpoint}")

In [None]:
class RedditInflationClassifier:
    def __init__(self, model_endpoint, client):
        self.model_endpoint = model_endpoint
        self.client = client
        self.prompt_template = """You are a chief economist at the IMF. I would like you to infer the public perception of inflation from Reddit posts. Please classify each Reddit post into one of the following categories:

    0: The post indicates deflation, such as the lower price of goods or services (e.g., "the prices are not bad"), affordable services (e.g., "this champagne is cheap and delicious"), sales information (e.g., "you can get it for only 10 dollars."), or a declining and buyer's market.

    2: The post indicates or includes inflation, such as the higher price of goods or services (e.g., "it's not cheap"), the unreasonable cost of goods or services (e.g., "the food is overpriced and cold"), consumers struggling to afford necessities (e.g., "items are too expensive to buy"), shortage of goods of services, or mention about an asset bubble.

    1: The post indicates neither deflation (0) nor inflation (2). This category also includes just questions to a community, social statements not personal experience, factual observations, references to originally expensive or cheap goods or services (e.g., "a gorgeous and costly dinner" or "an affordable Civic"), website promotion, authors' wishes, or illogical text.

    Please choose a stronger stance when the text includes both 0 and 2 stances. If these stances are of the same degree, answer 1.

Reddit Post:
'{text}'

Classification (0, 1, or 2):"""

    def create_prompt(self, text):
        return self.prompt_template.format(text=text)

    def extract_classification(self, response_text):
        """Extract classification number from model response"""
        patterns = [
            r'Classification.*?([012])',
            r'Answer.*?([012])',
            r'Response.*?([012])',
            r'\b([012])\b'
        ]

        for pattern in patterns:
            matches = re.findall(pattern, response_text)
            if matches:
                return int(matches[-1])

        # If no clear pattern, look for the first occurrence of 0, 1, or 2
        for char in response_text:
            if char in ['0', '1', '2']:
                return int(char)

        return 1  # Default to neutral if unable to extract

    def classify_text(self, text):
        """Classify a single text using the fine-tuned model"""
        try:
            prompt = self.create_prompt(text)
            response = self.client.models.generate_content(
                model=self.model_endpoint,
                contents=prompt,
            )
            classification = self.extract_classification(response.text)
            return classification
        except Exception as e:
            print(f"Error in classification: {e}")
            return 1  # Default to neutral on error

    def classify_batch(self, texts, delay=0.1):
        """Classify a batch of texts with rate limiting"""
        results = []
        for i, text in enumerate(texts):
            if i % 50 == 0:
                print(f"Processing {i+1}/{len(texts)} texts...")

            result = self.classify_text(text)
            results.append(result)
            time.sleep(delay)  # Rate limiting

        return results

In [None]:
def load_and_prepare_data(file_path, start_date=None, end_date=None):
    """Load TSV data and prepare for processing with optional date filtering"""
    print(f"Loading data from: {file_path}")

    # Read TSV file
    df = pd.read_csv(file_path, sep='\t', encoding='utf-8')
    print(f"Loaded {len(df)} total records")
    print(f"Columns: {df.columns.tolist()}")

    # Convert created_date to datetime
    df['created_date'] = pd.to_datetime(df['created_date'])

    # Show original date range
    print(f"Original date range: {df['created_date'].min()} to {df['created_date'].max()}")

    # Debug: Show some sample dates before filtering
    print(f"Sample dates from original data:")
    sample_dates = df['created_date'].head(10).tolist()
    for i, date in enumerate(sample_dates):
        print(f"  {i+1}: {date}")

    # Filter by date range if specified
    if start_date and end_date:
        start_date = pd.to_datetime(start_date)
        end_date = pd.to_datetime(end_date)
        print(f"Applying filter: {start_date} <= created_date <= {end_date}")
        df = df[(df['created_date'] >= start_date) & (df['created_date'] <= end_date)]
        print(f"Filtered to date range {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}: {len(df)} records")
    elif start_date:
        start_date = pd.to_datetime(start_date)
        print(f"Applying filter: created_date >= {start_date}")
        df = df[df['created_date'] >= start_date]
        print(f"Filtered from {start_date.strftime('%Y-%m-%d')}: {len(df)} records")
    elif end_date:
        end_date = pd.to_datetime(end_date)
        print(f"Applying filter: created_date <= {end_date}")
        df = df[df['created_date'] <= end_date]
        print(f"Filtered to {end_date.strftime('%Y-%m-%d')}: {len(df)} records")

    # Add year_month column for grouping
    df['year_month'] = df['created_date'].dt.to_period('M')

    print(f"Final date range: {df['created_date'].min()} to {df['created_date'].max()}")

    # Debug: Show the unique year_month values to verify filtering
    unique_months = sorted(df['year_month'].unique())
    print(f"Unique year_month values after filtering: {unique_months}")

    return df

def sample_monthly_data(df, max_samples_per_month=800):
    """Sample up to max_samples_per_month records per month"""
    sampled_dfs = []

    for year_month, group in df.groupby('year_month'):
        if len(group) <= max_samples_per_month:
            sampled_group = group
        else:
            sampled_group = group.sample(n=max_samples_per_month, random_state=42)

        sampled_dfs.append(sampled_group)
        print(f"{year_month}: sampled {len(sampled_group)} out of {len(group)} records")

    result_df = pd.concat(sampled_dfs, ignore_index=True)
    print(f"Total sampled records: {len(result_df)}")

    return result_df

def process_and_classify_data(df, classifier, output_path, spreadsheet_name, worksheet_name):
    """Process data month by month and classify texts"""
    # Sort by date
    df = df.sort_values('created_date')

    # Initialize Google Sheets connection
    try:
        # Use Google Colab authentication
        from google.colab import auth
        auth.authenticate_user()

        # Get default credentials
        from google.auth import default
        creds, _ = default()

        # Create gspread client
        gc = gspread.authorize(creds)

        # Open spreadsheet
        spreadsheet = gc.open(spreadsheet_name)
        worksheet = spreadsheet.worksheet(worksheet_name)

        # Check if worksheet has data and add headers if empty
        existing_data = worksheet.get_all_values()
        if not existing_data:
            headers = ['year_month', 'deflation', 'neither', 'inflation', 'total_number', 'average_score']
            worksheet.append_row(headers)

        print(f"Connected to existing spreadsheet: {spreadsheet_name} - {worksheet_name}")
    except gspread.exceptions.SpreadsheetNotFound:
        print(f"Error: Spreadsheet '{spreadsheet_name}' not found. Please check the name.")
        worksheet = None
    except gspread.exceptions.WorksheetNotFound:
        print(f"Error: Worksheet '{worksheet_name}' not found. Please check the name.")
        worksheet = None
    except Exception as e:
        print(f"Error connecting to Google Sheet: {type(e).__name__}: {str(e)}")
        worksheet = None

    # Group by year_month for processing
    for year_month, month_group in df.groupby('year_month'):
        print(f"\nProcessing {year_month}...")

        # Create a copy to avoid modifying original
        month_df = month_group.copy()

        # Classify texts
        texts = month_df['body'].fillna('').astype(str).tolist()
        classifications = classifier.classify_batch(texts)

        # Add classifications to dataframe
        month_df['inflation'] = classifications

        # Remove year_month column before saving to TSV
        month_df_tsv = month_df.drop('year_month', axis=1)

        # Save to TSV file (append mode)
        if os.path.exists(output_path):
            month_df_tsv.to_csv(output_path, sep='\t', mode='a', header=False, index=False)
        else:
            month_df_tsv.to_csv(output_path, sep='\t', mode='w', header=True, index=False)

        # Calculate monthly summary
        deflation_count = len(month_df[month_df['inflation'] == 0])
        neither_count = len(month_df[month_df['inflation'] == 1])
        inflation_count = len(month_df[month_df['inflation'] == 2])
        total_count = len(month_df)
        avg_score = (deflation_count * 0 + neither_count * 1 + inflation_count * 2) / total_count

        # Output to standard output
        print(f"Month: {year_month}")
        print(f"  Deflation (0): {deflation_count}")
        print(f"  Neither (1): {neither_count}")
        print(f"  Inflation (2): {inflation_count}")
        print(f"  Total: {total_count}")
        print(f"  Average Score: {avg_score:.4f}")

        # Add to Google Sheet
        if worksheet:
            try:
                row_data = [str(year_month), deflation_count, neither_count, inflation_count, total_count, round(avg_score, 4)]
                worksheet.append_row(row_data)
                print(f"  Added to Google Sheet: {spreadsheet_name}")
            except Exception as e:
                print(f"  Error updating Google Sheet: {e}")

        print(f"Saved {len(month_df)} records for {year_month}")

    print(f"\nAll data processed and saved to: {output_path}")

In [None]:
def main():
    """Main execution function"""
    # ☆
    # File paths
    input_file = "/content/drive/MyDrive/world-inflation/data/reddit/production/Frugal_submissions_2012_2022.tsv"
    output_file = "/content/drive/MyDrive/world-inflation/result/tsv/Frugal_submissions_results.tsv"

    # ☆
    # Google Sheets configuration - MODIFY THESE VALUES
    spreadsheet_name = "monthly_classification_result"  # Replace with your spreadsheet name
    worksheet_name = "subfrugal"      # Replace with your worksheet name

    # ☆
    # Date range configuration
    start_date = "2012-01-01"
    end_date = "2022-12-31"

    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    # Initialize classifier
    classifier = RedditInflationClassifier(model_endpoint, client)

    # Load and prepare data with date filtering
    df = load_and_prepare_data(input_file, start_date, end_date)

    # ☆
    # Sample monthly data
    sampled_df = sample_monthly_data(df, max_samples_per_month=200)

    print(f"Using Google Spreadsheet: {spreadsheet_name}")
    print(f"Using Worksheet: {worksheet_name}")
    if start_date:
        print(f"Date range: {start_date} to {end_date if end_date else 'latest'}")

    # Process and classify data with real-time Google Sheets updates
    process_and_classify_data(sampled_df, classifier, output_file, spreadsheet_name, worksheet_name)

    print("\n" + "="*60)
    print("PROCESS COMPLETED SUCCESSFULLY")
    print("="*60)
    print(f"Results saved to: {output_file}")
    print(f"Monthly summaries updated in Google Sheets: {spreadsheet_name} - {worksheet_name}")

# Execute the main function
if __name__ == "__main__":
    main()
