<a href="https://colab.research.google.com/github/RyuichiSaito1/multilingual-economic-narratives/blob/main/notebooks/classification_test_using_gpt_5_nano.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

!pip install openai pandas gspread google-auth

print("Installation completed")

Mounted at /content/drive
Installation completed!


In [2]:
import pandas as pd
import numpy as np
from datetime import datetime
import time
import re
import os
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

from google.colab import auth
from openai import OpenAI
import gspread
from google.auth import default

print("All libraries imported successfully")

All libraries imported successfully


In [3]:
print("Setting up OpenAI API...")

OPENAI_API_KEY = "key"

client = OpenAI(api_key=OPENAI_API_KEY)
print("OpenAI client initialized successfully!")

Setting up OpenAI API...
OpenAI client initialized successfully!


In [5]:
print("Testing GPT-5 nano model connection...")

try:
    test_response = client.chat.completions.create(
        model="gpt-5-nano",
        messages=[{"role": "user", "content": "Hello, respond with 'OK'"}],
        max_completion_tokens=10
    )
    print(f"Model test response: {test_response.choices[0].message.content}")
    print("Successfully connected to GPT-5 nano")
except Exception as e:
    print(f"Error connecting to GPT-5 nano: {e}")
    print("Note: If GPT-5 nano is not available, you may need to check the model name or your API access")
    raise

Testing GPT-5 nano model connection...
Model test response: 
Successfully connected to GPT-5 nano


In [11]:
class RedditEconomicSentimentClassifier:
    def __init__(self, client):

        self.client = client

        self.prompt_template = """
Vous êtes l'économiste en chef du FMI. Je souhaiterais que vous déduisiez les perceptions économiques des citoyens à partir des publications Reddit. Veuillez classer chaque publication Reddit dans l'une des catégories suivantes :

0 : La publication traite (1) de l'achat ou de la vente de biens ou de services, (2) des budgets ou des actifs des ménages ou des particuliers, (3) des salaires ou de l'emploi, ou (4) de la conjoncture économique. Dans ce cas, la publication n'exprime pas de sentiment négatif ou indique un sentiment positif.

2 : La publication traite (1) de l'achat ou de la vente de biens ou de services, (2) des budgets ou des actifs des ménages ou des particuliers, (3) des salaires ou de l'emploi, ou (4) de la conjoncture économique. Dans ce cas, la publication exprime un sentiment négatif.

#1 : La publication n'aborde pas (1) l'achat ou la vente de biens ou de services, (2) les budgets ou les actifs des ménages ou des particuliers, (3) les salaires ou l'emploi, ou (4) la conjoncture économique.

# Veuillez choisir une position plus forte lorsque le texte contient à la fois 0 et 2 positions. Si ces positions sont de force égale, répondez 1.

Message Reddit à classer : "{text}"

Répondez uniquement avec le chiffre : 0, 1 ou 2
"""

    def create_prompt(self, text):
        return self.prompt_template.format(text=text)

    def extract_classification(self, response_text):
        """Extract classification number from model response"""
        # Clean the response
        response_text = response_text.strip()

        # Look for digits in the response
        digits = re.findall(r'\d', response_text)
        if digits:
            pred_num = int(digits[0])
            if pred_num in [0, 1, 2]:
                return pred_num

        # Try to find valid number in full response
        for valid_class in [0, 1, 2]:
            if str(valid_class) in response_text:
                return valid_class

        # Enhanced pattern matching
        patterns = [
            r'Classification.*?([012])',
            r'Answer.*?([012])',
            r'Response.*?([012])',
            r'\b([012])\b'
        ]

        for pattern in patterns:
            matches = re.findall(pattern, response_text)
            if matches:
                return int(matches[-1])

        return 1  # Default to neutral if unable to extract

    def classify_text(self, text, max_retries=3):
        """Classify a single text using zero-shot prompting"""
        prompt = self.create_prompt(text)

        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-5-nano",
                    messages=[
                        {"role": "system", "content": "You are an expert economic sentiment classifier."},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0.1,
                    top_p=0.8,
                    max_tokens=10
                )

                response_text = response.choices[0].message.content
                classification = self.extract_classification(response_text)
                return classification

            except Exception as e:
                print(f"Error in classification attempt {attempt + 1}: {e}")
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)  # Exponential backoff

        print("All attempts failed, defaulting to class 1 (neutral)")
        return 1  # Default to neutral on error

    def classify_batch(self, texts, delay=1.0):
        """Classify a batch of texts with rate limiting"""
        results = []
        for i, text in enumerate(texts):
            if i % 50 == 0:
                print(f"Processing {i+1}/{len(texts)} texts...")

            result = self.classify_text(text)
            results.append(result)
            # Delay adjusted for OpenAI API rate limits
            time.sleep(delay)

        return results

