<a href="https://colab.research.google.com/github/RyuichiSaito1/inflation-reddit-usa/blob/main/notebooks/gemini_2_0_flash_lite_performances.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import auth
from google.cloud import aiplatform
import vertexai
from vertexai.generative_models import GenerativeModel
import time
import re
from typing import List, Tuple
import warnings
warnings.filterwarnings('ignore')

# Zero-shot

In [None]:
from google.colab import auth
from google.cloud import aiplatform

print("Authenticating with Google Cloud...")
auth.authenticate_user()
print("Authentication completed!")

# Initialize Vertex AI
PROJECT_ID = "#####"
LOCATION = "#####"

print("Initializing Vertex AI...")
try:
    vertexai.init(project=PROJECT_ID, location=LOCATION)
    print("Vertex AI initialized successfully!")
except Exception as e:
    print(f"Error initializing Vertex AI: {e}")
    print("Please ensure you have:")
    print("1. Set the correct PROJECT_ID")
    print("2. Enabled Vertex AI API in your project")
    print("3. Have proper permissions")
    raise

# Initialize Gemini 2.0 Flash Lite model
print("Initializing Gemini 2.0 Flash Lite model...")
try:
    model = GenerativeModel("gemini-2.0-flash-lite")
    print("Model initialized successfully!")

    # Test the model with a simple query
    print("Testing model connection...")
    test_response = model.generate_content("Hello, respond with 'OK'")
    print(f"Model test response: {test_response.text}")

except Exception as e:
    print(f"Error initializing model: {e}")
    print("Trying alternative model names...")

    # Try alternative model names
    alternative_models = [
        "gemini-2.0-flash-exp",
        "gemini-1.5-flash",
        "gemini-1.5-pro"
    ]

    model = None
    for model_name in alternative_models:
        try:
            print(f"Trying {model_name}...")
            model = GenerativeModel(model_name)
            test_response = model.generate_content("Hello, respond with 'OK'")
            print(f"Successfully connected to {model_name}")
            print(f"Test response: {test_response.text}")
            break
        except Exception as model_error:
            print(f"Failed to connect to {model_name}: {model_error}")
            continue

    if model is None:
        print("Could not connect to any Gemini model. Please check:")
        print("1. Your project has access to Gemini models")
        print("2. The model name is correct")
        print("3. Vertex AI API is enabled")
        raise Exception("No Gemini model available")

# Load test data
print("Loading test data...")
test_data_path = '/content/drive/MyDrive/world-inflation/data/reddit/production/test-data-200.csv'
df = pd.read_csv(test_data_path)

print(f"Test data shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")
print(f"Class distribution:\n{df['inflation'].value_counts()}")

In [None]:
# Define classification prompt template
def create_classification_prompt(text: str) -> str:
    """Create a prompt for three-class inflation perception classification."""
    prompt = f"""
    You are a chief economist at the IMF. I would like you to infer the public perception of inflation from Reddit posts. Please classify each Reddit post into one of the following categories:

    0: The post indicates deflation, such as the lower price of goods or services (e.g., "the prices are not bad"), affordable services (e.g., "this champagne is cheap and delicious"), sales information (e.g., "you can get it for only 10 dollars."), or a declining and buyer's market.

    2: The post indicates or includes inflation, such as the higher price of goods or services (e.g., "it's not cheap"), the unreasonable cost of goods or services (e.g., "the food is overpriced and cold"), consumers struggling to afford necessities (e.g., "items are too expensive to buy"), shortage of goods of services, or mention about an asset bubble.

    1: The post indicates neither deflation (0) nor inflation (2). This category also includes just questions to a community, social statements not personal experience, factual observations, references to originally expensive or cheap goods or services (e.g., "a gorgeous and costly dinner" or "an affordable Civic"), website promotion, authors' wishes, or illogical text.

    Please choose a stronger stance when the text includes both 0 and 2 stances. If these stances are of the same degree, answer 1.

    Reddit post to classify: "{text}"

    Respond with only the number: 0, 1, or 2
    """
    return prompt

