In [None]:
import json
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from collections import Counter
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
warnings.filterwarnings('ignore')

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")

# ----------------------------
# Configuration
# ----------------------------
json_files = {
    "zero_shot": "Zero-Shot_qa_dataset.json",
    "one_shot": "One-Shot_qa_dataset.json", 
    "few_shot": "Few-Shot_qa_dataset.json"
}

QUESTION_TYPES = ["factual", "relationship", "comparative", "inferential"]  
SHOT_TYPES = ["zero_shot", "one_shot", "few_shot"]

# Calculate 5% from each dataset
SAMPLE_PERCENTAGE = 0.05
DATASET_SIZES = {
    "zero_shot": 1127,
    "one_shot": 1080, 
    "few_shot": 1107
}

# Calculate target samples per shot type (5% of each)
SAMPLES_PER_SHOT_TYPE = {
    shot_type: int(size * SAMPLE_PERCENTAGE) 
    for shot_type, size in DATASET_SIZES.items()
}

# Calculate samples per combination (divided equally across 4 question types)
SAMPLES_PER_COMBINATION = {
    shot_type: max(1, samples // len(QUESTION_TYPES))
    for shot_type, samples in SAMPLES_PER_SHOT_TYPE.items()
}

TOTAL_TARGET_SAMPLES = sum(SAMPLES_PER_SHOT_TYPE.values())

print(f" Target: 5% from each dataset")
print(f" Samples per shot type:")
for shot_type, total in SAMPLES_PER_SHOT_TYPE.items():
    per_qtype = SAMPLES_PER_COMBINATION[shot_type]
    print(f" {shot_type}: {total} total ({per_qtype} per question type)")
print(f" Total target samples: {TOTAL_TARGET_SAMPLES}")

# ----------------------------
# Extract and Format Triples for Human Readability
# ----------------------------
def extract_all_triples(obj):
    """Extract all subject-predicate-object triples from nested structures and format for readability."""
    triples = []
    if isinstance(obj, dict):
        if set(obj.keys()) >= {"subject", "predicate", "object"}:
            # Format as: Subject → Predicate → Object
            formatted_triple = f"{obj['subject']} → {obj['predicate']} → {obj['object']}"
            triples.append(formatted_triple)
        for value in obj.values():
            triples.extend(extract_all_triples(value))
    elif isinstance(obj, list):
        for item in obj:
            triples.extend(extract_all_triples(item))
    return triples

def format_kg_for_humans(triples_string):
    """Convert semicolon-separated triples to numbered, readable format."""
    if not triples_string or triples_string.strip() == "":
        return ""
    
    # Split by semicolon and clean up
    triples = [triple.strip() for triple in triples_string.split(";") if triple.strip()]
    
    if not triples:
        return ""
    
    # Format as numbered list with better spacing
    formatted_triples = []
    for i, triple in enumerate(triples, 1):
        # Replace dashes with arrows for better readability
        if " - " in triple:
            parts = triple.split(" - ")
            if len(parts) == 3:
                subject, predicate, object_part = parts
                formatted_triple = f"{i}. {subject.strip()} → {predicate.strip()} → {object_part.strip()}"
                formatted_triples.append(formatted_triple)
            else:
                # Fallback for malformed triples
                formatted_triples.append(f"{i}. {triple}")
        else:
            # Already formatted or different format
            formatted_triples.append(f"{i}. {triple}")
    
    return "\n".join(formatted_triples)

# ----------------------------
# Load Data from JSON Files
# ----------------------------
def load_data(file_path, shot_type):
    """Load and process data from JSON file."""
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f" File not found: {file_path}")
        return pd.DataFrame()
    except json.JSONDecodeError:
        print(f" Invalid JSON in file: {file_path}")
        return pd.DataFrame()
    
    queries = data.get("queries", [])
    rows = []
    skipped_types = Counter()
    total_items = len(queries)
    
    for item in queries:
        qtype = item.get("question_type", "").lower().strip()
        
        # Track what question types we're seeing
        if qtype not in QUESTION_TYPES:
            skipped_types[qtype] += 1
            continue
        
        # Extract triples from the knowledge graph
        triples = extract_all_triples(item)
        raw_source_kg = "; ".join(triples)
        
        # Format KG for human readability
        source_kg = format_kg_for_humans(raw_source_kg)
        
        # Extract source_context from the correct location
        source_context = ""
        
        # The source_context is nested under ground_truth
        if "ground_truth" in item and isinstance(item["ground_truth"], dict):
            source_context = item["ground_truth"].get("source_context", "")
        
        # Fallback: check if it's at root level (for backward compatibility)
        if not source_context:
            source_context = item.get("source_context", "")
        
        rows.append({
            "qa_id": item.get("id", ""),
            "question_type": qtype,
            "shot_type": shot_type,
            "question": item.get("question", ""),
            "answer": item.get("answer", ""),
            "source_context": source_context,
            "source_kg": source_kg
        })
    
    df = pd.DataFrame(rows)
    print(f" {shot_type}: Loaded {len(df)} valid QA pairs out of {total_items} total")
    
    # Debug: Check how many have empty source_context
    empty_context_count = sum(1 for row in rows if not row['source_context'].strip())
    if empty_context_count > 0:
        print(f"   {empty_context_count} entries have empty source_context")
        # Show a sample entry to debug structure
        if rows:
            print(f" Sample entry keys: {list(rows[0].keys())}")
            print(f" Sample source_context length: {len(rows[0]['source_context'])}")
    
    # Show what was skipped
    if skipped_types:
        print(f"   Skipped question types: {dict(skipped_types)}")
    
    # Show distribution of question types
    if len(df) > 0:
        type_counts = df["question_type"].value_counts()
        print(f"  Question type distribution: {dict(type_counts)}")
    
    return df

