In [27]:
# Import Required Libraries
import os
import json
from collections import Counter
import matplotlib.pyplot as plt
import pandas as pd
from transformers import DebertaV2TokenizerFast

def load_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def analyze_data(data, tokenizer):
    num_instances = len(data)
    token_lengths = [len(tokenizer.tokenize(instance['tokenized_text'], is_split_into_words=True, add_special_tokens=False)) for instance in data]
    num_entities = [len(instance['ner']) for instance in data]
    entity_types = Counter([entity[2] for instance in data for entity in instance['ner']])
    negative_types = Counter([neg for instance in data for neg in instance.get('negatives', [])])

    return {
        'num_instances': num_instances,
        'token_lengths': token_lengths,
        'num_entities': num_entities,
        'entity_types': entity_types,
        'negative_types': negative_types
    }

def collect_summary_statistics(stats, split_name):
    return {
        'Split': split_name,
        'Number of Instances': stats['num_instances'],
        'Average Token Length': sum(stats['token_lengths']) / len(stats['token_lengths']),
        'Average Number of Entities': sum(stats['num_entities']) / len(stats['num_entities'])
    }

def collect_detailed_statistics(stats, split_name):
    detailed_stats = []
    for entity_type, count in stats['entity_types'].items():
        detailed_stats.append({
            'Split': split_name,
            'Entity Type': entity_type,
            'Number of Positive Instances': count,
            'Number of Negative Instances': stats['negative_types'].get(entity_type, 0)
        })
    return detailed_stats

def main(input_folder, output_folder):
    # Ensure the output folder exists
    os.makedirs(output_folder, exist_ok=True)
    
    tokenizer = DebertaV2TokenizerFast.from_pretrained("microsoft/deberta-v3-large")
    splits = ['train', 'val', 'test']
    all_stats = {}
    summary_stats = []
    detailed_stats = {split: [] for split in splits}
    
    for split in splits:
        file_path = os.path.join(input_folder, f"{split}.json")
        if os.path.exists(file_path):
            data = load_data(file_path)
            stats = analyze_data(data, tokenizer)
            all_stats[split] = stats
            summary_stats.append(collect_summary_statistics(stats, split))
            detailed_stats[split].extend(collect_detailed_statistics(stats, split))
        else:
            print(f"File {file_path} does not exist.")
    
    # Create and save the summary table
    summary_df = pd.DataFrame(summary_stats)
    summary_output_path = os.path.join(output_folder, f"{os.path.basename(input_folder)}_summary.csv")
    summary_df.to_csv(summary_output_path, index=False)
    
    # Create and save the detailed tables for each split
    detailed_dfs = {split: pd.DataFrame(detailed_stats[split]) for split in splits if detailed_stats[split]}
    for split, df in detailed_dfs.items():
        detailed_output_path = os.path.join(output_folder, f"{os.path.basename(input_folder)}_{split}_detailed.csv")
        df.to_csv(detailed_output_path, index=False)
    
    # Create and save the panel as a single image
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    for i, split in enumerate(splits):
        if split in all_stats:
            stats = all_stats[split]
            # Token Length Histogram
            axes[0, i].hist(stats['token_lengths'], bins=20, edgecolor='black')
            axes[0, i].set_title(f'Token Lengths in {split} Split')
            axes[0, i].set_xlabel('Token Length')
            axes[0, i].set_ylabel('Frequency')
            
            # Number of Entities Histogram
            axes[1, i].hist(stats['num_entities'], bins=20, edgecolor='black')
            axes[1, i].set_title(f'Number of Entities in {split} Split')
            axes[1, i].set_xlabel('Number of Entities')
            axes[1, i].set_ylabel('Frequency')
    
    # Save the entire figure as one image
    panel_output_path = os.path.join(output_folder, f"{os.path.basename(input_folder)}_panel_histograms.png")
    plt.tight_layout()
    plt.savefig(panel_output_path, format='png')
    plt.close()
    
    return summary_df, detailed_dfs

# Define the input and output folders
input_folder = "./data/tac_chunked"
output_folder = "./eda_output/tac_eda"
summary_df, detailed_dfs = main(input_folder, output_folder)