In [None]:
import pandas as pd
import numpy as np
import re
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

# ============================================================================
# CONFIGURATION
# ============================================================================

# Define the 5 dialects we're working with
DIALECTS = ['Chittagong', 'Sylhet', 'Barisal', 'Noakhali', 'Mymensingh']

# Vashantor uses different spelling
DIALECT_MAPPING = {
    'Chittagong': 'Chittagong',
    'Sylhet': 'Sylhet',
    'Barisal': ['Barisal', 'Barishal'],  # Handle both spellings
    'Noakhali': 'Noakhali',
    'Mymensingh': 'Mymensingh'
}

# Kaggle dataset paths
BANGLADIAL_PATH = '/kaggle/input/bangladiel/BanglaDial A Merged and Imbalanced text Dataset for Bengali Regional dialect analysis. (1).csv'
VASHANTOR_BASE_PATH = '/kaggle/input/vashantor010'

# Output paths
OUTPUT_DIR = '/kaggle/working/'

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def clean_text(text):
    """
    Clean and normalize text with:
    - Emoji removal
    - Hashtag/mention removal
    - Whitespace normalization
    - Punctuation spacing standardization
    - Lowercasing
    """
    if pd.isna(text) or not isinstance(text, str):
        return ""
    
    # Remove emojis (broad Unicode range coverage)
    emoji_pattern = re.compile(
        "["
        "\U0001F600-\U0001F64F"  # emoticons
        "\U0001F300-\U0001F5FF"  # symbols & pictographs
        "\U0001F680-\U0001F6FF"  # transport & map symbols
        "\U0001F1E0-\U0001F1FF"  # flags
        "\U00002702-\U000027B0"
        "\U000024C2-\U0001F251"
        "\U0001F900-\U0001F9FF"  # supplemental symbols
        "\U0001FA70-\U0001FAFF"
        "]+", flags=re.UNICODE
    )
    text = emoji_pattern.sub(r'', text)
    
    # Remove hashtags and mentions
    text = re.sub(r'#\w+', '', text)
    text = re.sub(r'@\w+', '', text)
    
    # Remove URLs
    text = re.sub(r'http\S+|www\.\S+', '', text)
    
    # Normalize whitespace
    text = re.sub(r'\s+', ' ', text)
    
    # Standardize punctuation spacing (add space after punctuation)
    text = re.sub(r'([।!?,.;:])([^\s])', r'\1 \2', text)
    
    # Remove extra spaces before punctuation
    text = re.sub(r'\s+([।!?,.])', r'\1', text)
    
    # Lowercase
    text = text.lower()
    
    # Strip leading/trailing whitespace
    text = text.strip()
    
    return text


def is_english_only(text):
    """
    Check if text contains only English characters (no Bangla script).
    Returns True if text is English-only or empty.
    """
    if not text or len(text.strip()) == 0:
        return True
    
    # Check for Bangla Unicode range (0980-09FF)
    bangla_chars = re.findall(r'[\u0980-\u09FF]', text)
    
    # If no Bangla characters found, it's English-only
    return len(bangla_chars) == 0


def normalize_bangla(text):
    """
    Apply minimal normalization for Bangla script:
    - Normalize certain variations
    - Remove zero-width characters
    """
    # Remove zero-width characters
    text = re.sub(r'[\u200B-\u200D\uFEFF]', '', text)
    
    return text


def plot_dialect_distribution(df, title, filename):
    """Create bar plot for dialect distribution"""
    fig, ax = plt.subplots(figsize=(10, 6))
    counts = df['dialect'].value_counts()
    
    colors = sns.color_palette("husl", len(counts))
    bars = ax.bar(counts.index, counts.values, color=colors, edgecolor='black', linewidth=1.2)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    ax.set_xlabel('Dialect', fontsize=12, fontweight='bold')
    ax.set_ylabel('Number of Samples', fontsize=12, fontweight='bold')
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR + filename, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Saved: {OUTPUT_DIR}{filename}")