# ----------------------------
# Stratified Sampling Function
# ----------------------------
def stratified_sample_per_type(df, shot_type, samples_per_type):
    """Sample equal numbers from each question type within a shot type."""
    sampled_rows = []
    
    for qtype in QUESTION_TYPES:
        type_df = df[df["question_type"] == qtype]
        
        if len(type_df) == 0:
            print(f"  {shot_type} - {qtype}: No samples available")
            continue
        
        if len(type_df) < samples_per_type:
            print(f"  {shot_type} - {qtype}: Only {len(type_df)} samples available (need {samples_per_type})")
            # Take all available samples
            sampled = type_df.copy()
        else:
            # Random sample without replacement
            sampled = type_df.sample(n=samples_per_type, random_state=42)
        
        sampled_rows.append(sampled)
        print(f" {shot_type} - {qtype}: Sampled {len(sampled)} samples")
    
    if sampled_rows:
        return pd.concat(sampled_rows, ignore_index=True)
    else:
        return pd.DataFrame()

# ----------------------------
# Load All Data
# ----------------------------
print("\n Loading data from files...")
all_dfs = []

for shot_type, file_path in json_files.items():
    df = load_data(file_path, shot_type)
    if len(df) > 0:
        all_dfs.append(df)

if not all_dfs:
    print(" No valid data loaded from any file!")
    exit(1)

# ----------------------------
# Perform Stratified Sampling
# ----------------------------
print(f"\n Performing stratified sampling (5% from each dataset)...")
sampled_dfs = []

for shot_type, file_path in json_files.items():
    df = load_data(file_path, shot_type)
    if len(df) > 0:
        samples_per_qtype = SAMPLES_PER_COMBINATION[shot_type]
        sampled = stratified_sample_per_type(df, shot_type, samples_per_qtype)
        if len(sampled) > 0:
            sampled_dfs.append(sampled)

