In [None]:
import fitz  # pymupdf
import re
import os
from pathlib import Path
import pandas as pd
import csv
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from wordcloud import WordCloud
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from collections import Counter
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Download required NLTK data
try:
    nltk.download('punkt', quiet=True)
    nltk.download('punkt_tab', quiet=True)
    nltk.download('stopwords', quiet=True)
except:
    print("Warning: Could not download NLTK data")

# ML and Deep Learning imports
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from sklearn.model_selection import train_test_split
from transformers import (
    DataCollatorForSeq2Seq,
    TrainingArguments,
    Trainer
)

# Try to import evaluation metrics
try:
    import evaluate
    rouge_available = True
except ImportError:
    rouge_available = False
    print("Warning: evaluate library not available, ROUGE scores will be skipped")

# Set style for plots
plt.style.use('default')
sns.set_palette("husl")

# 1. PDF utilities
def pdf_to_text(path):
    try:
        doc = fitz.open(path)
        pages = [page.get_text("text") for page in doc if page.get_text("text").strip()]
        doc.close()
        full_text = "\n\n".join(pages)
        full_text = re.sub(r"[^\S\r\n]+", " ", full_text)
        full_text = re.sub(r"\n\s*\n", "\n\n", full_text)
        return full_text
    except Exception as e:
        print(f"Error processing PDF {path}: {e}")
        return ""

def chunk_text(text, max_tokens=400, overlap=50):
    words = text.split()
    chunks = []
    i = 0
    while i < len(words):
        end_idx = min(i + max_tokens, len(words))
        chunk = " ".join(words[i:end_idx])
        chunks.append(chunk)
        if end_idx >= len(words):
            break
        i += max_tokens - overlap
    return chunks

# 2. CSV Diagnostic Tool
def diagnose_csv_structure(file_path):
    print("="*80)
    print("CSV FILE DIAGNOSTIC TOOL")
    print("="*80)

    # Basic file inspection
    print("\n1. BASIC FILE INSPECTION")
    print("-" * 40)

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            first_lines = [f.readline().strip() for _ in range(5)]

        print("First 5 lines of the file:")
        for i, line in enumerate(first_lines):
            print(f"Line {i+1}: {line[:200]}..." if len(line) > 200 else f"Line {i+1}: {line}")

        separators = [',', ';', '\t', '|']
        separator_counts = {sep: first_lines[0].count(sep) for sep in separators}
        print(f"\nPotential separators found: {separator_counts}")
        likely_separator = max(separator_counts, key=separator_counts.get)
        print(f"Most likely separator: '{likely_separator}'")

    except Exception as e:
        print(f"Error reading file: {e}")
        return None, None

    # Try different loading methods
    print("\n2. TRYING DIFFERENT LOADING METHODS")
    print("-" * 40)

    loading_results = {}

    for method, kwargs in [
        ('standard', {}),
        ('quote_minimal', {'quoting': csv.QUOTE_MINIMAL}),
        ('quote_none', {'quoting': csv.QUOTE_NONE, 'on_bad_lines': 'skip', 'engine': 'python'})
    ]:
        try:
            df = pd.read_csv(file_path, **kwargs)
            non_null_score = df.iloc[:, :2].notna().sum().sum() if df.shape[1] >= 2 else 0
            loading_results[method] = {
                'success': True,
                'shape': df.shape,
                'columns': df.columns.tolist()[:5],
                'score': non_null_score
            }
        except Exception as e:
            loading_results[method] = {'success': False, 'error': str(e)}

    for method, result in loading_results.items():
        print(f"\n{method.upper()}:")
        if result['success']:
            print(f"  Shape: {result['shape']}")
            print(f"  Columns: {result['columns']}")
            print(f"  Score: {result['score']}")
        else:
            print(f"  Failed: {result['error']}")

    # Select best method
    best_method = max(
        (method for method in loading_results if loading_results[method]['success']),
        key=lambda m: loading_results[m]['score'],
        default=None
    )

    if best_method:
        kwargs = {
            'standard': {},
            'quote_minimal': {'quoting': csv.QUOTE_MINIMAL},
            'quote_none': {'quoting': csv.QUOTE_NONE, 'on_bad_lines': 'skip', 'engine': 'python'}
        }[best_method]
        df = pd.read_csv(file_path, **kwargs)
        return df, best_method
    return None, None

