<a href="https://colab.research.google.com/github/RyuichiSaito1/multilingual-economic-narratives/blob/main/notebooks/classification_test_using_gemini_2_0_flash_lite.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install required packages
!pip install google-cloud-aiplatform pandas gspread google-auth

import pandas as pd
import numpy as np
from datetime import datetime
import time
import re
import os
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

from google.colab import auth
from google.cloud import aiplatform
import vertexai
from vertexai.generative_models import GenerativeModel
import gspread
from google.auth import default

In [None]:
# Google Cloud Authentication Setup
print("Authenticating with Google Cloud...")
auth.authenticate_user()
print("Authentication completed!")

# Initialize Vertex AI
PROJECT_ID = "###"
LOCATION = "###"

print("Initializing Vertex AI...")
try:
    vertexai.init(project=PROJECT_ID, location=LOCATION)
    print("Vertex AI initialized successfully!")
except Exception as e:
    print(f"Error initializing Vertex AI: {e}")
    raise

In [None]:
# Initialize Gemini model
print("Initializing Gemini model...")
try:
    model = GenerativeModel("gemini-2.0-flash-lite")
    print("Model initialized successfully!")

    # Test the model connection
    print("Testing model connection...")
    test_response = model.generate_content("Hello, respond with 'OK'")
    print(f"Model test response: {test_response.text}")

except Exception as e:
    print(f"Error initializing model: {e}")
    # Try alternative model names
    alternative_models = [
        "gemini-2.0-flash-exp",
        "gemini-1.5-flash",
        "gemini-1.5-pro"
    ]

    model = None
    for model_name in alternative_models:
        try:
            print(f"Trying {model_name}...")
            model = GenerativeModel(model_name)
            test_response = model.generate_content("Hello, respond with 'OK'")
            print(f"Successfully connected to {model_name}")
            break
        except Exception as model_error:
            continue

    if model is None:
        raise Exception("No Gemini model available")

print(f"Successfully connected to model")