# ----------------------------
# Combine and Shuffle
# ----------------------------
if sampled_dfs:
    final_df = pd.concat(sampled_dfs, ignore_index=True)
    
    # Shuffle the final dataset
    final_df = final_df.sample(frac=1, random_state=42).reset_index(drop=True)
    
    # Reorder columns for better readability
    column_order = ["qa_id", "question_type", "shot_type", "question", "answer", "source_context", "source_kg"]
    final_df = final_df[column_order]
    
    # ----------------------------
    # Save to CSV
    # ----------------------------
    output_file = f"QA_Human_Eval_Stratified_5percent.csv"
    final_df.to_csv(output_file, index=False, encoding='utf-8')
    
    # ----------------------------
    # Analysis and Summary
    # ----------------------------
    print("\n STRATIFIED SAMPLING RESULTS:")
    print("="*50)
    
    # Cross-tabulation showing distribution
    cross_tab = pd.crosstab(final_df['shot_type'], final_df['question_type'], margins=True)
    print("\n Distribution Matrix:")
    print(cross_tab)
    
    # Summary statistics
    print(f"\n Summary:")
    print(f"   • Total samples: {len(final_df)}")
    print(f"   • Target samples: {TOTAL_TARGET_SAMPLES}")
    print(f"   • Shot types: {final_df['shot_type'].nunique()}")
    print(f"   • Question types: {final_df['question_type'].nunique()}")
    print(f"   • Combinations covered: {len(final_df.groupby(['shot_type', 'question_type']))}")
    
    # Check balance
    combination_counts = final_df.groupby(['shot_type', 'question_type']).size()
    print(f"\n⚖️  Balance check:")
    print(f"   • Min samples per combination: {combination_counts.min()}")
    print(f"   • Max samples per combination: {combination_counts.max()}")
    print(f"   • Mean samples per combination: {combination_counts.mean():.1f}")
    
    # Check if balance is within expected range for proportional sampling
    max_diff = combination_counts.max() - combination_counts.min()
    if max_diff <= 1:  # Allow for 1 sample difference due to proportional sampling
        print(" Excellent stratified balance achieved!")
    elif max_diff <= 2:
        print(" Good stratified balance achieved!")
    else:
        print("  Significant imbalance detected - check data distribution")
    
    print(f"\n Saved to: {output_file}")
    print(f" Stratified sampling completed successfully!")
    
    # ----------------------------
    # Create Visualizations
    # ----------------------------
    print(f"\n Creating visualizations...")
    
    # Create a figure with multiple subplots
    fig = plt.figure(figsize=(16, 12))
    
    # 1. Distribution Matrix Heatmap
    plt.subplot(2, 3, 1)
    cross_tab_no_margins = pd.crosstab(final_df['shot_type'], final_df['question_type'])
    sns.heatmap(cross_tab_no_margins, annot=True, fmt='d', cmap='Blues', 
                cbar_kws={'label': 'Number of Samples'})
    plt.title('Sample Distribution Matrix\n(Shot Type × Question Type)', fontsize=12, fontweight='bold')
    plt.xlabel('Question Type', fontweight='bold')
    plt.ylabel('Shot Type', fontweight='bold')
    
    # 2. Samples per Shot Type
    plt.subplot(2, 3, 2)
    shot_counts = final_df['shot_type'].value_counts()
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    bars = plt.bar(shot_counts.index, shot_counts.values, color=colors, alpha=0.8, edgecolor='black')
    plt.title('Samples per Shot Type\n(5% from each dataset)', fontsize=12, fontweight='bold')
    plt.xlabel('Shot Type', fontweight='bold')
    plt.ylabel('Number of Samples', fontweight='bold')
    plt.xticks(rotation=45)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{int(height)}', ha='center', va='bottom', fontweight='bold')
    
    # 3. Samples per Question Type
    plt.subplot(2, 3, 3)
    question_counts = final_df['question_type'].value_counts()
    colors = ['#96CEB4', '#FECA57', '#FF9FF3', '#54A0FF']
    bars = plt.bar(question_counts.index, question_counts.values, color=colors, alpha=0.8, edgecolor='black')
    plt.title('Samples per Question Type\n(Equal distribution)', fontsize=12, fontweight='bold')
    plt.xlabel('Question Type', fontweight='bold')
    plt.ylabel('Number of Samples', fontweight='bold')
    plt.xticks(rotation=45)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{int(height)}', ha='center', va='bottom', fontweight='bold')
    
    # 4. Stacked Bar Chart
    plt.subplot(2, 3, 4)
    pivot_data = final_df.pivot_table(index='shot_type', columns='question_type', 
                                     values='qa_id', aggfunc='count', fill_value=0)
    pivot_data.plot(kind='bar', stacked=True, ax=plt.gca(), 
                   color=['#96CEB4', '#FECA57', '#FF9FF3', '#54A0FF'], alpha=0.8)
    plt.title('Stacked Distribution\n(Question Types within Shot Types)', fontsize=12, fontweight='bold')
    plt.xlabel('Shot Type', fontweight='bold')
    plt.ylabel('Number of Samples', fontweight='bold')
    plt.legend(title='Question Type', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    
    # 5. Comparison: Original vs Sampled
    plt.subplot(2, 3, 5)
    original_sizes = [DATASET_SIZES['zero_shot'], DATASET_SIZES['one_shot'], DATASET_SIZES['few_shot']]
    sampled_sizes = [SAMPLES_PER_SHOT_TYPE['zero_shot'], SAMPLES_PER_SHOT_TYPE['one_shot'], SAMPLES_PER_SHOT_TYPE['few_shot']]
    shot_types_labels = ['Zero-Shot', 'One-Shot', 'Few-Shot']
    
    x = np.arange(len(shot_types_labels))
    width = 0.35
    
    bars1 = plt.bar(x - width/2, original_sizes, width, label='Original Dataset', color='lightcoral', alpha=0.8)
    bars2 = plt.bar(x + width/2, sampled_sizes, width, label='5% Sample', color='lightblue', alpha=0.8)
    
    plt.title('Dataset Size Comparison\n(Original vs 5% Sample)', fontsize=12, fontweight='bold')
    plt.xlabel('Shot Type', fontweight='bold')
    plt.ylabel('Number of Samples', fontweight='bold')
    plt.xticks(x, shot_types_labels)
    plt.legend()
    plt.yscale('log')  # Log scale to show both large and small numbers
    
    # Add value labels
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height * 1.1,
                    f'{int(height)}', ha='center', va='bottom', fontsize=9)
    
    # 6. Pie Chart of Question Type Distribution
    plt.subplot(2, 3, 6)
    question_counts = final_df['question_type'].value_counts()
    colors = ['#96CEB4', '#FECA57', '#FF9FF3', '#54A0FF']
    plt.pie(question_counts.values, labels=question_counts.index, autopct='%1.1f%%', 
            colors=colors, startangle=90)
    plt.title('Question Type Distribution\n(Percentage of Total Sample)', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the visualization
    viz_filename = "QA_Stratified_Sampling_Visualization.png"
    plt.savefig(viz_filename, dpi=300, bbox_inches='tight')
    print(f" Visualizations saved to: {viz_filename}")
    
    # Show the plot
    plt.show()
    
    # ----------------------------
    # Create Summary Statistics Table
    # ----------------------------
    print(f"\n Creating summary statistics table...")
    
    # Create detailed statistics
    summary_stats = []
    for shot_type in SHOT_TYPES:
        for qtype in QUESTION_TYPES:
            count = len(final_df[(final_df['shot_type'] == shot_type) & (final_df['question_type'] == qtype)])
            original_total = DATASET_SIZES[shot_type]
            percentage = (count / len(final_df)) * 100
            
            summary_stats.append({
                'Shot Type': shot_type,
                'Question Type': qtype,
                'Samples': count,
                'Original Dataset Size': original_total,
                'Percentage of Total Sample': f"{percentage:.1f}%"
            })
    
    stats_df = pd.DataFrame(summary_stats)
    
    # Save summary statistics
    stats_filename = "QA_Sampling_Summary_Statistics.csv"
    stats_df.to_csv(stats_filename, index=False)
    print(f" Summary statistics saved to: {stats_filename}")
    
    # Display the table
    print(f"\n Detailed Sampling Statistics:")
    print(stats_df.to_string(index=False))
    
    # Show first few rows as preview
    print(f"\n👀 Preview of first 5 rows:")
    print(final_df.head()[['qa_id', 'question_type', 'shot_type', 'question']].to_string())

else:
    print(" No samples could be extracted from any file!")