def plot_text_length_distribution(df, title, filename):
    """Create histogram and box plot for text length distribution"""
    df['text_length'] = df['text'].str.len()
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Histogram
    axes[0].hist(df['text_length'], bins=50, color='skyblue', edgecolor='black', alpha=0.7)
    axes[0].set_xlabel('Text Length (characters)', fontsize=11, fontweight='bold')
    axes[0].set_ylabel('Frequency', fontsize=11, fontweight='bold')
    axes[0].set_title('Text Length Distribution', fontsize=12, fontweight='bold')
    axes[0].axvline(df['text_length'].mean(), color='red', linestyle='--', 
                    linewidth=2, label=f'Mean: {df["text_length"].mean():.0f}')
    axes[0].legend()
    
    # Box plot by dialect
    dialect_order = df['dialect'].value_counts().index.tolist()
    sns.boxplot(data=df, y='dialect', x='text_length', order=dialect_order, 
                palette='Set2', ax=axes[1])
    axes[1].set_xlabel('Text Length (characters)', fontsize=11, fontweight='bold')
    axes[1].set_ylabel('Dialect', fontsize=11, fontweight='bold')
    axes[1].set_title('Text Length by Dialect', fontsize=12, fontweight='bold')
    
    plt.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR + filename, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Saved: {OUTPUT_DIR}{filename}")