def analyze_columns_for_summarization(df):
    print("\n3. COLUMN ANALYSIS FOR SUMMARIZATION")
    print("-" * 40)

    if df is None or len(df) == 0:
        print("No data to analyze")
        return None

    print(f"Dataset shape: {df.shape}")
    print(f"Columns: {df.columns.tolist()}")

    # Analyze columns
    column_stats = {}
    for col in df.columns:
        non_null_count = df[col].count()
        if non_null_count > 0:
            sample_texts = df[col].dropna().head(1000).astype(str)
            column_stats[col] = {
                'non_null_count': non_null_count,
                'avg_char_length': sample_texts.str.len().mean(),
                'avg_word_count': sample_texts.str.split().str.len().mean(),
                'sample': str(sample_texts.iloc[0])[:150] + "..." if len(str(sample_texts.iloc[0])) > 150 else str(sample_texts.iloc[0])
            }

    print("\nColumn Statistics:")
    for col, stats in column_stats.items():
        print(f"\n{col}:")
        print(f"  Non-null: {stats['non_null_count']}")
        print(f"  Avg length: {stats['avg_char_length']:.1f} chars, {stats['avg_word_count']:.1f} words")
        print(f"  Sample: {stats['sample']}")

    # Find best document-summary pairs
    best_pairs = []
    columns = list(column_stats.keys())

    for i in range(len(columns)):
        for j in range(len(columns)):
            if i != j:
                col1, col2 = columns[i], columns[j]
                valid_pairs = (~df[col1].isnull() & ~df[col2].isnull()).sum()

                if valid_pairs >= 1000:  # Require more pairs for large dataset
                    stats1, stats2 = column_stats[col1], column_stats[col2]
                    doc_col, sum_col = (col1, col2) if stats1['avg_char_length'] > stats2['avg_char_length'] else (col2, col1)
                    doc_len, sum_len = column_stats[doc_col]['avg_char_length'], column_stats[sum_col]['avg_char_length']
                    length_ratio = sum_len / (doc_len + 1)
                    quality_score = valid_pairs * (1 / (abs(length_ratio - 0.15) + 0.1))  # Target ~0.15 compression

                    best_pairs.append({
                        'document_col': doc_col,
                        'summary_col': sum_col,
                        'valid_pairs': valid_pairs,
                        'doc_avg_len': doc_len,
                        'sum_avg_len': sum_len,
                        'compression_ratio': length_ratio,
                        'quality_score': quality_score
                    })

    if best_pairs:
        return max(best_pairs, key=lambda x: x['quality_score'])
    return None

def suggest_corrected_loading(file_path):
    df, method = diagnose_csv_structure(file_path)

    if df is None:
        print("\nCould not load CSV file. Please check file path, format, and encoding.")
        return None

    best_mapping = analyze_columns_for_summarization(df)

    if best_mapping is None:
        print("\nNo suitable document-summary pairs found.")
        return None

    print("\n" + "="*80)
    print("RECOMMENDED CONFIGURATION")
    print("="*80)
    print(f"Document column: '{best_mapping['document_col']}'")
    print(f"Summary column: '{best_mapping['summary_col']}'")
    print(f"Valid pairs: {best_mapping['valid_pairs']}")
    print(f"Avg document length: {best_mapping['doc_avg_len']:.1f} chars")
    print(f"Avg summary length: {best_mapping['sum_avg_len']:.1f} chars")
    print(f"Compression ratio: {best_mapping['compression_ratio']:.3f}")

    return {'loading_method': method, **best_mapping}

