In [None]:
import pandas as pd
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from tqdm.notebook import tqdm
import re
import nltk
from nltk.tokenize import sent_tokenize
import time
import warnings
import os
warnings.filterwarnings('ignore')

# Download NLTK resources for sentence splitting
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')
    
# Download additional NLTK resources that may help with text processing
try:
    nltk.data.find('tokenizers/wordnet')
except LookupError:
    nltk.download('wordnet')

class SinhalaEnglishTranslator:
    def __init__(self, model_path= '/Users/pasindumalinda/Downloads/nllb-200-1.3B'):
        """Initialize the translator with the specified model
        
        Args:
            model_path (str): Local path to the model directory
        """
        self.model_path = model_path
        self.src_lang = "sin_Sinh"  # Sinhala
        self.tgt_lang = "eng_Latn"  # English
        
        print(f"Loading tokenizer from local path: {model_path}...")
        
        # First check if config.json exists, if not create a minimal one
        config_path = os.path.join(model_path, "config.json")
        if not os.path.exists(config_path):
            print("Config file not found. Creating minimal config.json...")
            import json
            config = {
                "model_type": "nllb",
                "architectures": ["NllbForConditionalGeneration"],
                "_name_or_path": "facebook/nllb-200-1.3B"
            }
            os.makedirs(model_path, exist_ok=True)
            with open(config_path, 'w') as f:
                json.dump(config, f)
            print(f"Created config.json at {config_path}")
        
        # Try to load the tokenizer
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        except Exception as e:
            print(f"Error loading local tokenizer: {e}")
            print("Falling back to downloading from HuggingFace...")
            self.tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-1.3B")
            # Save tokenizer to local path
            self.tokenizer.save_pretrained(model_path)
            print(f"Saved tokenizer to {model_path}")
        
        print("Tokenizer loaded successfully!")
        
        print(f"Loading model from local path: {model_path}...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        
        # Try to load the model
        try:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True)
            self.model = self.model.to(self.device)
        except Exception as e:
            print(f"Error loading local model: {e}")
            print("Falling back to downloading from HuggingFace...")
            self.model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-1.3B")
            self.model = self.model.to(self.device)
            # Save model to local path
            self.model.save_pretrained(model_path)
            print(f"Saved model to {model_path}")
        
        print("Model loaded successfully!")
        
        # Ensure we have language code mapping for NLLB models
        # This is essential for the forced_bos_token_id in translation
        if not hasattr(self.tokenizer, 'lang_code_to_id'):
            print("Adding language code mapping to tokenizer...")
            self.tokenizer.lang_code_to_id = {
                "sin_Sinh": 50264,  # Sinhala
                "eng_Latn": 128022,  # English
                # Add more languages as needed
            }
        
        # Create a fallback pipeline using a different model for challenging cases
        try:
            print("Setting up fallback translator (Helsinki-NLP)...")
            self.fallback_translator = pipeline(
                "translation", 
                model="Helsinki-NLP/opus-mt-si-en",
                device=0 if torch.cuda.is_available() else -1
            )
            self.has_fallback = True
        except Exception as e:
            print(f"Fallback translator not available: {e}")
            self.has_fallback = False
        
        # Common Sinhala-English mixed phrases that should be preserved
        self.special_phrases = [
            "Kohomada", "karanna oni", "api", "mama", "oya", "eyala",
            "ayubowan", "istuti", "bohoma istuti", "karunakara", 
            "oba", "mage", "samahara", "kawadada", "koheda", 
            "mokada", "ai", "kawuru", "kohomada", "monawada", "ATM", "atm", "bank",
            # Add more phrases specific to your dataset
        ]
        
        # Setup a translation cache to avoid re-translating identical text
        self.translation_cache = {}
        
        # Collection of problematic and good translations for analysis
        self.problem_translations = []
        self.good_translations = []
        
        print("Translator initialized successfully")
    
    def preprocess_text(self, text):
        """Clean and prepare text for translation"""
        if pd.isna(text) or text is None:
            return ""
        
        # Convert to string if not already
        text = str(text)
        
        # Remove URLs
        text = re.sub(r'https?://\S+|www\.\S+', '', text)
        
        # Replace multiple spaces, newlines and tabs with single space
        text = re.sub(r'\s+', ' ', text)
        
        # Remove emojis (optional - comment out if you want to keep them)
        emoji_pattern = re.compile("["
                                   u"\U0001F600-\U0001F64F"  # emoticons
                                   u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                                   u"\U0001F680-\U0001F6FF"  # transport & map symbols
                                   u"\U0001F700-\U0001F77F"  # alchemical symbols
                                   u"\U0001F780-\U0001F7FF"  # Geometric Shapes
                                   u"\U0001F800-\U0001F8FF"  # Supplemental Arrows-C
                                   u"\U0001F900-\U0001F9FF"  # Supplemental Symbols and Pictographs
                                   u"\U0001FA00-\U0001FA6F"  # Chess Symbols
                                   u"\U0001FA70-\U0001FAFF"  # Symbols and Pictographs Extended-A
                                   u"\U00002702-\U000027B0"  # Dingbats
                                   "]+", flags=re.UNICODE)
        text = emoji_pattern.sub(r'', text)
        
        # Trim whitespace
        text = text.strip()
        
        return text
    
    def identify_special_terms(self, text):
        """Identify terms that should be preserved in the translation"""
        special_terms = []
        
        # Check for common phrases that should be preserved
        for phrase in self.special_phrases:
            if phrase.lower() in text.lower():
                # Find the actual occurrence with original casing
                matches = re.finditer(re.escape(phrase), text, re.IGNORECASE)
                for match in matches:
                    special_terms.append(text[match.start():match.end()])
        
        # Find mixed language terms (improved detection)
        # This regex looks for words with both English and Sinhala characters
        mixed_pattern = r'\b[a-zA-Z]*[\u0D80-\u0DFF]+[a-zA-Z]*\b|\b[a-zA-Z]+[\u0D80-\u0DFF]*[a-zA-Z]*\b'
        mixed_terms = re.findall(mixed_pattern, text)
        special_terms.extend(mixed_terms)
        
        # Find social media expressions (e.g., emoticons, hashtags)
        social_pattern = r'#\w+|@\w+|:\)|:\(|;\)|:D|:P|<3'
        social_terms = re.findall(social_pattern, text)
        special_terms.extend(social_terms)
        
        # Numbers with units (preserve as is)
        number_pattern = r'\b\d+(?:\.\d+)?\s*[a-zA-Z]+\b'  # e.g., "10kg", "5.5cm"
        number_terms = re.findall(number_pattern, text)
        special_terms.extend(number_terms)
        
        # Find proper nouns (simplified approach - words starting with capital letters)
        proper_noun_pattern = r'\b[A-Z][a-zA-Z]*\b'
        proper_nouns = re.findall(proper_noun_pattern, text)
        special_terms.extend(proper_nouns)
        
        # Remove duplicates while preserving order
        unique_terms = []
        for term in special_terms:
            if term not in unique_terms:
                unique_terms.append(term)
        
        return unique_terms
    
    def chunk_long_text(self, text, max_length=512):
        """Break long text into manageable chunks for translation"""
        if len(text) <= max_length:
            return [text]
        
        # Try to split by sentences first
        sentences = sent_tokenize(text)
        chunks = []
        current_chunk = ""
        
        for sentence in sentences:
            # Check if adding this sentence would exceed max_length
            if len(current_chunk) + len(sentence) + 1 <= max_length:  # +1 for the space
                current_chunk += " " + sentence if current_chunk else sentence
            else:
                # If current chunk is not empty, add it to chunks
                if current_chunk:
                    chunks.append(current_chunk)
                
                # If sentence itself is too long, split it by words
                if len(sentence) > max_length:
                    words = sentence.split()
                    current_chunk = ""
                    for word in words:
                        if len(current_chunk) + len(word) + 1 <= max_length:  # +1 for the space
                            current_chunk += " " + word if current_chunk else word
                        else:
                            chunks.append(current_chunk)
                            current_chunk = word
                else:
                    # Sentence is shorter than max_length but doesn't fit in current chunk
                    current_chunk = sentence
        
        # Don't forget to add the last chunk if it's not empty
        if current_chunk:
            chunks.append(current_chunk)
        
        return chunks
    
    def translate_with_fallback(self, chunk):
        """Attempt translation with fallback model if available"""
        if not self.has_fallback:
            return None
            
        try:
            # The Helsinki-NLP model uses a different approach
            result = self.fallback_translator(chunk, max_length=512)
            if result and isinstance(result, list) and len(result) > 0:
                return result[0]['translation_text']
        except Exception as e:
            print(f"Fallback translation error: {e}")
        return None
    
    def translate_text(self, text, max_length=512):
        """Translate text from Sinhala to English with special term handling"""
        if not text or pd.isna(text) or len(text.strip()) == 0:
            return ""
        
        # Check cache first
        if text in self.translation_cache:
            return self.translation_cache[text]
        
        # Preprocess
        text = self.preprocess_text(text)
        if not text:
            return ""
        
        # Identify special terms
        special_terms = self.identify_special_terms(text)
        
        # Create placeholders for special terms
        placeholders = {}
        for i, term in enumerate(special_terms):
            placeholder = f"__SPECIAL_TERM_{i}__"
            placeholders[placeholder] = term
            text = text.replace(term, placeholder)
        
        # Handle long text by chunking
        chunks = self.chunk_long_text(text, max_length)
        translated_chunks = []
        
        for chunk in chunks:
            # Skip empty chunks
            if not chunk.strip():
                continue
                
            # Add source language tag to the input text for NLLB model
            tagged_chunk = f">>{self.src_lang}<< {chunk}"
            
            try:
                # Tokenize
                inputs = self.tokenizer(tagged_chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                # Get the forced BOS token ID for the target language
                try:
                    forced_bos_token_id = self.tokenizer.lang_code_to_id[self.tgt_lang]
                except (KeyError, AttributeError):
                    # Fallback for older tokenizers that don't have lang_code_to_id
                    print("Using default forced_bos_token_id for English")
                    forced_bos_token_id = 128022  # Default ID for English in NLLB
                
                # Translate with primary model
                with torch.no_grad():
                    translated_tokens = self.model.generate(
                        **inputs,
                        forced_bos_token_id=forced_bos_token_id,
                        max_length=max_length,
                        num_beams=5,  # You can increase this for higher quality but slower translation
                        num_return_sequences=1,
                        length_penalty=1.0,
                        early_stopping=True
                    )
                
                # Decode
                translation = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
                
                # Quality check - if translation seems problematic (very short or unchanged),
                # try fallback model if available
                if (len(translation) < 10 and len(chunk) > 20) or translation == chunk:
                    fallback_translation = self.translate_with_fallback(chunk)
                    if fallback_translation:
                        translation = fallback_translation
                        print(f"Used fallback for chunk: {chunk[:30]}...")
                
                translated_chunks.append(translation)
                
            except Exception as e:
                print(f"Error translating chunk: {e}")
                # Try fallback model on error
                fallback_translation = self.translate_with_fallback(chunk)
                if fallback_translation:
                    translated_chunks.append(fallback_translation)
                else:
                    # If all fails, add original chunk
                    translated_chunks.append(chunk)
                    print(f"Failed to translate chunk: {chunk[:50]}...")
        
        # Combine chunks
        translation = " ".join(translated_chunks)
        
        # Replace placeholders with original terms
        for placeholder, term in placeholders.items():
            translation = translation.replace(placeholder, term)
        
        # Apply post-processing
        translation = self.post_process_translation(translation)
        
        # Cache the translation
        self.translation_cache[text] = translation
        
        # Collect statistics (optional)
        if len(translation.split()) < len(text.split()) / 3:
            self.problem_translations.append((text, translation))
        elif len(translation.split()) > 5:
            self.good_translations.append((text, translation))
        
        return translation
    
    def post_process_translation(self, translation):
        """Apply post-processing to improve translation quality"""
        if pd.isna(translation):
            return ""
        
        # Fix common translation artifacts
        translation = re.sub(r'\s+', ' ', translation)  # Remove multiple spaces
        translation = translation.strip()
        
        # Fix common punctuation issues
        translation = re.sub(r'\s+([.,;:!?])', r'\1', translation)  # No space before punctuation
        translation = re.sub(r'([.,;:!?])([^\s])', r'\1 \2', translation)  # Space after punctuation
        
        # Fix capitalization
        sentences = re.split(r'(?<=[.!?])\s+', translation)
        sentences = [s.capitalize() for s in sentences if s]
        translation = ' '.join(sentences)
        
        # Fix common translation errors (you can expand this list)
        common_errors = {
            # Common translation errors from Sinhala to English
            "the the": "the",
            "a the": "the",
            "an the": "the",
            # Add more errors that you observe in your translations
        }
        
        for error, correction in common_errors.items():
            translation = re.sub(r'\b' + re.escape(error) + r'\b', correction, translation, flags=re.IGNORECASE)
        
        return translation
    
    def batch_translate(self, df, column_name, batch_size=16, output_column='english_translation'):
        """Translate a dataframe column in batches with progress tracking and error handling"""
        total_rows = len(df)
        df_result = df.copy()
        
        # Add a new column for translations if it doesn't exist
        if output_column not in df_result.columns:
            df_result[output_column] = ""
        
        # Add error tracking column
        error_column = 'translation_error'
        if error_column not in df_result.columns:
            df_result[error_column] = False
        
        # Add translation confidence scores (based on simple heuristics)
        confidence_column = 'translation_confidence'
        if confidence_column not in df_result.columns:
            df_result[confidence_column] = 0.0
        
        error_count = 0
        
        try:
            for i in tqdm(range(0, total_rows, batch_size), desc="Translating batches"):
                end_idx = min(i + batch_size, total_rows)
                batch = df.iloc[i:end_idx]
                
                for idx, row in tqdm(batch.iterrows(), desc=f"Batch {i//batch_size + 1}", leave=False, total=len(batch)):
                    try:
                        text = row[column_name]
                        
                        # Skip translation if the text is empty
                        if pd.isna(text) or not str(text).strip():
                            df_result.at[idx, output_column] = ""
                            df_result.at[idx, confidence_column] = 0.0
                            continue
                        
                        # Translate
                        translation = self.translate_text(text)
                        
                        # Set translation
                        df_result.at[idx, output_column] = translation
                        df_result.at[idx, error_column] = False
                        
                        # Calculate a rough confidence score based on heuristics
                        src_words = len(str(text).split())
                        tgt_words = len(translation.split())
                        
                        # If target is too short compared to source, lower confidence
                        if src_words > 3 and tgt_words < src_words / 3:
                            confidence = 0.3
                        # If lengths are somewhat proportional, higher confidence
                        elif tgt_words > 0 and 0.5 <= (tgt_words / max(1, src_words)) <= 2.0:
                            confidence = 0.8
                        else:
                            confidence = 0.5
                            
                        df_result.at[idx, confidence_column] = confidence
                        
                    except Exception as e:
                        print(f"Error translating row {idx}: {e}")
                        df_result.at[idx, error_column] = True
                        error_count += 1
                    
                    # Add small delay to avoid potential rate limiting or overheating
                    time.sleep(0.05)
                
                # Save intermediate results every batch
                if i > 0 and i % (batch_size * 5) == 0:
                    temp_file = f"translation_checkpoint_{i}.csv"
                    df_result.to_csv(temp_file, index=False)
                    print(f"Checkpoint saved: {temp_file}")
            
        except KeyboardInterrupt:
            print("Translation interrupted. Saving partial results...")
        
        if error_count > 0:
            print(f"Completed with {error_count} errors out of {total_rows} rows.")
        
        # Generate summary statistics
        translation_stats = {
            "total_rows": total_rows,
            "error_count": error_count,
            "success_rate": (total_rows - error_count) / total_rows * 100,
            "problem_translations": len(self.problem_translations),
            "good_translations": len(self.good_translations)
        }
        
        print(f"Translation statistics: {translation_stats}")
        
        return df_result
    
    def review_sample(self, df, original_column, translation_column, n=5):
        """Display a sample of translations for manual review"""
        if n > len(df):
            n = len(df)
            
        print(f"\nReviewing {n} random samples:")
        sample = df.sample(n)
        for i, (_, row) in enumerate(sample.iterrows()):
            print(f"Sample {i+1}:")
            print(f"Original: {row[original_column]}")
            print(f"Translation: {row[translation_column]}")
            
            # Show confidence if available
            if 'translation_confidence' in row:
                print(f"Confidence: {row['translation_confidence']:.2f}")
                
            print("-" * 80)
    
    def translate_file(self, input_file, output_file, comment_column, batch_size=16):
        """Process an entire CSV file and save the results with error handling"""
        print(f"Reading CSV file: {input_file}")
        
        try:
            # Try to read with UTF-8 encoding first
            df = pd.read_csv(input_file, encoding='utf-8')
        except UnicodeDecodeError:
            # If that fails, try with the more permissive ISO-8859-1 encoding
            print("UTF-8 encoding failed, trying ISO-8859-1...")
            df = pd.read_csv(input_file, encoding='ISO-8859-1')
        
        print(f"Found {len(df)} rows to translate")
        
        # Check if the comment column exists
        if comment_column not in df.columns:
            # Show available columns to help user
            print(f"Available columns: {', '.join(df.columns)}")
            raise ValueError(f"Column '{comment_column}' not found in the CSV file")
        
        # Create backup of original file if we're going to overwrite it
        if os.path.exists(output_file) and input_file != output_file:
            backup_file = output_file + '.bak'
            if os.path.exists(backup_file):
                print(f"Backup file {backup_file} already exists, not overwriting it.")
            else:
                import shutil
                shutil.copy2(output_file, backup_file)
                print(f"Created backup of existing output file: {backup_file}")
        
        # Translate
        df_translated = self.batch_translate(df, comment_column, batch_size)
        
        # Save results
        df_translated.to_csv(output_file, index=False)
        print(f"Translation completed! Results saved to '{output_file}'")
        
        # Show a sample of the results
        print("\nSample of translations:")
        self.review_sample(df_translated, comment_column, 'english_translation', 5)
        
        # Offer to save problem translations for review
        if self.problem_translations:
            problem_file = "problem_translations.csv"
            pd.DataFrame(self.problem_translations, columns=['original', 'translation']).to_csv(problem_file, index=False)
            print(f"Saved {len(self.problem_translations)} problematic translations to '{problem_file}' for review")
        
        return df_translated

# Example usage
if __name__ == "__main__":
    # Initialize the translator with the local model
    model_path = '/Users/pasindumalinda/Downloads/nllb-200-1.3B'
    translator = SinhalaEnglishTranslator(model_path=model_path)
    
    # Set your file paths and column name
    input_csv = "/Volumes/KODAK/folder 02/language_translation/Language_translator/data/filtered_dataset.csv"
    output_csv = "english_translated_comments.csv"
    comment_column = "Comment"  
    
    # Process the file with smaller batch size for more frequent updates
    translated_df = translator.translate_file(input_csv, output_csv, comment_column, batch_size=8)

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/pasindumalinda/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


Loading tokenizer from local path: /Users/pasindumalinda/Downloads/nllb-200-1.3B...
Tokenizer loaded successfully!
Loading model from local path: /Users/pasindumalinda/Downloads/nllb-200-1.3B...
Using device: cpu
Model loaded successfully!
Adding language code mapping to tokenizer...
Setting up fallback translator (Helsinki-NLP)...
Fallback translator not available: Helsinki-NLP/opus-mt-si-en is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`
Translator initialized successfully
Reading CSV file: /Volumes/KODAK/folder 02/language_translation/Language_translator/data/filtered_dataset.csv
Found 807 rows to translate
Backup file english_translated_comments.csv.bak already exists, not overwriting it.


Translating batches:   0%|          | 0/101 [00:00<?, ?it/s]

Batch 1:   0%|          | 0/8 [00:00<?, ?it/s]

Translation interrupted. Saving partial results...
Translation statistics: {'total_rows': 807, 'error_count': 0, 'success_rate': 100.0, 'problem_translations': 1, 'good_translations': 0}
Translation completed! Results saved to 'english_translated_comments.csv'

Sample of translations:

Reviewing 5 random samples:
Sample 1:
Original: රුපියල් 700ට මොන කාඩ්ද හලෝ ඒ මනුස්සයට උදේ පාන්දරම cash වලින් දුන්නනම් ඉවරයිනෙ මමත් කැමතිම නිලියක් තමයි ඔයා ඒත් මේකනම් මහ අසික්කිත වනචාරි වැඩක් ඔයා අනිවාර්යයෙන් මේ වවීඩියෝ එක අයින් කරගනී ඔයාට ඔයාගෙ වැරැද්ද තේරුනාම ඒත් මේකෙන් ඔයා ඔයාටම කරගත්ත ඩැමේජ් එක හැමදාටම තියෙයි
Translation: 
Confidence: 0.00
--------------------------------------------------------------------------------
Sample 2:
Original: දුවන්ද කට්ටියට විතරි දන්නවා සල්ලි කොච්චරක් හම්බු කරුත් එකේ card payment දෙමමුත් අර සල්ලි ඉතුරු වෙන්න නෙහ ඔක්කොම companya කපාගන්නවා එකී වෙඩිම කට්ටිය කෙමත්තක් වෙන්න නෙහ customers ලා ට එක හොදි but ද්රිවෙර්ස් ලා ට එක ගොඩක් වෙලාව එක හැමවෙන්න නෙහ
Translation: 
Confidence: 

# Second Try using small model 'Helsinki-NLP/opus-mt-si-en'

## Step 1: Setting up the environment

In [1]:
# Import necessary libraries
import pandas as pd
import torch
from transformers import MarianMTModel, MarianTokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time

## Step 2: Loading and exploring the CSV file

In [2]:
def load_and_explore_csv(file_path):
    """
    Load the CSV file and explore its structure.
    
    Args:
        file_path (str): Path to the CSV file containing Sinhala comments
        
    Returns:
        DataFrame: Pandas DataFrame containing the CSV data
    """
    try:
        # Read the CSV file
        # Note: You might need to adjust parameters based on your CSV structure
        df = pd.read_csv(file_path, encoding='utf-8')
        
        # Display basic information
        print(f"CSV file loaded successfully.")
        print(f"Shape of the data: {df.shape}")
        print(f"Column names: {df.columns.tolist()}")
        
        # Display first few rows
        print("\nFirst few rows of the data:")
        print(df.head())
        
        # Check for missing values
        print("\nMissing values in each column:")
        print(df.isnull().sum())
        
        return df
        
    except Exception as e:
        print(f"Error loading CSV file: {e}")
        return None

# Example usage
df = load_and_explore_csv('/Volumes/KODAK/folder 02/language_translation/Language_translator/data/filtered_dataset.csv')

CSV file loaded successfully.
Shape of the data: (807, 1)
Column names: ['Comment']

First few rows of the data:
                                             Comment
0  මොනා උනත් පොත්ත සුදු කෑල්ලක් දැකලා දෙකට නැවුනෙ...
1  බැරිනං PickMe එකෙන් අයින් වෙලා නිකං හයර් දුවපං...
2  ට් ‍ රිප් එක මැදදී කෑශ් හයර් එක කාඩ් හයර් එකට ...
3  me දුවපු කොල්ලෙක්ට විතරයි seen එක තේරෙන්නෙඋදේම...
4  මාත් හයර් දුවන්නෙ කස්ටමර්ගෙ පැත්තෙන් බලනකොට සා...

Missing values in each column:
Comment    0
dtype: int64


## Step 3: Setting up the translation model

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

def setup_translation_model():
    """
    Set up the Facebook NLLB model for Sinhala to English translation.
    
    Returns:
        tuple: (model, tokenizer, device) - The loaded model, tokenizer, and device
    """
    try:
        # Define model name
        model_name = "facebook/nllb-200-distilled-600M"
        
        # Load tokenizer and model
        print(f"Loading the tokenizer for {model_name}...")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        print(f"Loading the model {model_name}...")
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            device_map="auto",  # Automatically handles GPU placement
            torch_dtype=torch.float16  # Uses less memory
        )
        
        # Check if GPU is available
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Model loaded successfully and running on {device}")
        
        return model, tokenizer, device
        
    except Exception as e:
        print(f"Error setting up translation model: {e}")
        print("Try running: pip install --upgrade transformers flash-attn --no-build-isolation")
        return None, None, None

# Example usage
# model, tokenizer, device = setup_translation_model()

## Step 4: Creating the translation function

In [None]:
def translate_text(text, model, tokenizer, device):
    """
    Translate a Sinhala text to English using the loaded model and tokenizer.
    
    Args:
        text (str): The Sinhala text to translate
        model: The MarianMT translation model
        tokenizer: The MarianTokenizer
        device: The device (CPU/GPU) to use for translation
        
    Returns:
        str: The translated English text
    """
    try:
        # Skip translation if text is empty or None
        if not text or pd.isna(text):
            return ""
        
        # Tokenize the text
        batch = tokenizer([text], return_tensors="pt", padding=True)
        
        # Move input tensors to the device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        # Generate translation
        translated = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=512,  # Adjust this based on your needs
            num_beams=4,     # Beam search for better translations
            early_stopping=True
        )
        
        # Decode the translated output
        translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
        
        return translated_text
        
    except Exception as e:
        print(f"Error translating text: {e}")
        print(f"Problematic text: {text}")
        return f"ERROR: {str(e)}"

# Example usage
# translated = translate_text("අසරන මනුස්සයෙක්ගෙ දවසම කාලා", model, tokenizer, device)
# print(translated)

## Step 5: Batch processing function for translation

In [None]:
def translate_dataframe(df, text_column, model, tokenizer, device, batch_size=10):
    """
    Translate all texts in a specific column of a DataFrame.
    
    Args:
        df (DataFrame): The DataFrame containing the text to translate
        text_column (str): The name of the column containing Sinhala text
        model: The MarianMT translation model
        tokenizer: The MarianTokenizer
        device: The device (CPU/GPU) to use for translation
        batch_size (int): Number of translations to process before showing progress
        
    Returns:
        DataFrame: The DataFrame with an additional column containing translations
    """
    # Create a copy of the DataFrame to avoid modifying the original
    result_df = df.copy()
    
    # Add a new column for translations
    result_df['translated_text'] = ""
    
    # Get total number of rows
    total_rows = len(df)
    print(f"Starting translation of {total_rows} rows...")
    
    # Track time
    start_time = time.time()
    last_update_time = start_time
    
    # Process each row
    for i, row in df.iterrows():
        # Get the text to translate
        text = str(row[text_column])
        
        # Translate the text
        translated = translate_text(text, model, tokenizer, device)
        
        # Store the translation
        result_df.at[i, 'translated_text'] = translated
        
        # Show progress every batch_size rows
        if (i + 1) % batch_size == 0 or (i + 1) == total_rows:
            current_time = time.time()
            elapsed = current_time - start_time
            batch_elapsed = current_time - last_update_time
            last_update_time = current_time
            
            progress = (i + 1) / total_rows * 100
            rows_per_sec = batch_size / batch_elapsed if batch_elapsed > 0 else 0
            
            # Estimate remaining time
            remaining_rows = total_rows - (i + 1)
            eta_seconds = remaining_rows / rows_per_sec if rows_per_sec > 0 else 0
            eta_min = eta_seconds / 60
            
            print(f"Progress: {i+1}/{total_rows} ({progress:.2f}%) - "
                  f"Speed: {rows_per_sec:.2f} rows/sec - "
                  f"Elapsed: {elapsed/60:.2f} min - "
                  f"ETA: {eta_min:.2f} min")
            
            # Show a sample of the translation for verification
            if i < 5:  # Show only for first few rows
                print(f"Original: {text[:100]}...")
                print(f"Translated: {translated[:100]}...")
                print("-" * 50)
    
    print(f"Translation completed. Total time: {(time.time() - start_time)/60:.2f} minutes")
    return result_df

# Example usage
# translated_df = translate_dataframe(df, 'comment_column', model, tokenizer, device)

## Step 6: Handling mixed language content

In [None]:
import re

def detect_language_mix(text):
    """
    Detect if the text contains significant amounts of English already.
    
    Args:
        text (str): The text to check
        
    Returns:
        bool: True if the text contains significant English, False otherwise
    """
    if not text or pd.isna(text):
        return False
    
    # Count English words (simple heuristic)
    # English words typically use Latin script
    english_pattern = r'[a-zA-Z]+\s*[a-zA-Z]*'
    english_words = re.findall(english_pattern, text)
    
    # Count words that are likely Sinhala
    # Sinhala Unicode range: U+0D80 to U+0DFF
    sinhala_pattern = r'[\u0D80-\u0DFF]+\s*[\u0D80-\u0DFF]*'
    sinhala_words = re.findall(sinhala_pattern, text)
    
    # If no words are found, return False
    if not english_words and not sinhala_words:
        return False
    
    # Calculate the percentage of English words
    total_words = len(english_words) + len(sinhala_words)
    english_percentage = len(english_words) / total_words if total_words > 0 else 0
    
    # If more than 70% of words are English, consider it mixed/mostly English
    return english_percentage > 0.7

def smart_translate(text, model, tokenizer, device):
    """
    Intelligently translate text, handling mixed language content.
    
    Args:
        text (str): The text to translate
        model: The MarianMT translation model
        tokenizer: The MarianTokenizer
        device: The device (CPU/GPU) to use for translation
        
    Returns:
        str: The translated text
    """
    if not text or pd.isna(text):
        return ""
    
    # Check if the text is already mostly English
    if detect_language_mix(text):
        # Text is already mostly English, no need for translation
        return text
    
    # Text is primarily Sinhala, translate it
    return translate_text(text, model, tokenizer, device)

# Modify the dataframe translation function to use smart_translate
def smart_translate_dataframe(df, text_column, model, tokenizer, device, batch_size=10):
    """
    Smartly translate all texts in a specific column of a DataFrame.
    
    Args:
        df (DataFrame): The DataFrame containing the text to translate
        text_column (str): The name of the column containing Sinhala text
        model: The MarianMT translation model
        tokenizer: The MarianTokenizer
        device: The device (CPU/GPU) to use for translation
        batch_size (int): Number of translations to process before showing progress
        
    Returns:
        DataFrame: The DataFrame with an additional column containing translations
    """
    # Create a copy of the DataFrame to avoid modifying the original
    result_df = df.copy()
    
    # Add a new column for translations
    result_df['translated_text'] = ""
    
    # Get total number of rows
    total_rows = len(df)
    print(f"Starting smart translation of {total_rows} rows...")
    
    # Track time
    start_time = time.time()
    last_update_time = start_time
    
    # Counters for statistics
    translated_count = 0
    already_english_count = 0
    
    # Process each row
    for i, row in df.iterrows():
        # Get the text to translate
        text = str(row[text_column])
        
        # Check if already mostly English
        is_mostly_english = detect_language_mix(text)
        
        if is_mostly_english:
            # Keep as is
            result_df.at[i, 'translated_text'] = text
            already_english_count += 1
        else:
            # Translate the text
            translated = translate_text(text, model, tokenizer, device)
            result_df.at[i, 'translated_text'] = translated
            translated_count += 1
        
        # Show progress every batch_size rows
        if (i + 1) % batch_size == 0 or (i + 1) == total_rows:
            current_time = time.time()
            elapsed = current_time - start_time
            batch_elapsed = current_time - last_update_time
            last_update_time = current_time
            
            progress = (i + 1) / total_rows * 100
            rows_per_sec = batch_size / batch_elapsed if batch_elapsed > 0 else 0
            
            # Estimate remaining time
            remaining_rows = total_rows - (i + 1)
            eta_seconds = remaining_rows / rows_per_sec if rows_per_sec > 0 else 0
            eta_min = eta_seconds / 60
            
            print(f"Progress: {i+1}/{total_rows} ({progress:.2f}%) - "
                  f"Speed: {rows_per_sec:.2f} rows/sec - "
                  f"Elapsed: {elapsed/60:.2f} min - "
                  f"ETA: {eta_min:.2f} min")
    
    print(f"Translation completed. Total time: {(time.time() - start_time)/60:.2f} minutes")
    print(f"Statistics: {translated_count} texts translated, {already_english_count} already mostly English")
    return result_df

# Example usage
# translated_df = smart_translate_dataframe(df, 'comment_column', model, tokenizer, device)

## Step 7: Saving the translated data to a new CSV file


In [None]:
def save_translated_csv(df, output_path, include_original=True):
    """
    Save the DataFrame with translations to a CSV file.
    
    Args:
        df (DataFrame): The DataFrame containing the original text and translations
        output_path (str): Path to save the output CSV file
        include_original (bool): Whether to include the original text in the output
        
    Returns:
        bool: True if saved successfully, False otherwise
    """
    try:
        # Create a copy to avoid modifying the original
        output_df = df.copy()
        
        # Reorder columns to put translations at the end if needed
        # This is just for better readability of the CSV
        if include_original:
            # Keep all columns but ensure translation is at the end
            cols = [col for col in output_df.columns if col != 'translated_text'] + ['translated_text']
            output_df = output_df[cols]
        else:
            # Replace the original text column with the translated text
            # Find the column that was translated (assuming it's stored somewhere)
            # For now, we'll just keep all columns including translated_text
            pass
        
        # Save to CSV
        output_df.to_csv(output_path, index=False, encoding='utf-8-sig')  # utf-8-sig includes BOM for Excel compatibility
        
        print(f"Translated data saved to {output_path}")
        print(f"Total rows saved: {len(output_df)}")
        
        return True
        
    except Exception as e:
        print(f"Error saving translated CSV: {e}")
        return False

# Example usage
# success = save_translated_csv(translated_df, 'translated_comments.csv')

## Step 8: Error handling and validation

In [None]:
def validate_translations(df, original_column, translated_column):
    """
    Validate the translations to ensure quality.
    
    Args:
        df (DataFrame): The DataFrame containing original and translated text
        original_column (str): The name of the column containing original text
        translated_column (str): The name of the column containing translated text
        
    Returns:
        DataFrame: A DataFrame containing problematic translations for review
    """
    # Create a copy for validation results
    validation_df = pd.DataFrame(columns=['row_index', 'original', 'translated', 'issue'])
    
    # List to collect problematic rows
    issues = []
    
    # Check each row
    for i, row in df.iterrows():
        original = str(row[original_column])
        translated = str(row[translated_column])
        
        # Check for empty translations of non-empty originals
        if original and not translated:
            issues.append({
                'row_index': i,
                'original': original,
                'translated': translated,
                'issue': 'Empty translation'
            })
            continue
        
        # Check for very short translations of long originals
        # This could indicate truncation or incomplete translation
        if len(original) > 50 and len(translated) < 10:
            issues.append({
                'row_index': i,
                'original': original,
                'translated': translated,
                'issue': 'Suspiciously short translation'
            })
            continue
            
        # Check for translations that contain error messages
        if "ERROR:" in translated:
            issues.append({
                'row_index': i,
                'original': original,
                'translated': translated,
                'issue': 'Contains error message'
            })
            continue
    
    # Create DataFrame from issues list
    if issues:
        validation_df = pd.DataFrame(issues)
        print(f"Found {len(issues)} potentially problematic translations")
    else:
        print("No translation issues found")
    
    return validation_df

def retry_failed_translations(df, validation_df, text_column, model, tokenizer, device):
    """
    Retry the failed translations.
    
    Args:
        df (DataFrame): The original DataFrame with translations
        validation_df (DataFrame): The DataFrame containing failed translations
        text_column (str): The name of the column containing original text
        model: The MarianMT translation model
        tokenizer: The MarianTokenizer
        device: The device (CPU/GPU) to use for translation
        
    Returns:
        DataFrame: The updated DataFrame with retried translations
    """
    # Create a copy to avoid modifying the original
    result_df = df.copy()
    
    if validation_df.empty:
        print("No failed translations to retry")
        return result_df
    
    print(f"Retrying {len(validation_df)} failed translations...")
    
    # Retry each failed translation
    for _, row in validation_df.iterrows():
        row_index = row['row_index']
        original_text = row['original']
        
        # Try translation with different parameters
        try:
            # Simple retry with same parameters
            translated = translate_text(original_text, model, tokenizer, device)
            
            # Update the DataFrame
            result_df.at[row_index, 'translated_text'] = translated
            print(f"Successfully retried translation for row {row_index}")
            print(f"Original: {original_text[:50]}...")
            print(f"New translation: {translated[:50]}...")
            print("-" * 50)
            
        except Exception as e:
            print(f"Still failed to translate row {row_index}: {e}")
    
    return result_df

# Example usage
# validation_results = validate_translations(translated_df, 'comment_column', 'translated_text')
# if not validation_results.empty:
#     translated_df = retry_failed_translations(translated_df, validation_results, 'comment_column', model, tokenizer, device)

## Step 9: Main execution script

In [None]:
def main():
    """
    Main function to execute the translation process.
    """
    import argparse
    
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Translate Sinhala comments to English for NLP analysis')
    parser.add_argument('--input', type=str, required=True, help='Path to input CSV file')
    parser.add_argument('--output', type=str, required=True, help='Path to output CSV file')
    parser.add_argument('--column', type=str, required=True, help='Name of column containing text to translate')
    parser.add_argument('--batch_size', type=int, default=10, help='Batch size for progress updates')
    
    args = parser.parse_args()
    
    print("Starting translation process...")
    print(f"Input file: {args.input}")
    print(f"Output file: {args.output}")
    print(f"Column to translate: {args.column}")
    
    # Step 1: Load the CSV file
    df = load_and_explore_csv(args.input)
    if df is None:
        print("Failed to load CSV file. Exiting.")
        return
    
    # Verify that the specified column exists
    if args.column not in df.columns:
        print(f"Column '{args.column}' not found in the CSV file.")
        print(f"Available columns: {df.columns.tolist()}")
        return
    
    # Step 2: Set up the translation model
    model, tokenizer, device = setup_translation_model()
    if model is None or tokenizer is None:
        print("Failed to set up translation model. Exiting.")
        return
    
    # Step 3: Translate the texts
    print("Starting translation process...")
    translated_df = smart_translate_dataframe(df, args.column, model, tokenizer, device, args.batch_size)
    
    # Step 4: Validate translations
    validation_results = validate_translations(translated_df, args.column, 'translated_text')
    
    # Step 5: Retry failed translations if needed
    if not validation_results.empty:
        print("Retrying failed translations...")
        translated_df = retry_failed_translations(translated_df, validation_results, args.column, model, tokenizer, device)
        
        # Re-validate to check how many issues were resolved
        new_validation = validate_translations(translated_df, args.column, 'translated_text')
        if not new_validation.empty:
            print(f"After retry, {len(new_validation)} translations still have issues")
            print("Saving problematic translations for manual review...")
            new_validation.to_csv(args.output.replace('.csv', '_issues.csv'), index=False)
    
    # Step 6: Save the results
    success = save_translated_csv(translated_df, args.output)
    if success:
        print(f"Translation process completed successfully. Results saved to {args.output}")
    else:
        print("Failed to save results.")

if __name__ == "__main__":
    main()

## Step 10: Test and example script for direct use

In [None]:
"""
Example script to demonstrate the usage of the translation functions.
"""
# Import all necessary functions from the above modules
# In a real implementation, these would be imported from properly organized modules

import pandas as pd
import torch
from transformers import MarianMTModel, MarianTokenizer
import time
import re

# Import our functions (in reality, these would be defined in the script or imported)
# You would need to replace these imports with the actual functions from Steps 1-9

# Here's a self-contained example that shows how to use the code:

def example_with_sample_data():
    """
    Example using the sample data provided in the requirements.
    """
    # Create a sample DataFrame with the provided example
    sample_data = pd.DataFrame({
        'comment': [
            'අසරන මනුස්සයෙක්ගෙ දවසම කාලා නිලිය sooting යන්නෙත් pikme weel එකක ඒකට දෙන්න මුංට සල්ලිත් නැ සල්ලි නැත්තම් atm එකකින් බැහැල කාඩ් එකෙන් සල්ලි අරන් දෙන්න තිබ්බනෙ මෙයා රටෙ මිනිස්සුන්ගෙන් මදිවට වීල් වල යන අයගෙනුත් කුණු බැනුම් අහගන්නවා',
            'PickMe this is a serious nonsense aththatama gaman cancel karana ekai gewanna wei kiala nathnm complaint ekak dammoth ape numbers walata call aran banina ekai lata kisima safety ekak naha complaint ekak daaddi ape address ekath ekkalu yanne U all have to update ur system and choose professional quality drivers n give them strict warnings about the service u all providing',
            'நெருக்கடியான சூழ்நிலையில். .. .. .. .. .. .. .. .. .. .. .. .. .. நெருக்கடியான சூழ்நிலையில். .. .. .. நெருக்கடியான சூழ்நிலையில். .. .. .. . நெருக்கடியான சூழ்நிலையில். .. .. .. .. .. .. .. .. .. .. . நெருக்கடிகள். .. .. .. .. .. .. .. .. .. .. . நெருக்கடியான சூழ்நிலையில். .. .. .. நெருக்கடியான சூழ்நிலையில். .. .. .. .. .. .. .. .. .. நெருக்கடியான சூழ்நிலையில். .. .. .. .. .. .. .. .. .. .'
        ]
    })
    
    # Save the sample data to a temporary CSV file
    sample_csv_path = 'sample_comments.csv'
    sample_data.to_csv(sample_csv_path, index=False)
    
    print("Sample CSV created with the provided examples.")
    
    # Set up the translation model
    print("Setting up the translation model...")
    model_name = "facebook/nllb-200-distilled-600M"
    tokenizer = MarianTokenizer.from_pretrained(model_name)
    model = MarianMTModel.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Function to detect if text is mostly English
    def detect_language_mix(text):
        if not text or pd.isna(text):
            return False
        
        english_pattern = r'[a-zA-Z]+\s*[a-zA-Z]*'
        english_words = re.findall(english_pattern, text)
        
        sinhala_pattern = r'[\u0D80-\u0DFF]+\s*[\u0D80-\u0DFF]*'
        sinhala_words = re.findall(sinhala_pattern, text)
        
        # Also check for Tamil script (since one example contains Tamil)
        tamil_pattern = r'[\u0B80-\u0BFF]+\s*[\u0B80-\u0BFF]*'
        tamil_words = re.findall(tamil_pattern, text)
        
        non_english_words = len(sinhala_words) + len(tamil_words)
        
        # If no words are found, return False
        if not english_words and non_english_words == 0:
            return False
        
        # Calculate the percentage of English words
        total_words = len(english_words) + non_english_words
        english_percentage = len(english_words) / total_words if total_words > 0 else 0
        
        # If more than 70% of words are English, consider it mixed/mostly English
        return english_percentage > 0.7
    
    # Function to translate text
    def translate_text(text, model, tokenizer, device):
        try:
            if not text or pd.isna(text):
                return ""
            
            # Skip translation if text contains Tamil script
            # Since our model is only for Sinhala-English
            if re.search(r'[\u0B80-\u0BFF]', text):
                return "[Tamil text detected - requires a different translation model]"
            
            # Tokenize the text
            batch = tokenizer([text], return_tensors="pt", padding=True)
            
            # Move input tensors to the device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            
            # Generate translation
            translated = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=512,
                num_beams=4,
                early_stopping=True
            )
            
            # Decode the translated output
            translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
            
            return translated_text
            
        except Exception as e:
            print(f"Error translating text: {e}")
            print(f"Problematic text: {text}")
            return f"ERROR: {str(e)}"
    
    # Function to smartly translate text
    def smart_translate(text, model, tokenizer, device):
        if not text or pd.isna(text):
            return ""
        
        # Check if the text is already mostly English
        if detect_language_mix(text):
            # Text is already mostly English, no need for translation
            return text
        
        # Text is primarily Sinhala or another language, translate it
        return translate_text(text, model, tokenizer, device)
    
    # Process each example
    print("\nTranslating sample texts:")
    for i, comment in enumerate(sample_data['comment']):
        print(f"\nExample {i+1}:")
        print(f"Original: {comment[:100]}...")
        
        # Translate the text
        translated = smart_translate(comment, model, tokenizer, device)
        print(f"Translated: {translated[:100]}...")
        
        # Update the DataFrame with the translation
        sample_data.at[i, 'translated_text'] = translated
    
    # Save the results
    output_csv_path = 'translated_sample_comments.csv'
    sample_data.to_csv(output_csv_path, index=False, encoding='utf-8-sig')
    
    print(f"\nTranslations completed and saved to {output_csv_path}")
    print("Review of translations:")
    
    # Show the final results
    for i, row in sample_data.iterrows():
        print(f"\nExample {i+1}:")
        print(f"Original: {row['comment'][:100]}...")
        print(f"Translated: {row['translated_text'][:100]}...")

# Run the example
if __name__ == "__main__":
    print("Running example with sample data...")
    example_with_sample_data()