print("RedditEconomicSentimentClassifier class defined successfully")

RedditEconomicSentimentClassifier class defined successfully


In [12]:
def load_and_prepare_data(file_path, start_date=None, end_date=None):
    """Load TSV data and prepare for processing with optional date filtering"""
    print(f"Loading data from: {file_path}")

    # Read TSV file
    df = pd.read_csv(file_path, sep='\t', encoding='utf-8')
    print(f"Loaded {len(df)} total records")
    print(f"Columns: {df.columns.tolist()}")

    # Convert created_date to datetime
    df['created_date'] = pd.to_datetime(df['created_date'])

    # Show original date range
    print(f"Original date range: {df['created_date'].min()} to {df['created_date'].max()}")

    # Filter by date range if specified
    if start_date and end_date:
        start_date = pd.to_datetime(start_date)
        end_date = pd.to_datetime(end_date)
        print(f"Applying filter: {start_date} <= created_date <= {end_date}")
        df = df[(df['created_date'] >= start_date) & (df['created_date'] <= end_date)]
        print(f"Filtered to date range {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}: {len(df)} records")
    elif start_date:
        start_date = pd.to_datetime(start_date)
        print(f"Applying filter: created_date >= {start_date}")
        df = df[df['created_date'] >= start_date]
        print(f"Filtered from {start_date.strftime('%Y-%m-%d')}: {len(df)} records")
    elif end_date:
        end_date = pd.to_datetime(end_date)
        print(f"Applying filter: created_date <= {end_date}")
        df = df[df['created_date'] <= end_date]
        print(f"Filtered to {end_date.strftime('%Y-%m-%d')}: {len(df)} records")

    # Add year_month column for grouping
    df['year_month'] = df['created_date'].dt.to_period('M')

    print(f"Final date range: {df['created_date'].min()} to {df['created_date'].max()}")

    # Debug: Show the unique year_month values to verify filtering
    unique_months = sorted(df['year_month'].unique())
    print(f"Unique year_month values after filtering: {unique_months}")

    return df

def sample_monthly_data(df, max_samples_per_month=800):
    """Sample up to max_samples_per_month records per month"""
    sampled_dfs = []

    for year_month, group in df.groupby('year_month'):
        if len(group) <= max_samples_per_month:
            sampled_group = group
        else:
            sampled_group = group.sample(n=max_samples_per_month, random_state=42)

        sampled_dfs.append(sampled_group)
        print(f"{year_month}: sampled {len(sampled_group)} out of {len(group)} records")

    result_df = pd.concat(sampled_dfs, ignore_index=True)
    print(f"Total sampled records: {len(result_df)}")

    return result_df

def process_and_classify_data(df, classifier, output_path, spreadsheet_name, worksheet_name):
    """Process data month by month and classify texts"""
    # Sort by date
    df = df.sort_values('created_date')

    # Initialize Google Sheets connection
    try:
        # Use Google Colab authentication
        from google.colab import auth
        auth.authenticate_user()

        # Get default credentials
        from google.auth import default
        creds, _ = default()

        # Create gspread client
        gc = gspread.authorize(creds)

        # Open spreadsheet
        spreadsheet = gc.open(spreadsheet_name)
        worksheet = spreadsheet.worksheet(worksheet_name)

        # Check if worksheet has data and add headers if empty
        existing_data = worksheet.get_all_values()
        if not existing_data:
            headers = ['year_month', 'positive', 'neutral', 'negative', 'total_number', 'average_score']
            worksheet.append_row(headers)

        print(f"Connected to existing spreadsheet: {spreadsheet_name} - {worksheet_name}")
    except gspread.exceptions.SpreadsheetNotFound:
        print(f"Error: Spreadsheet '{spreadsheet_name}' not found. Please check the name.")
        worksheet = None
    except gspread.exceptions.WorksheetNotFound:
        print(f"Error: Worksheet '{worksheet_name}' not found. Please check the name.")
        worksheet = None
    except Exception as e:
        print(f"Error connecting to Google Sheet: {type(e).__name__}: {str(e)}")
        worksheet = None

    # Group by year_month for processing
    for year_month, month_group in df.groupby('year_month'):
        print(f"\nProcessing {year_month}...")

        # Create a copy to avoid modifying original
        month_df = month_group.copy()

        # Classify texts
        texts = month_df['body'].fillna('').astype(str).tolist()
        classifications = classifier.classify_batch(texts)

        # Add classifications to dataframe
        month_df['sentiment'] = classifications

        # Remove year_month column before saving to TSV
        month_df_tsv = month_df.drop('year_month', axis=1)

        # Save to TSV file (append mode)
        if os.path.exists(output_path):
            month_df_tsv.to_csv(output_path, sep='\t', mode='a', header=False, index=False)
        else:
            month_df_tsv.to_csv(output_path, sep='\t', mode='w', header=True, index=False)

        # Calculate monthly summary
        positive_count = len(month_df[month_df['sentiment'] == 0])
        neutral_count = len(month_df[month_df['sentiment'] == 1])
        negative_count = len(month_df[month_df['sentiment'] == 2])
        total_count = len(month_df)
        avg_score = (positive_count * 0 + neutral_count * 1 + negative_count * 2) / total_count

        # Output to standard output
        print(f"Month: {year_month}")
        print(f"  Positive (0): {positive_count}")
        print(f"  Neutral (1): {neutral_count}")
        print(f"  Negative (2): {negative_count}")
        print(f"  Total: {total_count}")
        print(f"  Average Score: {avg_score:.4f}")

        # Add to Google Sheet
        if worksheet:
            try:
                row_data = [str(year_month), positive_count, neutral_count, negative_count, total_count, round(avg_score, 4)]
                worksheet.append_row(row_data)
                print(f"  Added to Google Sheet: {spreadsheet_name}")
            except Exception as e:
                print(f"  Error updating Google Sheet: {e}")

        print(f"Saved {len(month_df)} records for {year_month}")

    print(f"\nAll data processed and saved to: {output_path}")