# 3. Data Loading
def load_and_clean_data_corrected(file_path, force_columns=None, chunksize=10000):
    print(f"Loading data from: {file_path}")

    if not os.path.exists(file_path):
        raise FileNotFoundError(f"CSV file not found: {file_path}")

    chunks = []
    if force_columns:
        method = force_columns.get('loading_method', 'standard')
        doc_col = force_columns['document_col']
        sum_col = force_columns['summary_col']

        print(f"Using columns: {doc_col} -> document, {sum_col} -> summary")
        kwargs = {
            'standard': {},
            'quote_minimal': {'quoting': csv.QUOTE_MINIMAL},
            'quote_none': {'quoting': csv.QUOTE_NONE, 'on_bad_lines': 'skip', 'engine': 'python'}
        }[method]

        for chunk in pd.read_csv(file_path, chunksize=chunksize, **kwargs):
            chunk = chunk[[doc_col, sum_col]].copy()
            chunk.columns = ['document', 'summary']
            chunk = chunk.dropna(subset=['document', 'summary'])
            chunk['document'] = chunk['document'].astype(str).str.strip()
            chunk['summary'] = chunk['summary'].astype(str).str.strip()
            chunk = chunk[(chunk['document'] != '') & (chunk['summary'] != '')]
            chunks.append(chunk)
    else:
        print("No column mapping provided. Running diagnostic...")
        best_mapping = suggest_corrected_loading(file_path)
        if best_mapping is None:
            print("Falling back to original method...")
            return load_and_clean_data_original(file_path)
        return load_and_clean_data_corrected(file_path, best_mapping, chunksize)

    df = pd.concat(chunks, ignore_index=True) if chunks else pd.DataFrame(columns=['document', 'summary'])
    print(f"Final cleaned dataset: {df.shape}")
    return df

def load_and_clean_data_original(file_path, chunksize=10000):
    print(f"Using original loading method for: {file_path}")

    for encoding in ['utf-8', 'latin-1', 'cp1252', 'utf-16']:
        try:
            chunks = []
            for chunk in pd.read_csv(
                file_path,
                encoding=encoding,
                engine="python",
                quoting=csv.QUOTE_NONE,
                on_bad_lines="skip",
                sep=',',
                chunksize=chunksize
            ):
                if len(chunk.columns) >= 2:
                    chunk = chunk.iloc[:, :2].copy()
                    chunk.columns = ['document', 'summary']
                    chunk = chunk.dropna(subset=['document', 'summary'])
                    chunk['document'] = chunk['document'].astype(str).str.strip()
                    chunk['summary'] = chunk['summary'].astype(str).str.strip()
                    chunk = chunk[(chunk['document'] != '') & (chunk['summary'] != '')]
                    chunks.append(chunk)
            df = pd.concat(chunks, ignore_index=True) if chunks else pd.DataFrame(columns=['document', 'summary'])
            print(f"Successfully loaded with encoding: {encoding}")
            print(f"Final cleaned dataset: {df.shape}")
            return df
        except Exception as e:
            print(f"Failed with encoding {encoding}: {e}")
    raise ValueError("Could not read CSV with any encoding.")

