## Setup and Imports

In [None]:
# Install required packages (uncomment if needed)
# !pip install groq python-dotenv pandas scikit-learn tqdm matplotlib

In [None]:
import pandas as pd
import json
import os
import re
from tqdm import tqdm
import numpy as np
from collections import Counter
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix
)
import matplotlib.pyplot as plt
import seaborn as sns

print("✓ Libraries imported")

## Data Loading

In [None]:
# Load the dataset
df = pd.read_csv('../data/aggregated_data.csv')

# Create label mappings
unique_labels = sorted(df['label_sexist'].unique().tolist())
label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}

print(f"Total samples: {len(df)}")
print(f"Labels: {unique_labels}")
print(f"\nLanguage distribution:")
print(df['lang'].value_counts())
print(f"\nSplit distribution:")
print(df['split'].value_counts())
print(f"\nLabel distribution:")
print(df['label_sexist'].value_counts())

In [None]:
# Split data
train_df = df[df['split'] == 'train'].copy()
dev_df = df[df['split'] == 'dev'].copy()
test_df = df[df['split'] == 'test'].copy()

print(f"Train: {len(train_df)} samples")
print(f"Dev: {len(dev_df)} samples")
print(f"Test: {len(test_df)} samples")

# Show example
print(f"\nExample:")
print(f"  Text: {test_df.iloc[0]['text']}")
print(f"  Label: {test_df.iloc[0]['label_sexist']}")
print(f"  Language: {test_df.iloc[0]['lang']}")

## Initialize Groq API

In [None]:
from groq import Groq
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Initialize Groq client
client = Groq()

# Test the connection
test_completion = client.chat.completions.create(
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello! Just testing the connection."}
    ],
    model="llama-3.3-70b-versatile",
    temperature=0.5,
    max_completion_tokens=50,
)

print("✓ Groq API connection successful")
print(f"Test response: {test_completion.choices[0].message.content}")

## Define Models and Prompts Configuration

In [None]:
# Define models to test
models_to_test = [
    "llama-3.3-70b-versatile",
    "llama-3.1-70b-versatile",
    "mixtral-8x7b-32768",
]

print("Models to test:")
for i, model in enumerate(models_to_test, 1):
    print(f"  {i}. {model}")

In [None]:
# Define different prompt strategies
prompt_templates = {
    "simple": {
        "system": "You are a sexism detection system. Classify tweets as 'sexist' or 'not sexist'.",
        "user": "Classify this tweet as 'sexist' or 'not sexist'. Return ONLY the label, nothing else.\n\nTweet: {text}\n\nLabel:"
    },
    
    "detailed": {
        "system": "You are an expert in detecting sexism in social media content. Your task is to classify tweets as 'sexist' or 'not sexist' based on their content.",
        "user": """Classify the following tweet as either 'sexist' or 'not sexist'.

A tweet is 'sexist' if it:
- Contains gender-based discrimination or stereotyping
- Objectifies or demeans individuals based on gender
- Promotes gender-based violence or inequality
- Uses derogatory language targeting a specific gender

Tweet: {text}

Respond with ONLY 'sexist' or 'not sexist', nothing else.

Classification:"""
    },
    
    "structured": {
        "system": "You are a classifier for detecting sexism in social media posts. You must respond with ONLY 'sexist' or 'not sexist'.",
        "user": """Task: Binary classification of sexism
Input: \"{text}\"
Output (ONLY 'sexist' or 'not sexist'):"""
    }
}

print("Prompt strategies defined:")
for name in prompt_templates.keys():
    print(f"  - {name}")

## Classification Function

