In [None]:
import pandas as pd
import logging
import csv
from collections import defaultdict
from typing import List, Dict, Tuple, Any

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s')


def parse_variant_id(var_id: str, prefix: str = 'file-katl_e_bayan-1507111042') -> Tuple[int, int]:
    """
    Parse a variant ID into its reference and sub-variant numbers for numerical sorting.

    Args:
        var_id (str): Variant ID (e.g., 'file-katl_e_bayan-1507111042__1.2').
        prefix (str): Prefix to remove from the ID.

    Returns:
        Tuple[int, int]: Tuple of (reference_id, sub_id).
    """
    try:
        if not var_id.startswith(prefix):
            return (9999, 0)  # Invalid IDs go to the end
        # Extract the number part, e.g., '1.2' from 'file-katl_e_bayan-1507111042__1.2'
        number_part = var_id[len(prefix) + 2:]  # Skip prefix and '__'
        ref_id, sub_id = map(int, number_part.split('.'))
        return (ref_id, sub_id)
    except (ValueError, IndexError):
        logging.warning(f"Invalid variant ID format: {var_id}")
        return (9999, 0)

# Function to read variant sentences from CSV
def read_variant_sentences(variant_file: str) -> Dict[str, List[List[Tuple[str, int]]]]:
    """
    Read variant sentences from a CSV file and return a dictionary of sentence IDs to word-index pairs.

    Args:
        variant_file (str): Path to the variant CSV file.

    Returns:
        Dict[str, List[List[Tuple[str, int]]]]: Dictionary mapping sentence IDs to lists of word-index pairs.
    """
    try:
        # Read raw file for debugging
        with open(variant_file, 'r', encoding='utf-8') as f:
            raw_content = f.read()
            logging.info(f"Raw content of {variant_file}:\n{raw_content[:500]}...")

        # Try comma separator
        df = pd.read_csv(variant_file, sep=',', encoding='utf-8', header=None, na_values=[''])
        logging.info(f"Parsed {variant_file} with comma separator. Columns: {df.shape[1]}")
        logging.debug(f"First few rows:\n{df.head().to_string()}")

        # Handle different column counts
        if df.shape[1] == 4:
            df.columns = ['index', 'sentence_id', 'sentence', 'extra']
        elif df.shape[1] == 3:
            df.columns = ['index', 'sentence_id', 'sentence']
        elif df.shape[1] == 2:
            df.columns = ['sentence_id', 'sentence']
            df['index'] = range(len(df))
        elif df.shape[1] >= 8:
            df.columns = ['index', 'sentence_id', 'sentence'] + [f'extra{i}' for i in range(df.shape[1] - 3)]
        else:
            logging.warning(f"Unexpected number of columns in {variant_file}: {df.shape[1]}. Assuming index, sentence_id, sentence, extra...")
            df.columns = ['index', 'sentence_id', 'sentence'] + [f'extra{i}' for i in range(df.shape[1] - 3)]

        # Try tab separator if sentence column is all NaN
        if 'sentence' in df.columns and df['sentence'].isna().all():
            logging.warning("Comma separator failed, trying tab separator")
            df = pd.read_csv(variant_file, sep='\t', encoding='utf-8', header=None, na_values=[''])
            logging.info(f"Parsed {variant_file} with tab separator. Columns: {df.shape[1]}")
            if df.shape[1] == 4:
                df.columns = ['index', 'sentence_id', 'sentence', 'extra']
            elif df.shape[1] == 3:
                df.columns = ['index', 'sentence_id', 'sentence']
            elif df.shape[1] == 2:
                df.columns = ['sentence_id', 'sentence']
                df['index'] = range(len(df))
            else:
                logging.warning(f"Unexpected number of columns with tab separator in {variant_file}: {df.shape[1]}. Assuming index, sentence_id, sentence, extra...")
                df.columns = ['index', 'sentence_id', 'sentence'] + [f'extra{i}' for i in range(df.shape[1] - 3)]

        variants = defaultdict(list)
        for _, row in df.iterrows():
            if pd.isna(row['sentence_id']) or ('sentence' in df.columns and pd.isna(row['sentence'])):
                logging.warning(f"Skipping empty or invalid row: {row}")
                continue
            sentence_id = str(row['sentence_id']).strip()
            words = str(row.get('sentence', '')).strip().split() if 'sentence' in df.columns else []
            word_indices = []
            for word in words:
                if word == '__NULL__':
                    continue
                if '_' in word and word != '।':
                    try:
                        w, idx = word.split('_')
                        word_indices.append((w, int(idx)))
                    except ValueError:
                        logging.warning(f"Invalid index in word: {word}")
                        word_indices.append((word, None))
                else:
                    word_indices.append((word, None))
            variants[sentence_id].append(word_indices)
        return variants
    except Exception as e:
        logging.error(f"Error reading variant file {variant_file}: {e}")
        raise