def classify_text(text: str, max_retries: int = 3) -> int:
    """Classify a single text using Gemini model."""
    prompt = create_classification_prompt(text)

    for attempt in range(max_retries):
        try:
            # Add generation config for more reliable responses
            generation_config = {
                "temperature": 0.1,
                "top_p": 0.8,
                "top_k": 40,
                "max_output_tokens": 10,
            }

            response = model.generate_content(
                prompt,
                generation_config=generation_config
            )

            prediction = response.text.strip()

            # Clean and validate prediction - extract first digit
            digits = re.findall(r'\d', prediction)
            if digits:
                pred_num = int(digits[0])
                if pred_num in [0, 1, 2]:
                    return pred_num

            # Try to find valid number in full response
            for valid_class in [0, 1, 2]:
                if str(valid_class) in prediction:
                    return valid_class

            print(f"Invalid prediction: '{prediction}', retrying...")

        except Exception as e:
            print(f"Error in attempt {attempt + 1}: {str(e)}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)  # Exponential backoff

    # Return 1 (neutral) as default if all attempts fail
    print("All attempts failed, defaulting to class 1 (neutral)")
    return 1

def batch_classify(texts: List[str], batch_size: int = 10) -> List[int]:
    """Classify texts in batches with progress tracking."""
    predictions = []
    total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0)

    print(f"Processing {len(texts)} texts in {total_batches} batches...")

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        batch_predictions = []

        for j, text in enumerate(batch):
            try:
                pred = classify_text(text)
                batch_predictions.append(pred)
                print(f"Batch {i//batch_size + 1}/{total_batches}, Item {j+1}/{len(batch)}: {pred}")

                # Add delay to respect API rate limits
                time.sleep(1)

            except Exception as e:
                print(f"Error processing text {i+j}: {str(e)}")
                batch_predictions.append(1)  # Default fallback to neutral

        predictions.extend(batch_predictions)

        # Longer delay between batches
        if i + batch_size < len(texts):
            time.sleep(2)

    return predictions

In [None]:
# Perform classification
print("\nStarting classification...")
start_time = time.time()

# Get predictions for all test samples
test_texts = df['body'].tolist()
predictions = batch_classify(test_texts, batch_size=5)  # Smaller batch size for stability

end_time = time.time()
print(f"\nClassification completed in {end_time - start_time:.2f} seconds")

In [None]:
# Prepare data for evaluation
y_true = df['inflation'].tolist()
y_pred = predictions

# Ensure both lists have the same length
if len(y_true) != len(y_pred):
    min_len = min(len(y_true), len(y_pred))
    y_true = y_true[:min_len]
    y_pred = y_pred[:min_len]
    print(f"Truncated to {min_len} samples for evaluation")

# Calculate metrics
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)

# Overall accuracy
accuracy = accuracy_score(y_true, y_pred)
print(f"Overall Accuracy: {accuracy:.4f}")

# Macro-averaged metrics (average across classes)
precision_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
recall_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)
f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)

print(f"\nMacro-averaged Metrics:")
print(f"Precision: {precision_macro:.4f}")
print(f"Recall: {recall_macro:.4f}")
print(f"F1 Score: {f1_macro:.4f}")

# Micro-averaged metrics (overall across all samples)
precision_micro = precision_score(y_true, y_pred, average='micro', zero_division=0)
recall_micro = recall_score(y_true, y_pred, average='micro', zero_division=0)
f1_micro = f1_score(y_true, y_pred, average='micro', zero_division=0)

print(f"\nMicro-averaged Metrics:")
print(f"Precision: {precision_micro:.4f}")
print(f"Recall: {recall_micro:.4f}")
print(f"F1 Score: {f1_micro:.4f}")

# Per-class metrics
classes = [0, 1, 2]  # Ensure consistent numeric ordering
precision_per_class = precision_score(y_true, y_pred, labels=classes, average=None, zero_division=0)
recall_per_class = recall_score(y_true, y_pred, labels=classes, average=None, zero_division=0)
f1_per_class = f1_score(y_true, y_pred, labels=classes, average=None, zero_division=0)

class_names = {0: "Deflation", 1: "Neutral", 2: "Inflation"}
print(f"\nPer-class Metrics:")
for i, class_num in enumerate(classes):
    print(f"\nClass {class_num} ({class_names[class_num]}):")
    print(f"  Precision: {precision_per_class[i]:.4f}")
    print(f"  Recall: {recall_per_class[i]:.4f}")
    print(f"  F1 Score: {f1_per_class[i]:.4f}")

# Detailed classification report
print(f"\n\nDetailed Classification Report:")
target_names = ["Deflation (0)", "Neutral (1)", "Inflation (2)"]
print(classification_report(y_true, y_pred, labels=classes, target_names=target_names, zero_division=0))

# Confusion Matrix
print(f"\nConfusion Matrix:")
cm = confusion_matrix(y_true, y_pred, labels=classes)
print(cm)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=target_names, yticklabels=target_names)
plt.title('Confusion Matrix - Gemini 2.0 Flash Lite\nInflation Perception Classification')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.tight_layout()
plt.show()

# Create results summary DataFrame
results_summary = pd.DataFrame({
    'Metric': ['Accuracy', 'Macro Precision', 'Macro Recall', 'Macro F1',
               'Micro Precision', 'Micro Recall', 'Micro F1'],
    'Score': [accuracy, precision_macro, recall_macro, f1_macro,
              precision_micro, recall_micro, f1_micro]
})