def plot_split_comparison(train_df, val_df, test_df, filename):
    """Create comparison plot for train/val/test splits"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    splits = [
        (train_df, 'Training Set', axes[0]),
        (val_df, 'Validation Set', axes[1]),
        (test_df, 'Test Set', axes[2])
    ]
    
    for df, title, ax in splits:
        counts = df['dialect'].value_counts()
        colors = sns.color_palette("husl", len(counts))
        bars = ax.bar(counts.index, counts.values, color=colors, 
                      edgecolor='black', linewidth=1.2)
        
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                    f'{int(height)}',
                    ha='center', va='bottom', fontsize=9, fontweight='bold')
        
        ax.set_xlabel('Dialect', fontsize=10, fontweight='bold')
        ax.set_ylabel('Number of Samples', fontsize=10, fontweight='bold')
        ax.set_title(f'{title}\n(Total: {len(df)})', fontsize=11, fontweight='bold')
        ax.tick_params(axis='x', rotation=45)
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    plt.suptitle('Train/Validation/Test Split Distribution', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR + filename, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Saved: {OUTPUT_DIR}{filename}")


def plot_data_source_comparison(bangladial_df, vashantor_df, filename):
    """Compare data sources"""
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Source 1: BanglaDial
    counts1 = bangladial_df['dialect'].value_counts()
    colors1 = sns.color_palette("husl", len(counts1))
    bars1 = axes[0].bar(counts1.index, counts1.values, color=colors1, 
                        edgecolor='black', linewidth=1.2)
    
    for bar in bars1:
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height,
                     f'{int(height)}',
                     ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    axes[0].set_xlabel('Dialect', fontsize=11, fontweight='bold')
    axes[0].set_ylabel('Number of Samples', fontsize=11, fontweight='bold')
    axes[0].set_title(f'BanglaDial Dataset\n(Total: {len(bangladial_df)})', 
                      fontsize=12, fontweight='bold')
    axes[0].tick_params(axis='x', rotation=45)
    plt.setp(axes[0].xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    # Source 2: Vashantor
    counts2 = vashantor_df['dialect'].value_counts()
    colors2 = sns.color_palette("husl", len(counts2))
    bars2 = axes[1].bar(counts2.index, counts2.values, color=colors2, 
                        edgecolor='black', linewidth=1.2)
    
    for bar in bars2:
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height,
                     f'{int(height)}',
                     ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    axes[1].set_xlabel('Dialect', fontsize=11, fontweight='bold')
    axes[1].set_ylabel('Number of Samples', fontsize=11, fontweight='bold')
    axes[1].set_title(f'Vashantor Dataset\n(Total: {len(vashantor_df)})', 
                      fontsize=12, fontweight='bold')
    axes[1].tick_params(axis='x', rotation=45)
    plt.setp(axes[1].xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    plt.suptitle('Dataset Source Comparison', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR + filename, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Saved: {OUTPUT_DIR}{filename}")


# ============================================================================
# LOAD BANGLADIAL DATASET
# ============================================================================

print("=" * 80)
print("LOADING BANGLADIAL DATASET")
print("=" * 80)

# Try different encodings if needed
try:
    bangladial_df = pd.read_csv(BANGLADIAL_PATH, encoding='utf-8')
except:
    try:
        bangladial_df = pd.read_csv(BANGLADIAL_PATH, encoding='latin-1')
    except:
        bangladial_df = pd.read_csv(BANGLADIAL_PATH, encoding='cp1252')

print(f"Original BanglaDial shape: {bangladial_df.shape}")
print(f"Columns: {bangladial_df.columns.tolist()}")

# Detect text and dialect columns
text_col = None
dialect_col = None

for col in bangladial_df.columns:
    col_lower = col.lower()
    if 'text' in col_lower or 'sentence' in col_lower or 'content' in col_lower:
        text_col = col
    if 'dialect' in col_lower or 'label' in col_lower or 'class' in col_lower or 'region' in col_lower:
        dialect_col = col

if text_col is None or dialect_col is None:
    # Fallback to first two columns
    text_col = bangladial_df.columns[0]
    dialect_col = bangladial_df.columns[1]

print(f"\nUsing columns: text='{text_col}', dialect='{dialect_col}'")

# Rename columns for consistency
bangladial_df = bangladial_df.rename(columns={text_col: 'text', dialect_col: 'dialect'})

# Standardize dialect names (handle variations)
bangladial_df['dialect'] = bangladial_df['dialect'].str.strip()

# Filter for our 5 dialects (handle both spellings of Barisal)
bangladial_df = bangladial_df[
    (bangladial_df['dialect'] == 'Chittagong') |
    (bangladial_df['dialect'] == 'Sylhet') |
    (bangladial_df['dialect'].isin(['Barisal', 'Barishal'])) |
    (bangladial_df['dialect'] == 'Noakhali') |
    (bangladial_df['dialect'] == 'Mymensingh')
].copy()

# Standardize Barisal spelling
bangladial_df.loc[bangladial_df['dialect'] == 'Barishal', 'dialect'] = 'Barisal'

print(f"\nFiltered BanglaDial shape (5 dialects): {bangladial_df.shape}")
print(f"Dialect distribution:\n{bangladial_df['dialect'].value_counts()}")

# ============================================================================
# LOAD VASHANTOR DATASET
# ============================================================================

print("\n" + "=" * 80)
print("LOADING VASHANTOR DATASET")
print("=" * 80)

vashantor_dfs = []

for dialect in tqdm(DIALECTS, desc="Loading Vashantor files"):
    # Handle both Barisal and Barishal spellings
    dialect_variants = [dialect]
    if dialect == 'Barisal':
        dialect_variants = ['Barisal', 'Barishal']
    
    for variant in dialect_variants:
        for split in ['Train', 'Validation', 'Test']:
            filename = f"{variant} {split} Translation.csv"
            filepath = f"{VASHANTOR_BASE_PATH}/{filename}"
            
            try:
                df = pd.read_csv(filepath, encoding='utf-8')
                
                # Identify columns with Bangla and Banglish text
                bangla_cols = []
                
                # Look for region-specific Bangla columns (e.g., chittagong_bangla_speech)
                for col in df.columns:
                    col_lower = col.lower()
                    # Include bangla_speech, banglish_speech, and region-specific columns
                    if any(term in col_lower for term in ['bangla_speech', 'banglish_speech']):
                        if 'english' not in col_lower:  # Exclude English columns
                            bangla_cols.append(col)
                
                # If no specific columns found, use all except English
                if not bangla_cols:
                    bangla_cols = [col for col in df.columns 
                                  if 'english' not in col.lower()]
                
                # Extract text from Bangla/Banglish columns
                for col in bangla_cols:
                    temp_df = pd.DataFrame({
                        'text': df[col],
                        'dialect': 'Barisal' if variant == 'Barishal' else dialect
                    })
                    vashantor_dfs.append(temp_df)
                
                print(f"✓ Loaded: {filename} ({len(df)} rows, {len(bangla_cols)} text columns)")
                break  # If successful, don't try other variants
                
            except FileNotFoundError:
                continue  # Try next variant
            except Exception as e:
                print(f"Error loading {filename}: {e}")
                continue

# Combine all Vashantor data
if vashantor_dfs:
    vashantor_df = pd.concat(vashantor_dfs, ignore_index=True)
    print(f"\nVashantor combined shape: {vashantor_df.shape}")
    print(f"Dialect distribution:\n{vashantor_df['dialect'].value_counts()}")
else:
    print("Warning: No Vashantor data loaded!")
    vashantor_df = pd.DataFrame(columns=['text', 'dialect'])

# ============================================================================
# VISUALIZE ORIGINAL DATASETS
# ============================================================================

print("\n" + "=" * 80)
print("VISUALIZING ORIGINAL DATASETS")
print("=" * 80)

if len(bangladial_df) > 0 and len(vashantor_df) > 0:
    plot_data_source_comparison(bangladial_df, vashantor_df, 
                                'viz_01_source_comparison.png')

# ============================================================================
# MERGE DATASETS
# ============================================================================

print("\n" + "=" * 80)
print("MERGING DATASETS")
print("=" * 80)

# Combine BanglaDial and Vashantor
combined_df = pd.concat([bangladial_df[['text', 'dialect']], 
                         vashantor_df[['text', 'dialect']]], 
                        ignore_index=True)

print(f"Combined dataset shape: {combined_df.shape}")
print(f"Dialect distribution:\n{combined_df['dialect'].value_counts()}")

# Visualize combined dataset before cleaning
plot_dialect_distribution(combined_df, 
                         'Combined Dataset (Before Cleaning)',
                         'viz_02_combined_before_cleaning.png')

# ============================================================================
# DATA CLEANING
# ============================================================================

print("\n" + "=" * 80)
print("CLEANING DATA")
print("=" * 80)

# Track cleaning steps
cleaning_stats = []

initial_count = len(combined_df)
cleaning_stats.append(('Initial', initial_count))

# Drop rows with missing text
print(f"Missing values before cleaning: {combined_df['text'].isna().sum()}")
combined_df = combined_df.dropna(subset=['text'])
cleaning_stats.append(('After dropping NaN', len(combined_df)))

# Apply text cleaning
tqdm.pandas(desc="Cleaning text")
combined_df['text'] = combined_df['text'].progress_apply(clean_text)

# Apply Bangla normalization
tqdm.pandas(desc="Normalizing Bangla")
combined_df['text'] = combined_df['text'].progress_apply(normalize_bangla)

# Remove empty strings
combined_df = combined_df[combined_df['text'].str.len() > 0]
cleaning_stats.append(('After removing empty', len(combined_df)))

# Remove English-only sentences
print(f"\nRows before English removal: {len(combined_df)}")
tqdm.pandas(desc="Filtering English-only")
combined_df = combined_df[~combined_df['text'].progress_apply(is_english_only)]
print(f"Rows after English removal: {len(combined_df)}")
cleaning_stats.append(('After English removal', len(combined_df)))

# Remove duplicates
print(f"\nRows before deduplication: {len(combined_df)}")
combined_df = combined_df.drop_duplicates(subset=['text'], keep='first')
print(f"Rows after deduplication: {len(combined_df)}")
cleaning_stats.append(('After deduplication', len(combined_df)))

# Reset index
combined_df = combined_df.reset_index(drop=True)

print(f"\nFinal cleaned dataset shape: {combined_df.shape}")
print(f"Final dialect distribution:\n{combined_df['dialect'].value_counts()}")

# Visualize cleaning impact
fig, ax = plt.subplots(figsize=(10, 6))
steps = [s[0] for s in cleaning_stats]
counts = [s[1] for s in cleaning_stats]
colors = sns.color_palette("RdYlGn_r", len(steps))

bars = ax.barh(steps, counts, color=colors, edgecolor='black', linewidth=1.2)

for i, (bar, count) in enumerate(zip(bars, counts)):
    width = bar.get_width()
    ax.text(width, bar.get_y() + bar.get_height()/2.,
            f' {count:,}',
            ha='left', va='center', fontsize=11, fontweight='bold')

ax.set_xlabel('Number of Samples', fontsize=12, fontweight='bold')
ax.set_title('Data Cleaning Pipeline Impact', fontsize=14, fontweight='bold', pad=20)
ax.invert_yaxis()
plt.tight_layout()
plt.savefig(OUTPUT_DIR + 'viz_03_cleaning_pipeline.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Saved: {OUTPUT_DIR}viz_03_cleaning_pipeline.png")

# Visualize cleaned dataset
plot_dialect_distribution(combined_df, 
                         'Combined Dataset (After Cleaning)',
                         'viz_04_combined_after_cleaning.png')

plot_text_length_distribution(combined_df,
                              'Text Length Analysis (Cleaned Dataset)',
                              'viz_05_text_length_analysis.png')

# ============================================================================
# TRAIN/VAL/TEST SPLIT
# ============================================================================

print("\n" + "=" * 80)
print("SPLITTING DATA")
print("=" * 80)

# Check if we have enough samples per dialect
min_samples = combined_df['dialect'].value_counts().min()
print(f"Minimum samples per dialect: {min_samples}")

if min_samples < 10:
    print("Warning: Some dialects have very few samples. Adjusting split strategy.")
    test_size = 0.15
    val_size = 0.15
else:
    test_size = 0.15
    val_size = 0.15

# First split: train+val vs test
train_val_df, test_df = train_test_split(
    combined_df,
    test_size=test_size,
    stratify=combined_df['dialect'],
    random_state=42
)

# Second split: train vs val
val_size_adjusted = val_size / (1 - test_size)
train_df, val_df = train_test_split(
    train_val_df,
    test_size=val_size_adjusted,
    stratify=train_val_df['dialect'],
    random_state=42
)

print(f"\nTrain set: {len(train_df)} samples")
print(train_df['dialect'].value_counts())
print(f"\nValidation set: {len(val_df)} samples")
print(val_df['dialect'].value_counts())
print(f"\nTest set: {len(test_df)} samples")
print(test_df['dialect'].value_counts())

# Visualize splits
plot_split_comparison(train_df, val_df, test_df, 'viz_06_split_distribution.png')

# ============================================================================
# SAVE DATASETS
# ============================================================================

print("\n" + "=" * 80)
print("SAVING DATASETS")
print("=" * 80)

train_df.to_csv(OUTPUT_DIR + 'cleaned_bangla_train.csv', index=False, encoding='utf-8')
val_df.to_csv(OUTPUT_DIR + 'cleaned_bangla_val.csv', index=False, encoding='utf-8')
test_df.to_csv(OUTPUT_DIR + 'cleaned_bangla_test.csv', index=False, encoding='utf-8')

print(f"✓ Saved: {OUTPUT_DIR}cleaned_bangla_train.csv")
print(f"✓ Saved: {OUTPUT_DIR}cleaned_bangla_val.csv")
print(f"✓ Saved: {OUTPUT_DIR}cleaned_bangla_test.csv")

# ============================================================================
# SUMMARY STATISTICS
# ============================================================================

print("\n" + "=" * 80)
print("SUMMARY STATISTICS")
print("=" * 80)

def get_stats(df, name):
    print(f"\n{name}:")
    print(f"  Total samples: {len(df)}")
    print(f"  Avg text length: {df['text'].str.len().mean():.2f} chars")
    print(f"  Min text length: {df['text'].str.len().min()} chars")
    print(f"  Max text length: {df['text'].str.len().max()} chars")
    print(f"  Median text length: {df['text'].str.len().median():.2f} chars")
    
get_stats(train_df, "Training Set")
get_stats(val_df, "Validation Set")
get_stats(test_df, "Test Set")

# Create comprehensive summary table
summary_data = []
for split_name, split_df in [('Train', train_df), ('Validation', val_df), ('Test', test_df)]:
    for dialect in DIALECTS:
        dialect_data = split_df[split_df['dialect'] == dialect]
        if len(dialect_data) > 0:
            summary_data.append({
                'Split': split_name,
                'Dialect': dialect,
                'Count': len(dialect_data),
                'Avg Length': dialect_data['text'].str.len().mean(),
                'Min Length': dialect_data['text'].str.len().min(),
                'Max Length': dialect_data['text'].str.len().max()
            })

summary_df = pd.DataFrame(summary_data)
print("\n" + "=" * 80)
print("DETAILED SUMMARY TABLE")
print("=" * 80)
print(summary_df.to_string(index=False))

# Save summary
summary_df.to_csv(OUTPUT_DIR + 'dataset_summary.csv', index=False)
print(f"\n✓ Saved: {OUTPUT_DIR}dataset_summary.csv")

print("\n" + "=" * 80)
print("PREPROCESSING COMPLETE!")
print("=" * 80)

# Display samples
print("\nSample from training set:")
print(train_df.head(3).to_string())
print("\nSample from validation set:")
print(val_df.head(3).to_string())
print("\nSample from test set:")
print(test_df.head(3).to_string())