# Function to read CoNLL text file
def read_conll_file(conll_file: str) -> List[List[str]]:
    """
    Read a CoNLL file and return a list of sentences, where each sentence is a list of token rows.

    Args:
        conll_file (str): Path to the CoNLL file.

    Returns:
        List[List[str]]: List of sentences, each containing rows of CoNLL columns.
    """
    try:
        sentences = []
        current_sentence = []
        df = pd.read_csv(conll_file, sep='\t', encoding='utf-8', header=None, na_values=[''], comment='#')
        for _, row in df.iterrows():
            if row.isna().all() or str(row[0]).strip() == '':
                if current_sentence:
                    sentences.append(current_sentence)
                    current_sentence = []
            else:
                current_sentence.append(row.tolist())
        if current_sentence:
            sentences.append(current_sentence)
        logging.info(f"Read {len(sentences)} sentences from {conll_file}")
        return sentences
    except Exception as e:
        logging.error(f"Error reading CoNLL file {conll_file}: {e}")
        raise

# Function to generate variant CoNLL entries
def generate_variant_conll(reference_sentence: List[List[str]], variant_words: List[Tuple[str, int]], sentence_id: str) -> List[List[str]]:
    """
    Generate CoNLL entries for a variant sentence, mapping variant words to reference annotations.

    Args:
        reference_sentence (List[List[str]]): CoNLL rows of the reference sentence.
        variant_words (List[Tuple[str, int]]): List of (word, index) tuples for the variant.
        sentence_id (str): ID of the variant sentence.

    Returns:
        List[List[str]]: List of CoNLL rows for the variant.
    """
    try:
        if not reference_sentence or not variant_words:
            logging.warning(f"Empty reference or variant words for {sentence_id}")
            return []

        ref_dict = {}
        for row in reference_sentence:
            if len(row) < 6:
                logging.warning(f"Skipping malformed CoNLL row in {sentence_id}: {row}")
                continue
            try:
                index = int(row[0])
                word = row[1] if row[1] else '_'
                pos_tag = row[3] if row[3] else 'UNK'
                head = int(row[4]) if row[4] != '0' else 0
                rel = row[5] if row[5] else '_'
                ref_dict[index] = {'word': word, 'pos_tag': pos_tag, 'head': head, 'rel': rel}
            except (ValueError, IndexError) as e:
                logging.warning(f"Error processing CoNLL row in {sentence_id}: {row}, {e}")
                continue

        variant_conll = []
        new_index = 1
        word_indices = []
        for word, idx in variant_words:
            word_indices.append((word, idx, new_index))
            new_index += 1

        for word, idx, new_index in word_indices:
            if idx is None:
                if word == '।':
                    for row in reference_sentence:
                        if len(row) >= 6 and row[1] == '।':
                            variant_conll.append([new_index, word, row[3], row[4], row[5]])
                            break
                    else:
                        logging.warning(f"Punctuation '।' not found in reference for {sentence_id}")
                        variant_conll.append([new_index, word, 'SYM', '0', 'rsym'])
                else:
                    for row in reference_sentence:
                        if len(row) >= 6 and row[1] == word:
                            variant_conll.append([new_index, word, row[3], row[4], row[5]])
                            break
                    else:
                        logging.warning(f"Word {word} not found in reference for {sentence_id}, using defaults")
                        variant_conll.append([new_index, word, 'UNK', '0', 'UNK'])
            else:
                ref_entry = ref_dict.get(idx, {'word': word, 'pos_tag': 'UNK', 'head': 0, 'rel': 'UNK'})
                new_head = 0
                if ref_entry['head'] != 0:
                    for _, ref_idx, variant_idx in word_indices:
                        if ref_idx == ref_entry['head']:
                            new_head = variant_idx
                            break
                    if new_head == 0:
                        logging.warning(f"Head index {ref_entry['head']} not found in variant {sentence_id}, resetting to 0")
                logging.debug(f"Mapping {word} (ref_idx={idx}) to head={new_head} in {sentence_id}")
                variant_conll.append([new_index, word, ref_entry['pos_tag'], str(new_head), ref_entry['rel']])

        return variant_conll
    except Exception as e:
        logging.error(f"Error generating CoNLL for {sentence_id}: {e}")
        return []