# 4. EDA
def perform_comprehensive_eda(df, save_plots=True, sample_size=5000):
    print("\n" + "="*60)
    print("EFFICIENT EXPLORATORY DATA ANALYSIS")
    print("="*60)

    if save_plots:
        os.makedirs("eda_plots", exist_ok=True)

    print(f"Dataset shape: {df.shape}")
    sample_df = df.sample(min(sample_size, len(df)), random_state=42) if len(df) > sample_size else df
    print(f"Using sample of {len(sample_df)} rows for detailed analysis")

    # Basic info
    print("\n1. DATASET OVERVIEW")
    print("-" * 30)
    print(f"Memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")
    print("\nFirst few samples:")
    for i in range(min(3, len(df))):
        print(f"\nSample {i+1}:")
        print(f"Document (first 150 chars): {str(df.iloc[i]['document'])[:150]}...")
        print(f"Summary (first 100 chars): {str(df.iloc[i]['summary'])[:100]}...")

    # Missing values
    print("\n2. MISSING VALUES ANALYSIS")
    print("-" * 30)
    print(df.isnull().sum())

    # Text length analysis
    print("\n3. TEXT LENGTH ANALYSIS")
    print("-" * 30)

    sample_df['doc_length'] = sample_df['document'].str.len()
    sample_df['doc_word_count'] = sample_df['document'].str.split().str.len()
    sample_df['sum_length'] = sample_df['summary'].str.len()
    sample_df['sum_word_count'] = sample_df['summary'].str.split().str.len()
    sample_df['compression_ratio'] = sample_df['sum_length'] / sample_df['doc_length']

    length_stats = pd.DataFrame({
        'Document_chars': sample_df['doc_length'].describe(),
        'Summary_chars': sample_df['sum_length'].describe(),
        'Document_words': sample_df['doc_word_count'].describe(),
        'Summary_words': sample_df['sum_word_count'].describe(),
        'Compression_ratio': sample_df['compression_ratio'].describe()
    })

    print("Length Statistics:")
    print(length_stats.round(2))

    # Plots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Text Length Distributions', fontsize=16)

    axes[0, 0].hist(sample_df['doc_length'], bins=50, alpha=0.7, color='skyblue')
    axes[0, 0].set_title('Document Character Length')
    axes[0, 0].set_xlabel('Characters')
    axes[0, 0].set_ylabel('Frequency')

    axes[0, 1].hist(sample_df['sum_length'], bins=50, alpha=0.7, color='lightcoral')
    axes[0, 1].set_title('Summary Character Length')
    axes[0, 1].set_xlabel('Characters')
    axes[0, 1].set_ylabel('Frequency')

    axes[0, 2].hist(sample_df['compression_ratio'], bins=50, alpha=0.7, color='lightgreen')
    axes[0, 2].set_title('Compression Ratio')
    axes[0, 2].set_xlabel('Summary Length / Document Length')
    axes[0, 2].set_ylabel('Frequency')

    axes[1, 0].hist(sample_df['doc_word_count'], bins=50, alpha=0.7, color='orange')
    axes[1, 0].set_title('Document Word Count')
    axes[1, 0].set_xlabel('Words')
    axes[1, 0].set_ylabel('Frequency')

    axes[1, 1].hist(sample_df['sum_word_count'], bins=50, alpha=0.7, color='purple')
    axes[1, 1].set_title('Summary Word Count')
    axes[1, 1].set_xlabel('Words')
    axes[1, 1].set_ylabel('Frequency')

    axes[1, 2].scatter(sample_df['doc_length'], sample_df['sum_length'], alpha=0.5)
    axes[1, 2].set_title('Document vs Summary Length')
    axes[1, 2].set_xlabel('Document Length')
    axes[1, 2].set_ylabel('Summary Length')

    plt.tight_layout()
    if save_plots:
        plt.savefig('eda_plots/length_distributions.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Vocabulary analysis
    print("\n4. VOCABULARY ANALYSIS")
    print("-" * 30)

    stop_words = set(stopwords.words('english')) if 'stopwords' in nltk.data.find('corpora') else set(['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to'])

    doc_words = []
    sum_words = []
    for text in sample_df['document']:
        words = word_tokenize(str(text).lower())
        doc_words.extend([word for word in words if word.isalpha() and word not in stop_words])
    for text in sample_df['summary']:
        words = word_tokenize(str(text).lower())
        sum_words.extend([word for word in words if word.isalpha() and word not in stop_words])

    doc_vocab = Counter(doc_words)
    sum_vocab = Counter(sum_words)

    print(f"Document vocabulary size: {len(doc_vocab)}")
    print(f"Summary vocabulary size: {len(sum_vocab)}")
    print(f"Vocabulary overlap: {len(set(doc_vocab.keys()) & set(sum_vocab.keys()))}")
    print("\nMost common words in documents:")
    for word, count in doc_vocab.most_common(10):
        print(f"  {word}: {count}")
    print("\nMost common words in summaries:")
    for word, count in sum_vocab.most_common(10):
        print(f"  {word}: {count}")

    if doc_words and sum_words:
        fig, axes = plt.subplots(1, 2, figsize=(16, 8))
        doc_wordcloud = WordCloud(width=800, height=400, background_color='white').generate(' '.join(doc_words))
        axes[0].imshow(doc_wordcloud, interpolation='bilinear')
        axes[0].set_title('Document Word Cloud')
        axes[0].axis('off')

        sum_wordcloud = WordCloud(width=800, height=400, background_color='white').generate(' '.join(sum_words))
        axes[1].imshow(sum_wordcloud, interpolation='bilinear')
        axes[1].set_title('Summary Word Cloud')
        axes[1].axis('off')

        plt.tight_layout()
        if save_plots:
            plt.savefig('eda_plots/wordclouds.png', dpi=300, bbox_inches='tight')
        plt.show()

    return sample_df

# 5. Preprocessing
def advanced_preprocessing(df):
    print("\n" + "="*50)
    print("ADVANCED PREPROCESSING")
    print("="*50)

    initial_count = len(df)
    print(f"Initial dataset size: {initial_count}")

    def clean_text(text):
        text = str(text)
        text = text.replace('story_separator_special_tag', '').replace('replace_table_token_7_th', '')
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'\n+', '\n', text)
        return text.strip()

    df['document'] = df['document'].apply(clean_text)
    df['summary'] = df['summary'].apply(clean_text)

    # Calculate lengths for filtering
    df['doc_word_count'] = df['document'].str.split().str.len()
    df['sum_word_count'] = df['summary'].str.split().str.len()
    df['doc_length'] = df['document'].str.len()
    df['sum_length'] = df['summary'].str.len()
    df['compression_ratio'] = df['sum_length'] / df['doc_length']

    # Relaxed filters based on EDA
    len_filter = (
        (df['doc_word_count'] >= 1000) &  # Min 1000 words (from EDA: min ~1807)
        (df['doc_word_count'] <= 2500) &  # Max 2500 words (from EDA: max ~2064)
        (df['sum_word_count'] >= 20) &    # Min 20 words (from EDA: min ~24)
        (df['sum_word_count'] <= 600) &   # Max 600 words (from EDA: max ~484)
        (df['doc_length'] >= 5000) &      # Min 5000 chars (from EDA: min ~9357)
        (df['sum_length'] >= 150)         # Min 150 chars (from EDA: min ~178)
    )

    df_filtered = df[len_filter].copy()
    print(f"After length filtering: {len(df_filtered)} ({len(df_filtered)/initial_count*100:.1f}%)")

    # Compression ratio filter
    compression_filter = (
        (df_filtered['compression_ratio'] >= 0.05) &  # Min 5% (from EDA: min ~0.02)
        (df_filtered['compression_ratio'] <= 0.30)    # Max 30% (from EDA: max ~0.25)
    )

    df_filtered = df_filtered[compression_filter].copy()
    print(f"After compression ratio filtering: {len(df_filtered)} ({len(df_filtered)/initial_count*100:.1f}%)")

    # Remove duplicates
    initial_filtered = len(df_filtered)
    df_filtered = df_filtered.drop_duplicates(subset=['document'], keep='first').reset_index(drop=True)
    print(f"After duplicate removal: {len(df_filtered)} (removed {initial_filtered - len(df_filtered)} duplicates)")

    # Fallback with more relaxed filters if needed
    if len(df_filtered) < 1000:
        print("WARNING: Insufficient data after filtering. Using relaxed filters...")
        relaxed_filter = (
            (df['doc_word_count'] >= 500) &
            (df['sum_word_count'] >= 10) &
            (df['doc_length'] >= 3000) &
            (df['sum_length'] >= 100) &
            (df['compression_ratio'] >= 0.01) &
            (df['compression_ratio'] <= 0.50)
        )
        df_filtered = df[relaxed_filter].copy()
        print(f"Using relaxed filters: {len(df_filtered)} samples")

    # Clean up temporary columns
    df_filtered = df_filtered[['document', 'summary']].reset_index(drop=True)
    print(f"Final preprocessed dataset size: {len(df_filtered)}")

    if len(df_filtered) < 1000:
        raise ValueError(f"Insufficient data after preprocessing: {len(df_filtered)} samples")

    return df_filtered

