<a href="https://colab.research.google.com/github/RyuichiSaito1/inflation-reddit-usa/blob/main/src/phi_2_7_performances.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

from google.colab import auth
auth.authenticate_user()

In [None]:
# --------------------------------------------------------------------------
# 1. INSTALL REQUIRED PACKAGES
# --------------------------------------------------------------------------
# Install necessary libraries for model evaluation and data handling.
!pip install transformers==4.44.0 datasets scikit-learn matplotlib torch torchvision torchaudio accelerate bitsandbytes -q

# Fine-tuning model

In [None]:
# --------------------------------------------------------------------------
# 2. IMPORTS & INITIAL SETUP
# --------------------------------------------------------------------------
import torch
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from transformers import (
    AutoTokenizer,
    PhiForSequenceClassification # Import the specific class for Phi-2
)
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# --------------------------------------------------------------------------
# 3. HELPER CLASSES AND FUNCTIONS
# --------------------------------------------------------------------------

class TestDataset(torch.utils.data.Dataset):
    """Custom PyTorch Dataset for handling tokenized test data."""
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        # Retrieve tokenized inputs and convert to tensors
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        # Retrieve the corresponding label and convert to a tensor
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

def read_csv_file(file_path):
    """Reads a CSV file into a pandas DataFrame."""
    try:
        # Assumes header is on the first row, and names the columns
        data = pd.read_csv(file_path, names=['body', 'inflation'], header=0, dtype={'body': 'str', 'inflation': 'int'})
        print(f"Successfully loaded {len(data)} records from {file_path}")
        return data
    except FileNotFoundError:
        print(f"Error: The file at {file_path} was not found.")
        return None
    except Exception as e:
        print(f"An error occurred while reading the CSV: {e}")
        return None

# The prompt must be IDENTICAL to the one used during fine-tuning
INFLATION_PROMPT = """You are a chief economist at the IMF. I would like you to infer the public perception of inflation from Reddit posts. Please classify each Reddit post into one of the following categories: 0: The post indicates deflation, such as the lower price of goods or services (e.g., "the prices are not bad"), affordable services (e.g., "this champagne is cheap and delicious"), sales information (e.g., "you can get it for only 10 dollars."), or a declining and buyer's market. 2: The post indicates or includes inflation, such as the higher price of goods or services (e.g., "it's not cheap"), the unreasonable cost of goods or services (e.g., "the food is overpriced and cold"), consumers struggling to afford necessities (e.g., "items are too expensive to buy"), shortage of goods of services, or mention about an asset bubble. 1: The post indicates neither deflation (0) nor inflation (2). This category also includes just questions to a community, social statements not personal experience, factual observations, references to originally expensive or cheap goods or services (e.g., "a gorgeous and costly dinner" or "an affordable Civic"), website promotion, authors' wishes, or illogical text. Please choose a stronger stance when the text includes both 0 and 2 stances. If these stances are of the same degree, answer 1.
Reddit Post: {post}
Classification:"""

def format_with_prompt(post):
    """Applies the standard prompt format to a text post."""
    return INFLATION_PROMPT.format(post=str(post))

# --------------------------------------------------------------------------
# 4. MAIN EVALUATION SCRIPT
# --------------------------------------------------------------------------

TEST_DATA_PATH = '/content/drive/MyDrive/world-inflation/data/reddit/production/test-data-200.csv'
MODEL_PATH = '/content/drive/MyDrive/world-inflation/data/model/Phi-3.5-fine-tuning/checkpoint-192'

# --- Load and Prepare Data ---
test_data = read_csv_file(TEST_DATA_PATH)