In [None]:
def classify_with_llm(texts, model_name, prompt_template, temperature=0.3, max_retries=3):
    """
    Classify texts using LLM via Groq API
    
    Args:
        texts: List of text strings to classify
        model_name: Name of the Groq model to use
        prompt_template: Dictionary with 'system' and 'user' prompt templates
        temperature: Sampling temperature (lower = more deterministic)
        max_retries: Maximum number of retries for failed requests
    
    Returns:
        List of predicted labels
    """
    predictions = []
    
    for text in tqdm(texts, desc=f"Classifying with {model_name}"):
        retry_count = 0
        predicted_label = None
        
        while retry_count < max_retries and predicted_label is None:
            try:
                # Create the prompt
                user_prompt = prompt_template["user"].format(text=text)
                
                # Call the API
                chat_completion = client.chat.completions.create(
                    messages=[
                        {"role": "system", "content": prompt_template["system"]},
                        {"role": "user", "content": user_prompt}
                    ],
                    model=model_name,
                    temperature=temperature,
                    max_completion_tokens=10,  # We only need a short response
                )
                
                # Parse the response
                response_text = chat_completion.choices[0].message.content.strip().lower()
                
                # Extract label from response
                if 'sexist' in response_text and 'not sexist' not in response_text:
                    predicted_label = 'sexist'
                elif 'not sexist' in response_text or 'non-sexist' in response_text or 'nonsexist' in response_text:
                    predicted_label = 'not sexist'
                else:
                    # Fallback: try to extract any valid label
                    if 'sexist' in response_text:
                        predicted_label = 'sexist'
                    else:
                        predicted_label = 'not sexist'  # Default to not sexist if unclear
                
            except Exception as e:
                retry_count += 1
                if retry_count >= max_retries:
                    print(f"\nError after {max_retries} retries: {e}")
                    predicted_label = 'not sexist'  # Default fallback
        
        predictions.append(predicted_label)
    
    return predictions

print("✓ Classification function defined")

## Test on Small Sample

First, let's test on a small sample to verify the approach works.

In [None]:
# Test on first 10 dev samples
test_sample = dev_df.head(10)
test_texts = test_sample['text'].tolist()
test_labels = test_sample['label_sexist'].tolist()

print("Testing on 10 samples...")
sample_predictions = classify_with_llm(
    test_texts,
    model_name="llama-3.3-70b-versatile",
    prompt_template=prompt_templates["simple"],
    temperature=0.3
)

print("\nSample Results:")
print("=" * 80)
for i, (text, true_label, pred_label) in enumerate(zip(test_texts, test_labels, sample_predictions)):
    match = "✓" if true_label == pred_label else "✗"
    print(f"{match} Text: {text[:60]}...")
    print(f"  True: {true_label:12s} | Predicted: {pred_label}")
    print("-" * 80)

# Calculate accuracy on sample
sample_acc = accuracy_score(test_labels, sample_predictions)
print(f"\nSample Accuracy: {sample_acc:.4f}")

## Run Full Evaluation on Dev Set

Now evaluate all model and prompt combinations on the development set.

In [None]:
# For faster testing, you can limit the dev set size
# Comment out this line to use the full dev set
# dev_df_eval = dev_df.head(100)
dev_df_eval = dev_df.copy()

dev_texts = dev_df_eval['text'].tolist()
dev_labels = dev_df_eval['label_sexist'].tolist()

print(f"Evaluating on {len(dev_texts)} dev samples")
print(f"Testing {len(models_to_test)} models × {len(prompt_templates)} prompts = {len(models_to_test) * len(prompt_templates)} configurations\n")

In [None]:
# Store all results
all_results = {}

# Evaluate each combination
for model_name in models_to_test:
    for prompt_name, prompt_template in prompt_templates.items():
        config_name = f"{model_name}_{prompt_name}"
        
        print("\n" + "=" * 80)
        print(f"Configuration: {config_name}")
        print("=" * 80)
        
        # Get predictions
        predictions = classify_with_llm(
            dev_texts,
            model_name=model_name,
            prompt_template=prompt_template,
            temperature=0.3
        )
        
        # Calculate metrics
        precision, recall, f1, _ = precision_recall_fscore_support(
            dev_labels, predictions, average='weighted'
        )
        acc = accuracy_score(dev_labels, predictions)
        
        # Store results
        all_results[config_name] = {
            'model': model_name,
            'prompt': prompt_name,
            'predictions': predictions,
            'accuracy': acc,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }
        
        print(f"\nResults:")
        print(f"  Accuracy:  {acc:.4f}")
        print(f"  Precision: {precision:.4f}")
        print(f"  Recall:    {recall:.4f}")
        print(f"  F1-Score:  {f1:.4f}")

print("\n" + "=" * 80)
print("All configurations evaluated!")
print("=" * 80)