# 6. Data Splitting
def create_balanced_splits(df, train_size=0.7, val_size=0.3, random_state=42):
    print(f"\nCreating balanced splits: {train_size:.1%} train, {val_size:.1%} validation")

    if len(df) < 1000:
        raise ValueError(f"Need at least 1000 samples, got {len(df)}")

    # Stratify by document length
    df['length_bin'] = pd.qcut(df['document'].str.len(), q=5, duplicates='drop')
    train_df, val_df = train_test_split(
        df,
        test_size=val_size,
        random_state=random_state,
        stratify=df['length_bin']
    )

    train_df = train_df[['document', 'summary']].reset_index(drop=True)
    val_df = val_df[['document', 'summary']].reset_index(drop=True)

    print(f"Train set size: {len(train_df)}")
    print(f"Validation set size: {len(val_df)}")

    print("\nSplit Statistics:")
    print(f"Train - Avg doc length: {train_df['document'].str.len().mean():.1f}")
    print(f"Train - Avg sum length: {train_df['summary'].str.len().mean():.1f}")
    print(f"Val - Avg doc length: {val_df['document'].str.len().mean():.1f}")
    print(f"Val - Avg sum length: {val_df['summary'].str.len().mean():.1f}")

    return train_df, val_df

# 7. Tokenization
def preprocess_data_enhanced(train_df, val_df, tokenizer, max_input_length=512, max_target_length=128):
    print(f"\nTokenizing data (max_input: {max_input_length}, max_target: {max_target_length})")

    def preprocess_function(examples):
        inputs = ["summarize: " + doc for doc in examples['document']]
        targets = examples['summary']

        model_inputs = tokenizer(
            inputs,
            max_length=max_input_length,
            truncation=True,
            padding=False
        )

        labels = tokenizer(
            text_target=targets,
            max_length=max_target_length,
            truncation=True,
            padding=False
        )

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    train_dataset = Dataset.from_pandas(train_df)
    val_dataset = Dataset.from_pandas(val_df)

    train_tokenized = train_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=train_dataset.column_names
    )

    val_tokenized = val_dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=val_dataset.column_names
    )

    print(f"Tokenized train size: {len(train_tokenized)}")
    print(f"Tokenized validation size: {len(val_tokenized)}")

    return train_tokenized, val_tokenized

