In [None]:
# -*- coding: utf-8 -*-
"""
1_Data_Exploration.ipynb

This notebook explores and provides statistics for the processed dataset used in the text summarization project.
It loads the processed data, displays basic information, and analyzes text length distributions.
"""

# Import necessary libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys

# Add parent directory to path
sys.path.append('..')

from src.prepare_data import load_processed_data
from datasets import Dataset

# --- Load Data ---
print("Loading processed datasets...")
train_dataset = load_processed_data("train")
val_dataset = load_processed_data("validation")
test_dataset = load_processed_data("test")

# Convert to pandas for easier analysis
train_df = train_dataset.to_pandas()
val_df = val_dataset.to_pandas()
test_df = test_dataset.to_pandas()

print("Raw datasets loaded successfully.")

# Combine for overall statistics if needed, or analyze separately
# For simplicity, we'll concatenate for length analysis here
all_data_df = pd.concat([train_df, val_df, test_df], ignore_index=True)

# --- Basic Statistics ---
print("\n--- Basic Statistics ---")
print(f"Train dataset shape: {train_df.shape}")
print(f"Validation dataset shape: {val_df.shape}")
print(f"Test dataset shape: {test_df.shape}")
print(f"Total data points: {all_data_df.shape[0]}")

print("\n--- Sample Data (Train) ---")
print(train_df.head())

print("\n--- Data Info (Train) ---")
train_df.info()

# --- Text Length Analysis ---
print("\n--- Text Length Analysis ---")
all_data_df['article_len'] = all_data_df['article'].apply(lambda x: len(x.split()))
all_data_df['summary_len'] = all_data_df['summary'].apply(lambda x: len(x.split()))

print("\nArticle Length Statistics (in words):")
print(all_data_df['article_len'].describe())

print("\nSummary Length Statistics (in words):")
print(all_data_df['summary_len'].describe())

# --- Visualization: Length Distributions ---
print("\nGenerating length distribution plots...")

plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
sns.histplot(all_data_df['article_len'], bins=50, color='skyblue')
plt.title('Distribution of Article Lengths')
plt.xlabel('Number of Words')
plt.ylabel('Frequency')

plt.subplot(1, 2, 2)
sns.histplot(all_data_df['summary_len'], bins=50, color='lightcoral')
plt.title('Distribution of Summary Lengths')
plt.xlabel('Number of Words')
plt.ylabel('Frequency')

plt.tight_layout()

# Save figures to report/figures
figures_dir = os.path.join("report", "figures")
os.makedirs(figures_dir, exist_ok=True)
plt.savefig(os.path.join(figures_dir, "length_distributions.png"))
print(f"Length distribution plots saved to {os.path.join(figures_dir, "length_distributions.png")}")
plt.show()

print("Data exploration complete.")