print(f"\n\nResults Summary:")
print(results_summary.to_string(index=False, float_format='%.4f'))

# Save results to CSV
results_path = '/content/drive/MyDrive/world-inflation/data/reddit/production/gemini_evaluation_results.csv'
results_df = pd.DataFrame({
    'body': test_texts[:len(y_pred)],
    'true_label': y_true,
    'predicted_label': y_pred
})

results_df.to_csv(results_path, index=False)
print(f"\nDetailed results saved to: {results_path}")

# Additional analysis: prediction distribution
print(f"\n\nPrediction Distribution:")
pred_counts = pd.Series(y_pred).value_counts()
true_counts = pd.Series(y_true).value_counts()

comparison_df = pd.DataFrame({
    'True_Count': true_counts,
    'Predicted_Count': pred_counts
}).fillna(0)

print(comparison_df)

# Calculate class-wise accuracy
print(f"\n\nClass-wise Accuracy:")
for class_num in classes:
    class_mask = [i for i, label in enumerate(y_true) if label == class_num]
    if class_mask:
        class_accuracy = sum(1 for i in class_mask if y_pred[i] == y_true[i]) / len(class_mask)
        print(f"Class {class_num} ({class_names[class_num]}): {class_accuracy:.4f}")
    else:
        print(f"Class {class_num} ({class_names[class_num]}): No samples in test set")

print(f"\n{'='*50}")
print("EVALUATION COMPLETED SUCCESSFULLY!")
print(f"{'='*50}")


# Fine-tuning

In [None]:
# Install required packages
!pip install google-genai pandas scikit-learn

# Import libraries
import pandas as pd
import numpy as np
from google import genai
from google.genai.types import HttpOptions
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
import re
import time

# Authenticate with Google Cloud (run this first)
from google.colab import auth
auth.authenticate_user()

# 1,040

In [None]:
# Initialize client for Vertex AI
project_id = "#####"
location = "#####"
client = genai.Client(
    vertexai=True,
    project=project_id,
    location=location,
    http_options=HttpOptions(api_version="v1")
)

tuning_job_name = "#####"

# Get the tuning job and the tuned model
tuning_job = client.tunings.get(name=tuning_job_name)

# Load test data
print("Loading test data...")
test_data_path = "/content/drive/MyDrive/world-inflation/data/reddit/production/test-data-200.csv"
df = pd.read_csv(test_data_path)
print(f"Loaded {len(df)} test samples")
print(f"Columns: {df.columns.tolist()}")
print(f"First few rows:")
print(df.head())

In [None]:
# Define the prompt template
def create_prompt(reddit_post):
    return f"""You are a chief economist at the IMF. I would like you to infer the public perception of inflation from Reddit posts. Please classify each Reddit post into one of the following categories:
0: The post indicates deflation, such as the lower price of goods or services (e.g., "the prices are not bad"), affordable services (e.g., "this champagne is cheap and delicious"), sales information (e.g., "you can get it for only 10 dollars."), or a declining and buyer's market.
2: The post indicates or includes inflation, such as the higher price of goods or services (e.g., "it's not cheap"), the unreasonable cost of goods or services (e.g., "the food is overpriced and cold"), consumers struggling to afford necessities (e.g., "items are too expensive to buy"), shortage of goods of services, or mention about an asset bubble.
1: The post indicates neither deflation (0) nor inflation (2). This category also includes just questions to a community, social statements not personal experience, factual observations, references to originally expensive or cheap goods or services (e.g., "a gorgeous and costly dinner" or "an affordable Civic"), website promotion, authors' wishes, or illogical text.
Please choose a stronger stance when the text includes both 0 and 2 stances. If these stances are of the same degree, answer 1.
Reddit Post:
'{reddit_post}'
Classification (0, 1, or 2):"""

# Function to extract classification from response
def extract_classification(response_text):
    """Extract classification number from model response"""
    # Look for patterns like "Classification: 2" or just "2" at the end
    patterns = [
        r'Classification.*?([012])',
        r'Answer.*?([012])',
        r'Response.*?([012])',
        r'\b([012])\b'
    ]

    for pattern in patterns:
        matches = re.findall(pattern, response_text)
        if matches:
            return int(matches[-1])  # Take the last match

    # If no clear pattern, look for the first occurrence of 0, 1, or 2
    for char in response_text:
        if char in ['0', '1', '2']:
            return int(char)

    return None  # Unable to extract classification