# 8. Trainer Setup
def setup_simple_trainer(train_dataset, val_dataset, model_name="google/mt5-small"):
    print(f"Setting up model: {model_name}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100
    )

    def compute_metrics(eval_pred):
        if not rouge_available:
            return {}

        predictions, labels = eval_pred
        if isinstance(predictions, tuple):
            predictions = predictions[0]

        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        decoded_preds = [pred.strip() for pred in decoded_preds]
        decoded_labels = [label.strip() for label in decoded_labels]

        try:
            rouge = evaluate.load("rouge")
            result = rouge.compute(
                predictions=decoded_preds,
                references=decoded_labels,
                use_stemmer=True
            )
            return {k: round(v, 4) for k, v in result.items()}
        except Exception as e:
            print(f"Error computing ROUGE: {e}")
            return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}

    training_args = TrainingArguments(
        output_dir="./enhanced-mt5-financial",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=8,
        learning_rate=3e-5,
        weight_decay=0.01,
        num_train_epochs=3,
        warmup_steps=100,
        eval_strategy="no",  # Disable evaluation during training to avoid memory issues
        save_strategy="steps",
        save_steps=400,
        save_total_limit=2,
        load_best_model_at_end=False,  # Disable since no evaluation
        logging_strategy="steps",
        logging_steps=50,
        report_to="none",
        fp16=torch.cuda.is_available(),
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        processing_class=tokenizer,  # Use processing_class instead of tokenizer
        data_collator=data_collator,
        compute_metrics=compute_metrics if rouge_available else None,  # Only if available
    )

    return trainer, tokenizer, model

# 9. Inference
def generate_summary_enhanced(text, model, tokenizer, max_length=128, device="cpu"):
    input_text = f"summarize: {text}"

    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True
    ).to(device)

    with torch.no_grad():
        summary_ids = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_length,
            min_length=20,
            num_beams=4,
            length_penalty=1.0,
            early_stopping=True,
            no_repeat_ngram_size=3,
        )

    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# 10. Testing
def test_with_your_data(model, tokenizer, df, n_samples=5, device="cpu"):
    print("\n" + "="*50)
    print("TESTING MODEL ON YOUR DATA")
    print("="*50)

    samples = df.sample(n_samples, random_state=42)
    for i, row in samples.iterrows():
        print(f"\nSample {i+1}:")
        print(f"Input (first 100 chars): {row['document'][:100]}...")
        print(f"Reference (first 100 chars): {row['summary'][:100]}...")
        print(f"Generated: {generate_summary_enhanced(row['document'], model, tokenizer, device=device)}")