if test_data is not None:
    # Apply the same prompt formatting as used in training
    test_data['formatted_body'] = test_data['body'].apply(format_with_prompt)

    # --- Initialize Tokenizer ---
    print(f"\nInitializing tokenizer for Phi-2...")
    tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2', trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token # Set padding token for Phi-2

    # --- Tokenize Test Data ---
    print("Tokenizing test data...")
    test_encodings = tokenizer(
        test_data['formatted_body'].tolist(),
        truncation=True,
        padding=True,
        max_length=512, # Use the same max_length as in training
        return_tensors="pt"
    )
    test_labels = test_data['inflation'].tolist()
    test_dataset = TestDataset(test_encodings, test_labels)

    # --- Load Fine-Tuned Model ---
    print(f"Loading fine-tuned model from: {MODEL_PATH}")
    try:
        model = PhiForSequenceClassification.from_pretrained(
            MODEL_PATH,
            torch_dtype=torch.bfloat16, # Use bfloat16 for L4 GPU compatibility
            device_map="auto",
            trust_remote_code=True
        )
        model.eval() # Set the model to evaluation mode
        print("Model loaded successfully.")
    except OSError:
        print(f"Error: Model not found at {MODEL_PATH}.")
        print("Please ensure the path is correct and points to a valid checkpoint folder.")
        model = None

    if model:
        # --- Run Evaluation ---
        test_loader = DataLoader(test_dataset, batch_size=8) # Batch size for evaluation
        true_labels = []
        predicted_labels = []

        print("\nStarting evaluation...")
        with torch.no_grad(): # Disable gradient calculations for inference
            for i, batch in enumerate(test_loader):
                inputs = {key: val.to(model.device) for key, val in batch.items() if key != 'labels'}
                labels = batch['labels'].to(model.device)

                outputs = model(**inputs)
                predictions = torch.argmax(outputs.logits, dim=-1)

                true_labels.extend(labels.cpu().numpy())
                predicted_labels.extend(predictions.cpu().numpy())

                if (i + 1) % 10 == 0:
                    print(f"  Processed { (i + 1) * test_loader.batch_size } samples...")

        print("Evaluation finished.")

        # --- Display Results ---
        accuracy = accuracy_score(true_labels, predicted_labels)
        print(f"\nOverall Accuracy: {accuracy:.4f}")

        # Generate and print the classification report
        print("\nClassification Report:")
        # The labels are 0, 1, 2. The target names map them to readable strings.
        class_names = ['Deflation (0)', 'Neutral (1)', 'Inflation (2)']
        report = classification_report(true_labels, predicted_labels, target_names=class_names)
        print(report)

        # Generate and plot the confusion matrix
        print("\nConfusion Matrix:")
        cm = confusion_matrix(true_labels, predicted_labels)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names)
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Confusion Matrix')
        plt.show()

else:
    print("\nEvaluation stopped because the test data could not be loaded.")

# Zero-shot model

In [None]:
# --------------------------------------------------------------------------
# 2. IMPORTS & INITIAL SETUP
# --------------------------------------------------------------------------
import torch
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline
)
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import re

In [None]:
# --------------------------------------------------------------------------
# 3. HELPER FUNCTIONS
# --------------------------------------------------------------------------

def read_csv_file(file_path):
    """Reads a CSV file into a pandas DataFrame."""
    try:
        # Assumes header is on the first row, and names the columns
        data = pd.read_csv(file_path, names=['body', 'inflation'], header=0, dtype={'body': 'str', 'inflation': 'int'})
        print(f"Successfully loaded {len(data)} records from {file_path}")
        return data
    except FileNotFoundError:
        print(f"Error: The file at {file_path} was not found.")
        return None
    except Exception as e:
        print(f"An error occurred while reading the CSV: {e}")
        return None

# The prompt must be IDENTICAL to the one used during fine-tuning
INFLATION_PROMPT = """You are a chief economist at the IMF. I would like you to infer the public perception of inflation from Reddit posts. Please classify each Reddit post into one of the following categories: 0: The post indicates deflation, such as the lower price of goods or services (e.g., "the prices are not bad"), affordable services (e.g., "this champagne is cheap and delicious"), sales information (e.g., "you can get it for only 10 dollars."), or a declining and buyer's market. 2: The post indicates or includes inflation, such as the higher price of goods or services (e.g., "it's not cheap"), the unreasonable cost of goods or services (e.g., "the food is overpriced and cold"), consumers struggling to afford necessities (e.g., "items are too expensive to buy"), shortage of goods of services, or mention about an asset bubble. 1: The post indicates neither deflation (0) nor inflation (2). This category also includes just questions to a community, social statements not personal experience, factual observations, references to originally expensive or cheap goods or services (e.g., "a gorgeous and costly dinner" or "an affordable Civic"), website promotion, authors' wishes, or illogical text. Please choose a stronger stance when the text includes both 0 and 2 stances. If these stances are of the same degree, answer 1.
Reddit Post: {post}
Classification:"""

def format_with_prompt(post):
    """Applies the standard prompt format to a text post."""
    return INFLATION_PROMPT.format(post=str(post))

def extract_classification(response_text):
    """
    Extract the classification (0, 1, or 2) from the model's response.
    Uses multiple strategies to handle various response formats.
    """
    # Strategy 1: Look for the exact number at the end or after "Classification:"
    patterns = [
        r'Classification:\s*([012])',  # "Classification: 0"
        r'Classification:\s*(\d)',     # "Classification: 0" (any digit)
        r'\b([012])\b(?!.*[012])',     # Last occurrence of 0, 1, or 2 in the text
        r'answer is\s*([012])',        # "answer is 0"
        r'category\s*([012])',         # "category 0"
    ]

    for pattern in patterns:
        match = re.search(pattern, response_text, re.IGNORECASE)
        if match:
            classification = int(match.group(1))
            if classification in [0, 1, 2]:
                return classification

    # Strategy 2: Count occurrences of each class and return the most frequent
    counts = {0: len(re.findall(r'\b0\b', response_text)),
              1: len(re.findall(r'\b1\b', response_text)),
              2: len(re.findall(r'\b2\b', response_text))}

    if max(counts.values()) > 0:
        return max(counts, key=counts.get)

    # Strategy 3: Look for keywords indicating the class
    response_lower = response_text.lower()
    if any(word in response_lower for word in ['deflation', 'cheap', 'affordable', 'declining']):
        return 0
    elif any(word in response_lower for word in ['inflation', 'expensive', 'overpriced', 'costly']):
        return 2
    else:
        return 1  # Default to neutral if unclear