## Compare Results

In [None]:
# Create comparison dataframe
comparison_data = []
for config_name, results in all_results.items():
    comparison_data.append({
        'Configuration': config_name,
        'Model': results['model'],
        'Prompt': results['prompt'],
        'Accuracy': results['accuracy'],
        'Precision': results['precision'],
        'Recall': results['recall'],
        'F1-Score': results['f1']
    })

comparison_df = pd.DataFrame(comparison_data)
comparison_df = comparison_df.sort_values('F1-Score', ascending=False)

print("=" * 100)
print("COMPARISON OF ALL CONFIGURATIONS")
print("=" * 100)
print(comparison_df.to_string(index=False))
print("=" * 100)

# Best configuration
best_config = comparison_df.iloc[0]
print(f"\nBest Configuration: {best_config['Configuration']}")
print(f"  Model: {best_config['Model']}")
print(f"  Prompt: {best_config['Prompt']}")
print(f"  F1-Score: {best_config['F1-Score']:.4f}")

## Visualization

In [None]:
# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Sort by F1-score for better visualization
sorted_df = comparison_df.sort_values('F1-Score', ascending=True)

# Color map by prompt type
prompt_colors = {'simple': '#1f77b4', 'detailed': '#ff7f0e', 'structured': '#2ca02c'}
colors = [prompt_colors[p] for p in sorted_df['Prompt']]

# 1. F1-Score comparison
axes[0, 0].barh(range(len(sorted_df)), sorted_df['F1-Score'], color=colors, alpha=0.8)
axes[0, 0].set_yticks(range(len(sorted_df)))
axes[0, 0].set_yticklabels([f"{row['Model'].split('/')[-1][:15]}\n({row['Prompt']})" 
                             for _, row in sorted_df.iterrows()], fontsize=8)
axes[0, 0].set_xlabel('F1-Score', fontsize=11)
axes[0, 0].set_title('F1-Score by Configuration', fontsize=13, fontweight='bold')
axes[0, 0].set_xlim([0, 1])
axes[0, 0].grid(True, alpha=0.3, axis='x')
for i, (_, row) in enumerate(sorted_df.iterrows()):
    axes[0, 0].text(row['F1-Score'] + 0.01, i, f"{row['F1-Score']:.3f}", 
                    va='center', fontsize=9)

# 2. Accuracy comparison
axes[0, 1].barh(range(len(sorted_df)), sorted_df['Accuracy'], color=colors, alpha=0.8)
axes[0, 1].set_yticks(range(len(sorted_df)))
axes[0, 1].set_yticklabels([f"{row['Model'].split('/')[-1][:15]}\n({row['Prompt']})" 
                             for _, row in sorted_df.iterrows()], fontsize=8)
axes[0, 1].set_xlabel('Accuracy', fontsize=11)
axes[0, 1].set_title('Accuracy by Configuration', fontsize=13, fontweight='bold')
axes[0, 1].set_xlim([0, 1])
axes[0, 1].grid(True, alpha=0.3, axis='x')
for i, (_, row) in enumerate(sorted_df.iterrows()):
    axes[0, 1].text(row['Accuracy'] + 0.01, i, f"{row['Accuracy']:.3f}", 
                    va='center', fontsize=9)

# 3. Metrics by Model
models_unique = comparison_df['Model'].unique()
x = np.arange(len(models_unique))
width = 0.2

for i, prompt in enumerate(['simple', 'detailed', 'structured']):
    prompt_data = comparison_df[comparison_df['Prompt'] == prompt]
    f1_scores = [prompt_data[prompt_data['Model'] == m]['F1-Score'].values[0] 
                 if len(prompt_data[prompt_data['Model'] == m]) > 0 else 0 
                 for m in models_unique]
    axes[1, 0].bar(x + i*width, f1_scores, width, label=prompt, alpha=0.8)

axes[1, 0].set_ylabel('F1-Score', fontsize=11)
axes[1, 0].set_title('F1-Score by Model and Prompt', fontsize=13, fontweight='bold')
axes[1, 0].set_xticks(x + width)
axes[1, 0].set_xticklabels([m.split('/')[-1][:20] for m in models_unique], 
                           rotation=45, ha='right', fontsize=9)