In [None]:
class RedditInflationClassifier:
    def __init__(self, model):

        self.model = model

        self.prompt_template = """Είστε επικεφαλής οικονομολόγος στο ΔΝΤ. Θα ήθελα να συμπεράνετε την αντίληψη του κοινού για τον πληθωρισμό από αναρτήσεις στο Reddit. Παρακαλώ ταξινομήστε κάθε ανάρτηση στο Reddit σε μία από τις ακόλουθες κατηγορίες:

0: Η ανάρτηση υποδεικνύει αποπληθωρισμό, όπως τη χαμηλότερη τιμή αγαθών ή υπηρεσιών (π.χ., "οι τιμές δεν είναι κακές"), προσιτές υπηρεσίες (π.χ., "αυτή η σαμπάνια είναι φθηνή και νόστιμη"), πληροφορίες πωλήσεων (π.χ., "μπορείτε να την αποκτήσετε μόνο με 10 δολάρια.") ή μια φθίνουσα αγορά αγοραστή.

2: Η ανάρτηση υποδεικνύει ή περιλαμβάνει πληθωρισμό, όπως την υψηλότερη τιμή αγαθών ή υπηρεσιών (π.χ., "δεν είναι φθηνή"), το παράλογο κόστος αγαθών ή υπηρεσιών (π.χ., "το φαγητό είναι υπερτιμημένο και κρύο"), καταναλωτές που αγωνίζονται να αγοράσουν τα απαραίτητα (π.χ., "τα είδη είναι πολύ ακριβά για να αγοραστούν"), έλλειψη αγαθών ή υπηρεσιών ή αναφορά σε μια φούσκα περιουσιακών στοιχείων.

1: Η ανάρτηση δεν υποδεικνύει ούτε αποπληθωρισμό (0) ούτε πληθωρισμό (2). Αυτή η κατηγορία περιλαμβάνει επίσης μόνο ερωτήσεις προς μια κοινότητα, κοινωνικές δηλώσεις, όχι προσωπικές. εμπειρία, πραγματικές παρατηρήσεις, αναφορές σε αρχικά ακριβά ή φθηνά αγαθά ή υπηρεσίες (π.χ., "ένα πανέμορφο και ακριβό δείπνο" ή "ένα προσιτό Civic"), προώθηση ιστοσελίδας, επιθυμίες συγγραφέων ή παράλογο κείμενο.

Επιλέξτε μια πιο ισχυρή στάση όταν το κείμενο περιλαμβάνει τόσο 0 όσο και 2 θέσεις. Εάν αυτές οι θέσεις είναι του ίδιου βαθμού, απαντήστε 1.

Δημοσίευση Reddit για ταξινόμηση: "{text}"

Απαντήστε μόνο με τον αριθμό: 0, 1 ή 2"""

    def create_prompt(self, text):
        return self.prompt_template.format(text=text)

    def extract_classification(self, response_text):
        """Extract classification number from model response - ENHANCED ERROR HANDLING"""
        # Clean the response
        response_text = response_text.strip()

        # Look for digits in the response
        digits = re.findall(r'\d', response_text)
        if digits:
            pred_num = int(digits[0])
            if pred_num in [0, 1, 2]:
                return pred_num

        # Try to find valid number in full response
        for valid_class in [0, 1, 2]:
            if str(valid_class) in response_text:
                return valid_class

        # Enhanced pattern matching
        patterns = [
            r'Classification.*?([012])',
            r'Answer.*?([012])',
            r'Response.*?([012])',
            r'\b([012])\b'
        ]

        for pattern in patterns:
            matches = re.findall(pattern, response_text)
            if matches:
                return int(matches[-1])

        return 1  # Default to neutral if unable to extract

    def classify_text(self, text, max_retries=3):
        """Classify a single text using zero-shot prompting"""
        prompt = self.create_prompt(text)

        for attempt in range(max_retries):
            try:
                # CHANGE: Use Gemini model with generation config instead of fine-tuned model
                generation_config = {
                    "temperature": 0.1,
                    "top_p": 0.8,
                    "top_k": 40,
                    "max_output_tokens": 10,
                }

                response = self.model.generate_content(
                    prompt,
                    generation_config=generation_config
                )

                classification = self.extract_classification(response.text)
                return classification

            except Exception as e:
                print(f"Error in classification attempt {attempt + 1}: {e}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)  # Exponential backoff

        print("All attempts failed, defaulting to class 1 (neutral)")
        return 1  # Default to neutral on error

    def classify_batch(self, texts, delay=1.0):
        """Classify a batch of texts with rate limiting - INCREASED DELAY FOR API LIMITS"""
        results = []
        for i, text in enumerate(texts):
            if i % 50 == 0:
                print(f"Processing {i+1}/{len(texts)} texts...")

            result = self.classify_text(text)
            results.append(result)
            # CHANGE: Increased delay to handle API rate limits for base model
            time.sleep(delay)

        return results

In [None]:
def load_and_prepare_data(file_path, start_date=None, end_date=None):
    """Load TSV data and prepare for processing with optional date filtering"""
    print(f"Loading data from: {file_path}")

    # Read TSV file
    df = pd.read_csv(file_path, sep='\t', encoding='utf-8')
    print(f"Loaded {len(df)} total records")
    print(f"Columns: {df.columns.tolist()}")

    # Convert created_date to datetime
    df['created_date'] = pd.to_datetime(df['created_date'])

    # Show original date range
    print(f"Original date range: {df['created_date'].min()} to {df['created_date'].max()}")

    # Debug: Show some sample dates before filtering
    print(f"Sample dates from original data:")
    sample_dates = df['created_date'].head(10).tolist()
    for i, date in enumerate(sample_dates):
        print(f"  {i+1}: {date}")

    # Filter by date range if specified
    if start_date and end_date:
        start_date = pd.to_datetime(start_date)
        end_date = pd.to_datetime(end_date)
        print(f"Applying filter: {start_date} <= created_date <= {end_date}")
        df = df[(df['created_date'] >= start_date) & (df['created_date'] <= end_date)]
        print(f"Filtered to date range {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}: {len(df)} records")
    elif start_date:
        start_date = pd.to_datetime(start_date)
        print(f"Applying filter: created_date >= {start_date}")
        df = df[df['created_date'] >= start_date]
        print(f"Filtered from {start_date.strftime('%Y-%m-%d')}: {len(df)} records")
    elif end_date:
        end_date = pd.to_datetime(end_date)
        print(f"Applying filter: created_date <= {end_date}")
        df = df[df['created_date'] <= end_date]
        print(f"Filtered to {end_date.strftime('%Y-%m-%d')}: {len(df)} records")

    # Add year_month column for grouping
    df['year_month'] = df['created_date'].dt.to_period('M')

    print(f"Final date range: {df['created_date'].min()} to {df['created_date'].max()}")

    # Debug: Show the unique year_month values to verify filtering
    unique_months = sorted(df['year_month'].unique())
    print(f"Unique year_month values after filtering: {unique_months}")

    return df

def sample_monthly_data(df, max_samples_per_month=800):
  """Sample up to max_samples_per_month records per month"""
  sampled_dfs = []

  for year_month, group in df.groupby('year_month'):
      if len(group) <= max_samples_per_month:
          sampled_group = group
      else:
          sampled_group = group.sample(n=max_samples_per_month, random_state=42)

      sampled_dfs.append(sampled_group)
      print(f"{year_month}: sampled {len(sampled_group)} out of {len(group)} records")

  result_df = pd.concat(sampled_dfs, ignore_index=True)
  print(f"Total sampled records: {len(result_df)}")

  return result_df

In [None]:
def process_and_classify_data(df, classifier, output_path, spreadsheet_name, worksheet_name):
    """Process data month by month and classify texts"""
    # Sort by date
    df = df.sort_values('created_date')

    # Initialize Google Sheets connection
    try:
        # Use Google Colab authentication
        from google.colab import auth
        auth.authenticate_user()

        # Get default credentials
        from google.auth import default
        creds, _ = default()

        # Create gspread client
        gc = gspread.authorize(creds)

        # Open spreadsheet
        spreadsheet = gc.open(spreadsheet_name)
        worksheet = spreadsheet.worksheet(worksheet_name)

        # Check if worksheet has data and add headers if empty
        existing_data = worksheet.get_all_values()
        if not existing_data:
            headers = ['year_month', 'deflation', 'neither', 'inflation', 'total_number', 'average_score']
            worksheet.append_row(headers)

        print(f"Connected to existing spreadsheet: {spreadsheet_name} - {worksheet_name}")
    except gspread.exceptions.SpreadsheetNotFound:
        print(f"Error: Spreadsheet '{spreadsheet_name}' not found. Please check the name.")
        worksheet = None
    except gspread.exceptions.WorksheetNotFound:
        print(f"Error: Worksheet '{worksheet_name}' not found. Please check the name.")
        worksheet = None
    except Exception as e:
        print(f"Error connecting to Google Sheet: {type(e).__name__}: {str(e)}")
        worksheet = None

    # Group by year_month for processing
    for year_month, month_group in df.groupby('year_month'):
        print(f"\nProcessing {year_month}...")

        # Create a copy to avoid modifying original
        month_df = month_group.copy()

        # Classify texts
        texts = month_df['body'].fillna('').astype(str).tolist()
        classifications = classifier.classify_batch(texts)

        # Add classifications to dataframe
        month_df['inflation'] = classifications

        # Remove year_month column before saving to TSV
        month_df_tsv = month_df.drop('year_month', axis=1)

        # Save to TSV file (append mode)
        if os.path.exists(output_path):
            month_df_tsv.to_csv(output_path, sep='\t', mode='a', header=False, index=False)
        else:
            month_df_tsv.to_csv(output_path, sep='\t', mode='w', header=True, index=False)

        # Calculate monthly summary
        deflation_count = len(month_df[month_df['inflation'] == 0])
        neither_count = len(month_df[month_df['inflation'] == 1])
        inflation_count = len(month_df[month_df['inflation'] == 2])
        total_count = len(month_df)
        avg_score = (deflation_count * 0 + neither_count * 1 + inflation_count * 2) / total_count

        # Output to standard output
        print(f"Month: {year_month}")
        print(f"  Deflation (0): {deflation_count}")
        print(f"  Neither (1): {neither_count}")
        print(f"  Inflation (2): {inflation_count}")
        print(f"  Total: {total_count}")
        print(f"  Average Score: {avg_score:.4f}")

        # Add to Google Sheet
        if worksheet:
            try:
                row_data = [str(year_month), deflation_count, neither_count, inflation_count, total_count, round(avg_score, 4)]
                worksheet.append_row(row_data)
                print(f"  Added to Google Sheet: {spreadsheet_name}")
            except Exception as e:
                print(f"  Error updating Google Sheet: {e}")

        print(f"Saved {len(month_df)} records for {year_month}")

    print(f"\nAll data processed and saved to: {output_path}")

In [None]:
def main():
    """Main execution function"""
    # File paths
    input_file = "/content/drive/MyDrive/multilingual-economic-narratives/data/test/greece_comments.tsv"
    output_file = "/content/drive/MyDrive/multilingual-economic-narratives/data/test/greece_comments_results.tsv"

    # Google Sheets configuration
    spreadsheet_name = "Classification_Test"
    worksheet_name = "greece"

    # Date range configuration
    start_date = "2016-01-01"
    end_date = "2022-12-31"

    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    # Initialize classifier with model instance instead of endpoint and client
    classifier = RedditInflationClassifier(model)

    # Load and prepare data with date filtering
    df = load_and_prepare_data(input_file, start_date, end_date)

    # Sample monthly data
    sampled_df = sample_monthly_data(df, max_samples_per_month=300)

    print(f"Using Google Spreadsheet: {spreadsheet_name}")
    print(f"Using Worksheet: {worksheet_name}")
    if start_date:
        print(f"Date range: {start_date} to {end_date if end_date else 'latest'}")

    # Process and classify data with real-time Google Sheets updates
    process_and_classify_data(sampled_df, classifier, output_file, spreadsheet_name, worksheet_name)

    print("\n" + "="*60)
    print("ZERO-SHOT PROCESS COMPLETED SUCCESSFULLY")
    print("="*60)
    print(f"Results saved to: {output_file}")
    print(f"Monthly summaries updated in Google Sheets: {spreadsheet_name} - {worksheet_name}")

# Execute the main function
if __name__ == "__main__":
    main()