# Function to evaluate model
def evaluate_model(model_endpoint, test_df, model_name="Model"):
    print(f"\n{'='*60}")
    print(f"🔹 Evaluating {model_name}")
    print(f"{'='*60}")

    predictions = []
    true_labels = []
    failed_predictions = 0

    text_column = 'text' if 'text' in test_df.columns else test_df.columns[0]
    label_column = 'label' if 'label' in test_df.columns else test_df.columns[1]

    print(f"Using text column: '{text_column}' and label column: '{label_column}'")

    for idx, row in test_df.iterrows():
        reddit_post = row[text_column]
        true_label = row[label_column]

        try:
            # Create prompt
            prompt = create_prompt(reddit_post)

            # Generate prediction
            response = client.models.generate_content(
                model=model_endpoint,
                contents=prompt,
            )

            # Extract classification
            predicted_label = extract_classification(response.text)

            if predicted_label is not None:
                predictions.append(predicted_label)
                true_labels.append(true_label)

                if (idx + 1) % 10 == 0:
                    print(f"Processed {idx + 1}/{len(test_df)} samples...")
            else:
                failed_predictions += 1
                print(f"Failed to extract classification from response: {response.text[:100]}...")

        except Exception as e:
            failed_predictions += 1
            print(f"Error processing sample {idx + 1}: {e}")

        # Add small delay to avoid rate limiting
        time.sleep(0.1)

    # Calculate metrics
    if len(predictions) > 0:
        accuracy = accuracy_score(true_labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted', zero_division=0)

        print(f"\n📊 RESULTS for {model_name}:")
        print(f"Total samples processed: {len(predictions)}/{len(test_df)}")
        print(f"Failed predictions: {failed_predictions}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-Score: {f1:.4f}")

        # Detailed classification report
        print(f"\n📋 Detailed Classification Report:")
        print(classification_report(true_labels, predictions, target_names=['Deflation (0)', 'Neutral (1)', 'Inflation (2)']))

        # Confusion Matrix
        print(f"\n🔢 Confusion Matrix:")
        cm = confusion_matrix(true_labels, predictions)
        print("True\\Pred    0    1    2")
        for i, row in enumerate(cm):
            print(f"    {i}    {row[0]:4d} {row[1]:4d} {row[2]:4d}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'predictions': predictions,
            'true_labels': true_labels,
            'failed_predictions': failed_predictions
        }
    else:
        print(f"❌ No valid predictions obtained for {model_name}")
        return None

In [None]:
# Test with the default checkpoint (tuned model endpoint)
default_results = evaluate_model(tuning_job.tuned_model.endpoint, df, "DEFAULT Checkpoint")

print(f"\n✅ Evaluation completed!")
print(f"{'='*80}")

# 520

In [None]:
# Initialize client for Vertex AI
project_id = "#####"
location = "#####"
client = genai.Client(
    vertexai=True,
    project=project_id,
    location=location,
    http_options=HttpOptions(api_version="v1")
)

tuning_job_name = "#####"

# Get the tuning job and the tuned model
tuning_job = client.tunings.get(name=tuning_job_name)

# Load test data
print("Loading test data...")
test_data_path = "/content/drive/MyDrive/world-inflation/data/reddit/production/test-data-200.csv"
df = pd.read_csv(test_data_path)
print(f"Loaded {len(df)} test samples")
print(f"Columns: {df.columns.tolist()}")
print(f"First few rows:")
print(df.head())

In [None]:
# Define the prompt template
def create_prompt(reddit_post):
    return f"""You are a chief economist at the IMF. I would like you to infer the public perception of inflation from Reddit posts. Please classify each Reddit post into one of the following categories:
0: The post indicates deflation, such as the lower price of goods or services (e.g., "the prices are not bad"), affordable services (e.g., "this champagne is cheap and delicious"), sales information (e.g., "you can get it for only 10 dollars."), or a declining and buyer's market.
2: The post indicates or includes inflation, such as the higher price of goods or services (e.g., "it's not cheap"), the unreasonable cost of goods or services (e.g., "the food is overpriced and cold"), consumers struggling to afford necessities (e.g., "items are too expensive to buy"), shortage of goods of services, or mention about an asset bubble.
1: The post indicates neither deflation (0) nor inflation (2). This category also includes just questions to a community, social statements not personal experience, factual observations, references to originally expensive or cheap goods or services (e.g., "a gorgeous and costly dinner" or "an affordable Civic"), website promotion, authors' wishes, or illogical text.
Please choose a stronger stance when the text includes both 0 and 2 stances. If these stances are of the same degree, answer 1.
Reddit Post:
'{reddit_post}'
Classification (0, 1, or 2):"""

# Function to extract classification from response
def extract_classification(response_text):
    """Extract classification number from model response"""
    # Look for patterns like "Classification: 2" or just "2" at the end
    patterns = [
        r'Classification.*?([012])',
        r'Answer.*?([012])',
        r'Response.*?([012])',
        r'\b([012])\b'
    ]

    for pattern in patterns:
        matches = re.findall(pattern, response_text)
        if matches:
            return int(matches[-1])  # Take the last match

    # If no clear pattern, look for the first occurrence of 0, 1, or 2
    for char in response_text:
        if char in ['0', '1', '2']:
            return int(char)

    return None  # Unable to extract classification

# Function to evaluate model
def evaluate_model(model_endpoint, test_df, model_name="Model"):
    print(f"\n{'='*60}")
    print(f"🔹 Evaluating {model_name}")
    print(f"{'='*60}")

    predictions = []
    true_labels = []
    failed_predictions = 0

    text_column = 'text' if 'text' in test_df.columns else test_df.columns[0]
    label_column = 'label' if 'label' in test_df.columns else test_df.columns[1]

    print(f"Using text column: '{text_column}' and label column: '{label_column}'")

    for idx, row in test_df.iterrows():
        reddit_post = row[text_column]
        true_label = row[label_column]

        try:
            # Create prompt
            prompt = create_prompt(reddit_post)

            # Generate prediction
            response = client.models.generate_content(
                model=model_endpoint,
                contents=prompt,
            )

            # Extract classification
            predicted_label = extract_classification(response.text)

            if predicted_label is not None:
                predictions.append(predicted_label)
                true_labels.append(true_label)

                if (idx + 1) % 10 == 0:
                    print(f"Processed {idx + 1}/{len(test_df)} samples...")
            else:
                failed_predictions += 1
                print(f"Failed to extract classification from response: {response.text[:100]}...")

        except Exception as e:
            failed_predictions += 1
            print(f"Error processing sample {idx + 1}: {e}")

        # Add small delay to avoid rate limiting
        time.sleep(0.1)

    # Calculate metrics
    if len(predictions) > 0:
        accuracy = accuracy_score(true_labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted', zero_division=0)

        print(f"\n📊 RESULTS for {model_name}:")
        print(f"Total samples processed: {len(predictions)}/{len(test_df)}")
        print(f"Failed predictions: {failed_predictions}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-Score: {f1:.4f}")

        # Detailed classification report
        print(f"\n📋 Detailed Classification Report:")
        print(classification_report(true_labels, predictions, target_names=['Deflation (0)', 'Neutral (1)', 'Inflation (2)']))

        # Confusion Matrix
        print(f"\n🔢 Confusion Matrix:")
        cm = confusion_matrix(true_labels, predictions)
        print("True\\Pred    0    1    2")
        for i, row in enumerate(cm):
            print(f"    {i}    {row[0]:4d} {row[1]:4d} {row[2]:4d}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'predictions': predictions,
            'true_labels': true_labels,
            'failed_predictions': failed_predictions
        }
    else:
        print(f"❌ No valid predictions obtained for {model_name}")
        return None

In [None]:
# Test with the default checkpoint (tuned model endpoint)
default_results = evaluate_model(tuning_job.tuned_model.endpoint, df, "DEFAULT Checkpoint")

print(f"\n✅ Evaluation completed!")
print(f"{'='*80}")

# 260

In [None]:
# Initialize client for Vertex AI
project_id = "#####"
location = "#####"
client = genai.Client(
    vertexai=True,
    project=project_id,
    location=location,
    http_options=HttpOptions(api_version="v1")
)

tuning_job_name = "#####"

# Get the tuning job and the tuned model
tuning_job = client.tunings.get(name=tuning_job_name)

# Load test data
print("Loading test data...")
test_data_path = "/content/drive/MyDrive/world-inflation/data/reddit/production/test-data-200.csv"
df = pd.read_csv(test_data_path)
print(f"Loaded {len(df)} test samples")
print(f"Columns: {df.columns.tolist()}")
print(f"First few rows:")
print(df.head())

In [None]:
# Define the prompt template
def create_prompt(reddit_post):
    return f"""You are a chief economist at the IMF. I would like you to infer the public perception of inflation from Reddit posts. Please classify each Reddit post into one of the following categories:
0: The post indicates deflation, such as the lower price of goods or services (e.g., "the prices are not bad"), affordable services (e.g., "this champagne is cheap and delicious"), sales information (e.g., "you can get it for only 10 dollars."), or a declining and buyer's market.
2: The post indicates or includes inflation, such as the higher price of goods or services (e.g., "it's not cheap"), the unreasonable cost of goods or services (e.g., "the food is overpriced and cold"), consumers struggling to afford necessities (e.g., "items are too expensive to buy"), shortage of goods of services, or mention about an asset bubble.
1: The post indicates neither deflation (0) nor inflation (2). This category also includes just questions to a community, social statements not personal experience, factual observations, references to originally expensive or cheap goods or services (e.g., "a gorgeous and costly dinner" or "an affordable Civic"), website promotion, authors' wishes, or illogical text.
Please choose a stronger stance when the text includes both 0 and 2 stances. If these stances are of the same degree, answer 1.
Reddit Post:
'{reddit_post}'
Classification (0, 1, or 2):"""

# Function to extract classification from response
def extract_classification(response_text):
    """Extract classification number from model response"""
    # Look for patterns like "Classification: 2" or just "2" at the end
    patterns = [
        r'Classification.*?([012])',
        r'Answer.*?([012])',
        r'Response.*?([012])',
        r'\b([012])\b'
    ]

    for pattern in patterns:
        matches = re.findall(pattern, response_text)
        if matches:
            return int(matches[-1])  # Take the last match

    # If no clear pattern, look for the first occurrence of 0, 1, or 2
    for char in response_text:
        if char in ['0', '1', '2']:
            return int(char)

    return None  # Unable to extract classification

# Function to evaluate model
def evaluate_model(model_endpoint, test_df, model_name="Model"):
    print(f"\n{'='*60}")
    print(f"🔹 Evaluating {model_name}")
    print(f"{'='*60}")

    predictions = []
    true_labels = []
    failed_predictions = 0

    text_column = 'text' if 'text' in test_df.columns else test_df.columns[0]
    label_column = 'label' if 'label' in test_df.columns else test_df.columns[1]

    print(f"Using text column: '{text_column}' and label column: '{label_column}'")

    for idx, row in test_df.iterrows():
        reddit_post = row[text_column]
        true_label = row[label_column]

        try:
            # Create prompt
            prompt = create_prompt(reddit_post)

            # Generate prediction
            response = client.models.generate_content(
                model=model_endpoint,
                contents=prompt,
            )

            # Extract classification
            predicted_label = extract_classification(response.text)

            if predicted_label is not None:
                predictions.append(predicted_label)
                true_labels.append(true_label)

                if (idx + 1) % 10 == 0:
                    print(f"Processed {idx + 1}/{len(test_df)} samples...")
            else:
                failed_predictions += 1
                print(f"Failed to extract classification from response: {response.text[:100]}...")

        except Exception as e:
            failed_predictions += 1
            print(f"Error processing sample {idx + 1}: {e}")

        # Add small delay to avoid rate limiting
        time.sleep(0.1)

    # Calculate metrics
    if len(predictions) > 0:
        accuracy = accuracy_score(true_labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted', zero_division=0)

        print(f"\n📊 RESULTS for {model_name}:")
        print(f"Total samples processed: {len(predictions)}/{len(test_df)}")
        print(f"Failed predictions: {failed_predictions}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-Score: {f1:.4f}")

        # Detailed classification report
        print(f"\n📋 Detailed Classification Report:")
        print(classification_report(true_labels, predictions, target_names=['Deflation (0)', 'Neutral (1)', 'Inflation (2)']))

        # Confusion Matrix
        print(f"\n🔢 Confusion Matrix:")
        cm = confusion_matrix(true_labels, predictions)
        print("True\\Pred    0    1    2")
        for i, row in enumerate(cm):
            print(f"    {i}    {row[0]:4d} {row[1]:4d} {row[2]:4d}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'predictions': predictions,
            'true_labels': true_labels,
            'failed_predictions': failed_predictions
        }
    else:
        print(f"❌ No valid predictions obtained for {model_name}")
        return None

In [None]:
# Test with the default checkpoint (tuned model endpoint)
default_results = evaluate_model(tuning_job.tuned_model.endpoint, df, "DEFAULT Checkpoint")

print(f"\n✅ Evaluation completed!")
print(f"{'='*80}")

# 130

In [None]:
# Initialize client for Vertex AI
project_id = "#####"
location = "#####"
client = genai.Client(
    vertexai=True,
    project=project_id,
    location=location,
    http_options=HttpOptions(api_version="v1")
)

tuning_job_name = "#####"

# Get the tuning job and the tuned model
tuning_job = client.tunings.get(name=tuning_job_name)

# Load test data
print("Loading test data...")
test_data_path = "/content/drive/MyDrive/world-inflation/data/reddit/production/test-data-200.csv"
df = pd.read_csv(test_data_path)
print(f"Loaded {len(df)} test samples")
print(f"Columns: {df.columns.tolist()}")
print(f"First few rows:")
print(df.head())

In [None]:
# Define the prompt template
def create_prompt(reddit_post):
    return f"""You are a chief economist at the IMF. I would like you to infer the public perception of inflation from Reddit posts. Please classify each Reddit post into one of the following categories:
0: The post indicates deflation, such as the lower price of goods or services (e.g., "the prices are not bad"), affordable services (e.g., "this champagne is cheap and delicious"), sales information (e.g., "you can get it for only 10 dollars."), or a declining and buyer's market.
2: The post indicates or includes inflation, such as the higher price of goods or services (e.g., "it's not cheap"), the unreasonable cost of goods or services (e.g., "the food is overpriced and cold"), consumers struggling to afford necessities (e.g., "items are too expensive to buy"), shortage of goods of services, or mention about an asset bubble.
1: The post indicates neither deflation (0) nor inflation (2). This category also includes just questions to a community, social statements not personal experience, factual observations, references to originally expensive or cheap goods or services (e.g., "a gorgeous and costly dinner" or "an affordable Civic"), website promotion, authors' wishes, or illogical text.
Please choose a stronger stance when the text includes both 0 and 2 stances. If these stances are of the same degree, answer 1.
Reddit Post:
'{reddit_post}'
Classification (0, 1, or 2):"""

# Function to extract classification from response
def extract_classification(response_text):
    """Extract classification number from model response"""
    # Look for patterns like "Classification: 2" or just "2" at the end
    patterns = [
        r'Classification.*?([012])',
        r'Answer.*?([012])',
        r'Response.*?([012])',
        r'\b([012])\b'
    ]

    for pattern in patterns:
        matches = re.findall(pattern, response_text)
        if matches:
            return int(matches[-1])  # Take the last match

    # If no clear pattern, look for the first occurrence of 0, 1, or 2
    for char in response_text:
        if char in ['0', '1', '2']:
            return int(char)

    return None  # Unable to extract classification

# Function to evaluate model
def evaluate_model(model_endpoint, test_df, model_name="Model"):
    print(f"\n{'='*60}")
    print(f"🔹 Evaluating {model_name}")
    print(f"{'='*60}")

    predictions = []
    true_labels = []
    failed_predictions = 0

    text_column = 'text' if 'text' in test_df.columns else test_df.columns[0]
    label_column = 'label' if 'label' in test_df.columns else test_df.columns[1]

    print(f"Using text column: '{text_column}' and label column: '{label_column}'")

    for idx, row in test_df.iterrows():
        reddit_post = row[text_column]
        true_label = row[label_column]

        try:
            # Create prompt
            prompt = create_prompt(reddit_post)

            # Generate prediction
            response = client.models.generate_content(
                model=model_endpoint,
                contents=prompt,
            )

            # Extract classification
            predicted_label = extract_classification(response.text)

            if predicted_label is not None:
                predictions.append(predicted_label)
                true_labels.append(true_label)

                if (idx + 1) % 10 == 0:
                    print(f"Processed {idx + 1}/{len(test_df)} samples...")
            else:
                failed_predictions += 1
                print(f"Failed to extract classification from response: {response.text[:100]}...")

        except Exception as e:
            failed_predictions += 1
            print(f"Error processing sample {idx + 1}: {e}")

        # Add small delay to avoid rate limiting
        time.sleep(0.1)

    # Calculate metrics
    if len(predictions) > 0:
        accuracy = accuracy_score(true_labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted', zero_division=0)

        print(f"\n📊 RESULTS for {model_name}:")
        print(f"Total samples processed: {len(predictions)}/{len(test_df)}")
        print(f"Failed predictions: {failed_predictions}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-Score: {f1:.4f}")

        # Detailed classification report
        print(f"\n📋 Detailed Classification Report:")
        print(classification_report(true_labels, predictions, target_names=['Deflation (0)', 'Neutral (1)', 'Inflation (2)']))

        # Confusion Matrix
        print(f"\n🔢 Confusion Matrix:")
        cm = confusion_matrix(true_labels, predictions)
        print("True\\Pred    0    1    2")
        for i, row in enumerate(cm):
            print(f"    {i}    {row[0]:4d} {row[1]:4d} {row[2]:4d}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'predictions': predictions,
            'true_labels': true_labels,
            'failed_predictions': failed_predictions
        }
    else:
        print(f"❌ No valid predictions obtained for {model_name}")
        return None

In [None]:
# Test with the default checkpoint (tuned model endpoint)
default_results = evaluate_model(tuning_job.tuned_model.endpoint, df, "DEFAULT Checkpoint")

print(f"\n✅ Evaluation completed!")
print(f"{'='*80}")

# 65

In [None]:
# Initialize client for Vertex AI
project_id = "#####"
location = "#####"
client = genai.Client(
    vertexai=True,
    project=project_id,
    location=location,
    http_options=HttpOptions(api_version="v1")
)

tuning_job_name = "#####"

# Get the tuning job and the tuned model
tuning_job = client.tunings.get(name=tuning_job_name)

# Load test data
print("Loading test data...")
test_data_path = "/content/drive/MyDrive/world-inflation/data/reddit/production/test-data-200.csv"
df = pd.read_csv(test_data_path)
print(f"Loaded {len(df)} test samples")
print(f"Columns: {df.columns.tolist()}")
print(f"First few rows:")
print(df.head())

In [None]:
# Define the prompt template
def create_prompt(reddit_post):
    return f"""You are a chief economist at the IMF. I would like you to infer the public perception of inflation from Reddit posts. Please classify each Reddit post into one of the following categories:
0: The post indicates deflation, such as the lower price of goods or services (e.g., "the prices are not bad"), affordable services (e.g., "this champagne is cheap and delicious"), sales information (e.g., "you can get it for only 10 dollars."), or a declining and buyer's market.
2: The post indicates or includes inflation, such as the higher price of goods or services (e.g., "it's not cheap"), the unreasonable cost of goods or services (e.g., "the food is overpriced and cold"), consumers struggling to afford necessities (e.g., "items are too expensive to buy"), shortage of goods of services, or mention about an asset bubble.
1: The post indicates neither deflation (0) nor inflation (2). This category also includes just questions to a community, social statements not personal experience, factual observations, references to originally expensive or cheap goods or services (e.g., "a gorgeous and costly dinner" or "an affordable Civic"), website promotion, authors' wishes, or illogical text.
Please choose a stronger stance when the text includes both 0 and 2 stances. If these stances are of the same degree, answer 1.
Reddit Post:
'{reddit_post}'
Classification (0, 1, or 2):"""

# Function to extract classification from response
def extract_classification(response_text):
    """Extract classification number from model response"""
    # Look for patterns like "Classification: 2" or just "2" at the end
    patterns = [
        r'Classification.*?([012])',
        r'Answer.*?([012])',
        r'Response.*?([012])',
        r'\b([012])\b'
    ]

    for pattern in patterns:
        matches = re.findall(pattern, response_text)
        if matches:
            return int(matches[-1])  # Take the last match

    # If no clear pattern, look for the first occurrence of 0, 1, or 2
    for char in response_text:
        if char in ['0', '1', '2']:
            return int(char)

    return None  # Unable to extract classification

# Function to evaluate model
def evaluate_model(model_endpoint, test_df, model_name="Model"):
    print(f"\n{'='*60}")
    print(f"🔹 Evaluating {model_name}")
    print(f"{'='*60}")

    predictions = []
    true_labels = []
    failed_predictions = 0

    text_column = 'text' if 'text' in test_df.columns else test_df.columns[0]
    label_column = 'label' if 'label' in test_df.columns else test_df.columns[1]

    print(f"Using text column: '{text_column}' and label column: '{label_column}'")

    for idx, row in test_df.iterrows():
        reddit_post = row[text_column]
        true_label = row[label_column]

        try:
            # Create prompt
            prompt = create_prompt(reddit_post)

            # Generate prediction
            response = client.models.generate_content(
                model=model_endpoint,
                contents=prompt,
            )

            # Extract classification
            predicted_label = extract_classification(response.text)

            if predicted_label is not None:
                predictions.append(predicted_label)
                true_labels.append(true_label)

                if (idx + 1) % 10 == 0:
                    print(f"Processed {idx + 1}/{len(test_df)} samples...")
            else:
                failed_predictions += 1
                print(f"Failed to extract classification from response: {response.text[:100]}...")

        except Exception as e:
            failed_predictions += 1
            print(f"Error processing sample {idx + 1}: {e}")

        # Add small delay to avoid rate limiting
        time.sleep(0.1)

    # Calculate metrics
    if len(predictions) > 0:
        accuracy = accuracy_score(true_labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted', zero_division=0)

        print(f"\n📊 RESULTS for {model_name}:")
        print(f"Total samples processed: {len(predictions)}/{len(test_df)}")
        print(f"Failed predictions: {failed_predictions}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-Score: {f1:.4f}")

        # Detailed classification report
        print(f"\n📋 Detailed Classification Report:")
        print(classification_report(true_labels, predictions, target_names=['Deflation (0)', 'Neutral (1)', 'Inflation (2)']))

        # Confusion Matrix
        print(f"\n🔢 Confusion Matrix:")
        cm = confusion_matrix(true_labels, predictions)
        print("True\\Pred    0    1    2")
        for i, row in enumerate(cm):
            print(f"    {i}    {row[0]:4d} {row[1]:4d} {row[2]:4d}")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'predictions': predictions,
            'true_labels': true_labels,
            'failed_predictions': failed_predictions
        }
    else:
        print(f"❌ No valid predictions obtained for {model_name}")
        return None

In [None]:
# Test with the default checkpoint (tuned model endpoint)
default_results = evaluate_model(tuning_job.tuned_model.endpoint, df, "DEFAULT Checkpoint")

print(f"\n✅ Evaluation completed!")
print(f"{'='*80}")