print("Helper functions defined successfully")


Helper functions defined successfully


In [13]:
# File paths
input_file = "/content/drive/MyDrive/multilingual-economic-narratives/data/test/france_comments.tsv"
output_file = "/content/drive/MyDrive/multilingual-economic-narratives/data/test/france_comments_results.tsv"

# Google Sheets configuration
spreadsheet_name = "Classification_Test"
worksheet_name = "france_20251010"

# Date range configuration
start_date = "2020-03-01"
end_date = "2020-04-30"

# Maximum samples per month
max_samples_per_month = 400

print("Configuration completed")
print(f"Input file: {input_file}")
print(f"Output file: {output_file}")
print(f"Spreadsheet: {spreadsheet_name} - {worksheet_name}")
print(f"Date range: {start_date} to {end_date}")
print(f"Max samples per month: {max_samples_per_month}")

Configuration completed
Input file: /content/drive/MyDrive/multilingual-economic-narratives/data/test/france_comments.tsv
Output file: /content/drive/MyDrive/multilingual-economic-narratives/data/test/france_comments_results.tsv
Spreadsheet: Classification_Test - france_20251010
Date range: 2020-03-01 to 2020-04-30
Max samples per month: 400


In [18]:
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(output_file), exist_ok=True)
print("Output directory created/verified")

classifier = RedditEconomicSentimentClassifier(client)
print("Classifier initialized successfully")

df = load_and_prepare_data(input_file, start_date, end_date)
print(f"\nData loaded successfully Shape: {df.shape}")

sampled_df = sample_monthly_data(df, max_samples_per_month=max_samples_per_month)
print(f"\nData sampling completed Final shape: {sampled_df.shape}")


Output directory created/verified
Classifier initialized successfully
Loading data from: /content/drive/MyDrive/multilingual-economic-narratives/data/test/france_comments.tsv
Loaded 1447845 total records
Columns: ['created_date', 'subreddit_id', 'id', 'author', 'parent_id', 'body', 'score']
Original date range: 2016-01-01 00:52:55 to 2022-12-31 23:59:09
Applying filter: 2020-03-01 00:00:00 <= created_date <= 2020-04-30 00:00:00
Filtered to date range 2020-03-01 to 2020-04-30: 44275 records
Final date range: 2020-03-01 00:05:40 to 2020-04-29 23:58:01
Unique year_month values after filtering: [Period('2020-03', 'M'), Period('2020-04', 'M')]

Data loaded successfully Shape: (44275, 8)
2020-03: sampled 400 out of 22554 records
2020-04: sampled 400 out of 21721 records
Total sampled records: 800

Data sampling completed Final shape: (800, 8)


In [None]:
print("\n" + "="*60)
print("STARTING CLASSIFICATION PROCESS")
print("="*60)
print(f"Using Google Spreadsheet: {spreadsheet_name}")
print(f"Using Worksheet: {worksheet_name}")
print(f"Date range: {start_date} to {end_date}")
print("="*60 + "\n")

# Process and classify data with real-time Google Sheets updates
process_and_classify_data(sampled_df, classifier, output_file, spreadsheet_name, worksheet_name)

print("\n" + "="*60)
print("ZERO-SHOT PROCESS COMPLETED SUCCESSFULLY")
print("="*60)
print(f"Results saved to: {output_file}")
print(f"Monthly summaries updated in Google Sheets: {spreadsheet_name} - {worksheet_name}")
print("="*60)