def quick_test_enhanced(model, tokenizer, device="cpu"):
    print("\n" + "="*50)
    print("QUICK MODEL TEST")
    print("="*50)

    test_examples = [
        """The company reported strong quarterly results with revenue of $2.8 billion, representing a 15% increase year-over-year.
        Net income reached $450 million, up from $380 million in the same quarter last year. The growth was primarily driven by
        increased demand in the cloud services division, which saw a 28% revenue increase.""",
        """The Federal Reserve announced its decision to maintain the current federal funds rate at 5.25-5.50%, marking the
        fourth consecutive meeting without a rate change. The central bank cited recent economic data showing moderate inflation
        trends and stable employment levels as key factors in the decision."""
    ]

    for i, example in enumerate(test_examples):
        print(f"\nTest Example {i+1}:")
        print(f"Input: {example[:100]}...")
        print(f"Generated: {generate_summary_enhanced(example, model, tokenizer, device=device)}")

# 11. Main Pipeline
def main_pipeline_improved(file_path="merged.csv", save_plots=True):
    print("="*80)
    print("IMPROVED FINANCIAL SUMMARIZATION PIPELINE")
    print("="*80)

    try:
        # Step 1: Load data
        print("\nStep 1: Loading data efficiently...")
        best_mapping = suggest_corrected_loading(file_path)
        if best_mapping is None:
            raise ValueError("Failed to identify suitable document-summary columns")

        df = load_and_clean_data_corrected(file_path, best_mapping)
        if df.shape[0] < 1000:
            raise ValueError(f"Insufficient data loaded: {df.shape[0]} samples")

        # Step 2: EDA
        print("\nStep 2: Performing efficient EDA...")
        df_with_stats = perform_comprehensive_eda(df, save_plots)

        # Step 3: Preprocessing
        print("\nStep 3: Preprocessing for training...")
        df_processed = advanced_preprocessing(df)

        # Step 4: Splitting
        print("\nStep 4: Creating balanced splits...")
        train_df, val_df = create_balanced_splits(df_processed)

        # Step 5: Tokenization
        print("\nStep 5: Tokenizing data...")
        tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
        train_tokenized, val_tokenized = preprocess_data_enhanced(train_df, val_df, tokenizer)

        # Step 6: Training
        print("\nStep 6: Setting up and training model...")
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        trainer, tokenizer, model = setup_simple_trainer(train_tokenized, val_tokenized)
        trainer.train()

        # Save model
        model_save_path = "./enhanced-mt5-financial"
        trainer.save_model(model_save_path)
        tokenizer.save_pretrained(model_save_path)
        print(f"Model saved to: {model_save_path}")

        # Step 7: Testing
        print("\nStep 7: Testing model...")
        quick_test_enhanced(model, tokenizer, device)
        test_with_your_data(model, tokenizer, val_df, n_samples=5, device=device)

        print("\n" + "="*60)
        print("PIPELINE COMPLETED SUCCESSFULLY!")
        print("="*60)
        return trainer, tokenizer, model, train_df, val_df

    except Exception as e:
        print(f"Pipeline failed: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None, None, None

# 12. Compatibility Tests
def run_compatibility_tests():
    print("Running compatibility tests...")

    tests = [
        ("PyTorch", lambda: __import__("torch")),
        ("Transformers", lambda: __import__("transformers")),
        ("Datasets", lambda: __import__("datasets")),
        ("Pandas", lambda: __import__("pandas")),
        ("NumPy", lambda: __import__("numpy")),
        ("Matplotlib", lambda: __import__("matplotlib")),
        ("NLTK", lambda: __import__("nltk")),
        ("Scikit-learn", lambda: __import__("sklearn")),
    ]

    for name, test_func in tests:
        try:
            test_func()
            print(f"✓ {name} - OK")
        except ImportError:
            print(f"✗ {name} - Missing (install with: pip install {name.lower()})")

    try:
        import torch
        print(f"✓ CUDA available - {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "ℹ CUDA not available - will use CPU")
    except:
        print("✗ Could not check CUDA status")

    print("Compatibility test completed.")

# Run the pipeline
if __name__ == "__main__":
    run_compatibility_tests()
    trainer, tokenizer, model, train_df, val_df = main_pipeline_improved("merged.csv")