axes[1, 0].legend()
axes[1, 0].set_ylim([0, 1])
axes[1, 0].grid(True, alpha=0.3, axis='y')

# 4. All metrics for best configuration
best_config_name = comparison_df.iloc[0]['Configuration']
best_results = all_results[best_config_name]
metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
values = [best_results['accuracy'], best_results['precision'], 
          best_results['recall'], best_results['f1']]

axes[1, 1].bar(metrics, values, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'], alpha=0.8)
axes[1, 1].set_ylabel('Score', fontsize=11)
axes[1, 1].set_title(f'Best Configuration Metrics\n({best_config_name})', 
                     fontsize=13, fontweight='bold')
axes[1, 1].set_ylim([0, 1])
axes[1, 1].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(values):
    axes[1, 1].text(i, v + 0.02, f'{v:.4f}', ha='center', fontsize=10)

plt.tight_layout()
plt.savefig('llm_zero_shot_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nVisualization saved as 'llm_zero_shot_comparison.png'")

## Detailed Analysis of Best Model

In [None]:
# Get best model predictions
best_config_name = comparison_df.iloc[0]['Configuration']
best_predictions = all_results[best_config_name]['predictions']

print("=" * 80)
print(f"DETAILED ANALYSIS: {best_config_name}")
print("=" * 80)

# Classification report
print("\nClassification Report:")
print("-" * 80)
print(classification_report(dev_labels, best_predictions, 
                          target_names=['not sexist', 'sexist']))

# Confusion matrix
cm = confusion_matrix(dev_labels, best_predictions, labels=['not sexist', 'sexist'])
print("\nConfusion Matrix:")
print(cm)

# Visualize confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['not sexist', 'sexist'],
            yticklabels=['not sexist', 'sexist'])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title(f'Confusion Matrix - {best_config_name}')
plt.tight_layout()
plt.savefig('confusion_matrix_best_llm.png', dpi=300, bbox_inches='tight')
plt.show()

## Save Results

In [None]:
# Save comparison results
comparison_df.to_csv('llm_zero_shot_comparison_results.csv', index=False)
print("✓ Comparison results saved to 'llm_zero_shot_comparison_results.csv'")

# Save best model predictions
best_config_name = comparison_df.iloc[0]['Configuration']
best_predictions = all_results[best_config_name]['predictions']

output_data = []
for idx, (_, row) in enumerate(dev_df_eval.iterrows()):
    output_data.append({
        'id': row['id'],
        'text': row['text'],
        'true_label': dev_labels[idx],
        'predicted_label': best_predictions[idx],
        'correct': dev_labels[idx] == best_predictions[idx]
    })

output_df = pd.DataFrame(output_data)
output_df.to_csv(f'predictions_llm_{best_config_name}.csv', index=False)
print(f"✓ Best model predictions saved to 'predictions_llm_{best_config_name}.csv'")

## Example Predictions Analysis

In [None]:
# Show some example predictions
print("Example Predictions from Best Model:")
print("=" * 100)

# Show correct predictions
print("\n✓ CORRECT PREDICTIONS (Sexist):")
print("-" * 100)
count = 0
for i, (text, true, pred) in enumerate(zip(dev_texts, dev_labels, best_predictions)):
    if true == pred and true == 'sexist' and count < 3:
        print(f"Text: {text[:80]}...")
        print(f"Label: {true}\n")
        count += 1

print("\n✓ CORRECT PREDICTIONS (Not Sexist):")
print("-" * 100)
count = 0
for i, (text, true, pred) in enumerate(zip(dev_texts, dev_labels, best_predictions)):
    if true == pred and true == 'not sexist' and count < 3:
        print(f"Text: {text[:80]}...")
        print(f"Label: {true}\n")
        count += 1

# Show incorrect predictions
print("\n✗ INCORRECT PREDICTIONS:")
print("-" * 100)
count = 0
for i, (text, true, pred) in enumerate(zip(dev_texts, dev_labels, best_predictions)):
    if true != pred and count < 5:
        print(f"Text: {text[:80]}...")
        print(f"True: {true:12s} | Predicted: {pred}\n")
        count += 1