In [None]:
import os
from typing import List

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

BASE_DIR = "XED/"
ANNOTATED_PATH = os.path.join(BASE_DIR, "AnnotatedData")
PROJECTIONS_PATH = os.path.join(BASE_DIR, "Projections")
LANGUAGES = ["en", "fi", "fr", "es"]
OUTPUT_DIR = os.path.join(BASE_DIR, "processed")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Functions ---

def load_xed_file(lang: str) -> pd.DataFrame:
    """
    Load and clean XED data for a given language. Determines file path 
    based on whether the data is annotated (en, fi) or projected (fr, es).

    Args:
        lang: The language code (e.g., 'en', 'fr').

    Returns:
        A DataFrame with standardized 'text', 'labels', and 'language' columns.
    """
    if lang in ["en", "fi"]:
        file_path = os.path.join(ANNOTATED_PATH, f"{lang}-annotated.tsv")
    else:
        file_path = os.path.join(PROJECTIONS_PATH, f"{lang}-projections.tsv")

    data_frame = pd.read_csv(file_path, sep="\t")
    
    # Standardize column names (lowercase and strip whitespace)
    data_frame.columns = [col.strip().lower() for col in data_frame.columns]
    
    text_cols = [col for col in data_frame.columns if "text" in col or "utterance" in col]
    label_cols = [col for col in data_frame.columns if "label" in col or "emotion" in col]
    
    # Fallback to the first two columns if automatic detection fails
    if not text_cols:
        text_col_name = data_frame.columns[0]
    else:
        text_col_name = text_cols[0]
        
    if not label_cols:
        label_col_name = data_frame.columns[1]
    else:
        label_col_name = label_cols[0]
    
    data_frame = data_frame[[text_col_name, label_col_name]]
    data_frame.columns = ["text", "labels"]
    
    data_frame["language"] = lang
    return data_frame

# --- Data Loading and Concatenation ---

data_frames: List[pd.DataFrame] = []
for lang_code in LANGUAGES:
    try:
        df = load_xed_file(lang_code)
        print(f"Loaded {lang_code}: {df.shape}")
        data_frames.append(df)
    except Exception as e:
        print(f"Error loading {lang_code}: {e}")

all_data_combined = pd.concat(data_frames, ignore_index=True)
print(f"\nCombined shape: {all_data_combined.shape}")
print("Sample data head:")
print(all_data_combined.head())

# --- Data Cleaning ---

print(f"\nDuplicates before cleaning: {all_data_combined.duplicated(subset=['text']).sum()}")

# Drop duplicates based on the 'text' column
all_data_combined = all_data_combined.drop_duplicates(subset=["text"]).reset_index(drop=True)
print(f"After removing duplicates: {all_data_combined.shape}")

# Drop rows with missing data in 'text' or 'labels'
missing_labels_count = all_data_combined["labels"].isna().sum()
print(f"Missing label rows: {missing_labels_count}")
all_data_combined = all_data_combined.dropna(subset=["text", "labels"])
print(f"After dropping NA rows: {all_data_combined.shape}")

# --- Splitting and Saving ---

train_data_frames: List[pd.DataFrame] = []
test_data_frames: List[pd.DataFrame] = []
TEST_SIZE = 0.2
RANDOM_STATE = 42

for lang_code in LANGUAGES:
    lang_df = all_data_combined[all_data_combined["language"] == lang_code]
    
    # Split data per language (logic preserved)
    train_split, test_split = train_test_split(
        lang_df, 
        test_size=TEST_SIZE, 
        random_state=RANDOM_STATE
    )
    
    train_data_frames.append(train_split)
    test_data_frames.append(test_split)
    print(f"{lang_code}: train={train_split.shape[0]}, test={test_split.shape[0]}")

print("\nSaving individual language splits...")
for lang_code, train_df, test_df in zip(LANGUAGES, train_data_frames, test_data_frames):
    train_df.to_csv(os.path.join(OUTPUT_DIR, f"train_{lang_code}.csv"), index=False)
    test_df.to_csv(os.path.join(OUTPUT_DIR, f"test_{lang_code}.csv"), index=False)

train_multi = pd.concat(train_data_frames, ignore_index=True)
test_multi = pd.concat(test_data_frames, ignore_index=True)

train_multi.to_csv(os.path.join(OUTPUT_DIR, "train_multilingual.csv"), index=False)
test_multi.to_csv(os.path.join(OUTPUT_DIR, "test_multilingual.csv"), index=False)

print("\nSaved multilingual splits:")
print(f"Train: {train_multi.shape} | Test: {test_multi.shape}")

# --- Final Checks ---
print("\nUnique languages in combined training set:", train_multi["language"].unique())
print("Sample labels distribution (Top 5 in training set):")
print(train_multi["labels"].value_counts().head())
print(f"\nPreprocessing complete. Files saved in: {OUTPUT_DIR}")