# Main function to process files and generate output CSV
def generate_conll_output(conll_file: str, variant_file: str, output_file: str) -> None:
    """
    Process CoNLL and variant files to generate a combined CoNLL output.

    Args:
        conll_file (str): Path to the CoNLL input file.
        variant_file (str): Path to the variant CSV file.
        output_file (str): Path to the output CoNLL file.
    """
    try:
        # Read input files
        sentences = read_conll_file(conll_file)
        variants = read_variant_sentences(variant_file)

        # Open output file
        with open(output_file, 'w', encoding='utf-8', newline='') as f:
            writer = csv.writer(f, delimiter='\t')
            writer.writerow(['word', 'pos_tag', 'head', 'rel'])

            base_prefix = 'file-katl_e_bayan-1507111042'
            # Group variant IDs by reference sentence
            variant_groups = defaultdict(list)
            for var_id in variants.keys():
                if var_id.startswith(base_prefix):
                    ref_id, sub_id = parse_variant_id(var_id, base_prefix)
                    if sub_id == 0:  # Reference sentence (e.g., 1.0, 2.0)
                        variant_groups[ref_id].append((var_id, True))
                    else:  # Sub-variant (e.g., 1.1, 1.2)
                        variant_groups[ref_id].append((var_id, False))

            # Process each reference sentence and its variants in numerical order
            for ref_id in sorted(variant_groups.keys()):
                # Sort variants: reference first (X.0), then sub-variants (X.1, X.2, ...)
                group = sorted(variant_groups[ref_id], key=lambda x: parse_variant_id(x[0], base_prefix)[1])

                # Map reference ID to sentence index (e.g., 1.0 -> sentences[0], 2.0 -> sentences[1])
                try:
                    sentence_idx = ref_id - 1  # Assume 1.0 is sentences[0], 2.0 is sentences[1], etc.
                    if sentence_idx < 0 or sentence_idx >= len(sentences):
                        logging.warning(f"No reference sentence found for ref_id {ref_id}. Skipping.")
                        continue
                    ref_sentence = sentences[sentence_idx]
                except IndexError:
                    logging.warning(f"Reference sentence index {sentence_idx} out of range for ref_id {ref_id}. Skipping.")
                    continue

                for var_id, is_reference in group:
                    try:
                        writer.writerow([f'# {var_id}'])
                        # Use reference sentence for both reference and variants
                        if is_reference:
                            # Write reference sentence directly
                            for row in ref_sentence:
                                if len(row) >= 6:
                                    writer.writerow([row[1], row[3], row[4], row[5]])
                                else:
                                    logging.warning(f"Skipping malformed row in {var_id}: {row}")
                        else:
                            # Generate variant CoNLL
                            variant_conll = generate_variant_conll(ref_sentence, variants[var_id][0], var_id)
                            if not variant_conll:
                                logging.warning(f"No CoNLL data generated for {var_id}")
                                continue
                            for row in variant_conll:
                                if len(row) >= 5:
                                    writer.writerow([row[1], row[2], row[3], row[4]])
                                else:
                                    logging.warning(f"Invalid row in {var_id}: {row}")
                        writer.writerow([])
                    except Exception as e:
                        logging.error(f"Error processing variant {var_id}: {e}")
                        continue

        logging.info(f"Output written to {output_file}")
    except Exception as e:
        logging.error(f"Error generating output: {e}")
        raise

# File paths
conll_file = 'hutb_conll.txt'
variant_file = 'hutb_katle_final.csv'
output_file = 'katl_e_bayan_conll_final.csv'

# Run the function
if __name__ == "__main__":
    generate_conll_output(conll_file, variant_file, output_file)