def zero_shot_predict(model, tokenizer, formatted_posts, batch_size=4):
    """
    Generate zero-shot predictions for a list of formatted posts.
    """
    predictions = []
    total_posts = len(formatted_posts)

    print(f"Generating predictions for {total_posts} posts...")

    for i in range(0, total_posts, batch_size):
        batch_posts = formatted_posts[i:i+batch_size]
        batch_predictions = []

        for post in batch_posts:
            try:
                # Tokenize the input
                inputs = tokenizer(post, return_tensors="pt", truncation=True, max_length=512)
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

                # Generate response
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=50,  # Limit response length
                        temperature=0.1,    # Low temperature for more deterministic output
                        do_sample=True,
                        pad_token_id=tokenizer.eos_token_id,
                        eos_token_id=tokenizer.eos_token_id
                    )

                # Decode the response
                full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
                # Extract only the generated part (after the prompt)
                response = full_response[len(post):].strip()

                # Extract classification from response
                classification = extract_classification(response)
                batch_predictions.append(classification)

            except Exception as e:
                print(f"Error processing post: {e}")
                batch_predictions.append(1)  # Default to neutral on error

        predictions.extend(batch_predictions)

        # Progress update
        processed = min(i + batch_size, total_posts)
        if processed % 20 == 0 or processed == total_posts:
            print(f"  Processed {processed}/{total_posts} samples...")

    return predictions

In [None]:
# --------------------------------------------------------------------------
# 4. MAIN ZERO-SHOT EVALUATION SCRIPT
# --------------------------------------------------------------------------

TEST_DATA_PATH = '/content/drive/MyDrive/world-inflation/data/reddit/production/test-data-200.csv'

# --- Load and Prepare Data ---
test_data = read_csv_file(TEST_DATA_PATH)

if test_data is not None:
    # Apply the same prompt formatting as used in training
    test_data['formatted_body'] = test_data['body'].apply(format_with_prompt)

    # --- Initialize Tokenizer and Model ---
    print(f"\nInitializing tokenizer and model for Phi-2...")
    tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2', trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load the base Phi-2 model for zero-shot inference
    try:
        model = AutoModelForCausalLM.from_pretrained(
            'microsoft/phi-2',
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        model.eval()
        print("Base Phi-2 model loaded successfully for zero-shot evaluation.")
    except Exception as e:
        print(f"Error loading model: {e}")
        model = None

    if model:
        # --- Run Zero-Shot Evaluation ---
        true_labels = test_data['inflation'].tolist()
        formatted_posts = test_data['formatted_body'].tolist()

        print("\nStarting zero-shot evaluation...")
        predicted_labels = zero_shot_predict(model, tokenizer, formatted_posts, batch_size=4)

        print("Zero-shot evaluation finished.")

        # --- Display Results ---
        accuracy = accuracy_score(true_labels, predicted_labels)
        print(f"\nZero-Shot Overall Accuracy: {accuracy:.4f}")

        # Generate and print the classification report
        print("\nZero-Shot Classification Report:")
        class_names = ['Deflation (0)', 'Neutral (1)', 'Inflation (2)']
        report = classification_report(true_labels, predicted_labels, target_names=class_names)
        print(report)

        # Generate and plot the confusion matrix
        print("\nZero-Shot Confusion Matrix:")
        cm = confusion_matrix(true_labels, predicted_labels)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names)
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Zero-Shot Confusion Matrix - Base Phi-2')
        plt.show()

        # --- Additional Analysis ---
        print("\n" + "="*50)
        print("COMPARISON SUMMARY")
        print("="*50)
        print("This is the zero-shot performance of the base Phi-2 model.")
        print("Compare these results with your fine-tuned model to measure")
        print("the improvement gained from fine-tuning.")

        # Show some example predictions for manual inspection
        print("\n" + "="*50)
        print("SAMPLE PREDICTIONS (First 5)")
        print("="*50)
        for i in range(min(5, len(true_labels))):
            print(f"\nSample {i+1}:")
            print(f"Post: {test_data['body'].iloc[i][:100]}...")
            print(f"True Label: {true_labels[i]} ({class_names[true_labels[i]]})")
            print(f"Predicted: {predicted_labels[i]} ({class_names[predicted_labels[i]]})")
            print(f"Correct: {'✓' if true_labels[i] == predicted_labels[i] else '✗'}")

else:
    print("\nEvaluation stopped because the test data could not be loaded.")