In [1]:
import torch
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import (
    MarianMTModel,
    MarianTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
import os
import warnings
warnings.filterwarnings("ignore")

def prepare_translation_data():
    """Prepare translation data from samanantar dataset."""
    print("Loading samanantar dataset...")
    dataset = load_dataset("ai4bharat/samanantar", "hi", split='train')
    
    # Filter valid pairs and create smaller dataset for testing
    valid_pairs = []
    for i, example in enumerate(dataset):
        if i >= 5000:  # Use smaller dataset for debugging
            break
        if example['src'] and example['tgt'] and len(example['src'].strip()) > 0 and len(example['tgt'].strip()) > 0:
            valid_pairs.append({
                'english': example['src'].strip(),
                'hindi': example['tgt'].strip()
            })
    
    print(f"Found {len(valid_pairs)} valid translation pairs")
    
    # Split into train/val
    split_idx = int(0.9 * len(valid_pairs))
    train_pairs = valid_pairs[:split_idx]
    val_pairs = valid_pairs[split_idx:]
    
    return train_pairs, val_pairs

def fine_tune_marian_en_hi():
    """Fine-tune MarianMT model for English-Hindi translation."""
    print("--- Starting Fine-Tuning MarianMT English-Hindi ---")
    
    model_name = "Helsinki-NLP/opus-mt-en-hi"
    output_dir = "results/marian_en_hi_finetuned"
    
    try:
        # Load model and tokenizer
        print(f"Loading {model_name}...")
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        
        # Prepare data
        train_pairs, val_pairs = prepare_translation_data()
        
        def preprocess_function(examples):
            inputs = examples["english"]
            targets = examples["hindi"] 
            model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding=False)
            
            # Setup the tokenizer for targets
            labels = tokenizer(text_target=targets, max_length=128, truncation=True, padding=False)
            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
        
        # Convert to datasets
        train_dataset = Dataset.from_list(train_pairs)
        val_dataset = Dataset.from_list(val_pairs)
        
        print("Preprocessing datasets...")
        train_dataset = train_dataset.map(preprocess_function, batched=True)
        val_dataset = val_dataset.map(preprocess_function, batched=True)
        
        # Data collator
        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
        
        # Training arguments
        training_args = Seq2SeqTrainingArguments(
            output_dir=output_dir,
            evaluation_strategy="steps",
            eval_steps=200,
            logging_steps=50,
            save_steps=200,
            save_total_limit=2,
            learning_rate=3e-5,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=16,
            num_train_epochs=3,
            weight_decay=0.01,
            warmup_steps=200,
            predict_with_generate=True,
            fp16=torch.cuda.is_available(),
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
        )
        
        # Trainer
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )
        
        print("Starting English-Hindi training...")
        trainer.train()
        
        print(f"Saving English-Hindi model to {output_dir}")
        trainer.save_model(output_dir)
        tokenizer.save_pretrained(output_dir)
        
        return output_dir
        
    except Exception as e:
        print(f"MarianMT English-Hindi failed: {e}")
        return None

def fine_tune_marian_hi_en():
    """Fine-tune MarianMT model for Hindi-English translation."""
    print("--- Starting Fine-Tuning MarianMT Hindi-English ---")
    
    model_name = "Helsinki-NLP/opus-mt-hi-en"
    output_dir = "results/marian_hi_en_finetuned"
    
    try:
        # Load model and tokenizer
        print(f"Loading {model_name}...")
        tokenizer = MarianTokenizer.from_pretrained(model_name)
        model = MarianMTModel.from_pretrained(model_name)
        
        # Prepare data (reversed for Hindi to English)
        train_pairs, val_pairs = prepare_translation_data()
        
        # Reverse the pairs for Hindi-English translation
        train_pairs_reversed = [{'english': pair['hindi'], 'hindi': pair['english']} for pair in train_pairs]
        val_pairs_reversed = [{'english': pair['hindi'], 'hindi': pair['english']} for pair in val_pairs]
        
        def preprocess_function(examples):
            inputs = examples["english"]  # Now contains Hindi text
            targets = examples["hindi"]   # Now contains English text
            model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding=False)
            
            # Setup the tokenizer for targets
            labels = tokenizer(text_target=targets, max_length=128, truncation=True, padding=False)
            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
        
        # Convert to datasets
        train_dataset = Dataset.from_list(train_pairs_reversed)
        val_dataset = Dataset.from_list(val_pairs_reversed)
        
        print("Preprocessing datasets...")
        train_dataset = train_dataset.map(preprocess_function, batched=True)
        val_dataset = val_dataset.map(preprocess_function, batched=True)
        
        # Data collator
        data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
        
        # Training arguments
        training_args = Seq2SeqTrainingArguments(
            output_dir=output_dir,
            evaluation_strategy="steps",
            eval_steps=200,
            logging_steps=50,
            save_steps=200,
            save_total_limit=2,
            learning_rate=3e-5,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=16,
            num_train_epochs=3,
            weight_decay=0.01,
            warmup_steps=200,
            predict_with_generate=True,
            fp16=torch.cuda.is_available(),
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
        )
        
        # Trainer
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )
        
        print("Starting Hindi-English training...")
        trainer.train()
        
        print(f"Saving Hindi-English model to {output_dir}")
        trainer.save_model(output_dir)
        tokenizer.save_pretrained(output_dir)
        
        return output_dir
        
    except Exception as e:
        print(f"MarianMT Hindi-English failed: {e}")
        return None

def fine_tune_t5_model():
    """Fallback: Fine-tune T5 model for bidirectional translation."""
    from transformers import T5ForConditionalGeneration, T5Tokenizer
    
    model_name = "t5-small"
    output_dir = "results/t5_bidirectional_finetuned"
    
    print(f"Loading {model_name}...")
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    
    # Prepare data
    train_pairs, val_pairs = prepare_translation_data()
    
    # Create bidirectional training data with T5 format
    train_data = []
    val_data = []
    
    # English to Hindi
    for pair in train_pairs:
        train_data.append({
            'input_text': f"translate English to Hindi: {pair['english']}",
            'target_text': pair['hindi']
        })
        # Hindi to English (reverse)
        train_data.append({
            'input_text': f"translate Hindi to English: {pair['hindi']}",
            'target_text': pair['english']
        })
    
    for pair in val_pairs:
        val_data.append({
            'input_text': f"translate English to Hindi: {pair['english']}",
            'target_text': pair['hindi']
        })
        val_data.append({
            'input_text': f"translate Hindi to English: {pair['hindi']}",
            'target_text': pair['english']
        })
    
    def preprocess_function(examples):
        inputs = examples["input_text"]
        targets = examples["target_text"]
        model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding=False)
        labels = tokenizer(text_target=targets, max_length=128, truncation=True, padding=False)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    # Convert to datasets
    train_dataset = Dataset.from_list(train_data)
    val_dataset = Dataset.from_list(val_data)
    
    print("Preprocessing datasets...")
    train_dataset = train_dataset.map(preprocess_function, batched=True)
    val_dataset = val_dataset.map(preprocess_function, batched=True)
    
    # Data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
    
    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="steps",
        eval_steps=200,
        logging_steps=50,
        save_steps=200,
        save_total_limit=2,
        learning_rate=3e-4,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=5,
        weight_decay=0.01,
        warmup_steps=200,
        predict_with_generate=True,
        fp16=torch.cuda.is_available(),
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
    )
    
    # Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    
    print("Starting T5 bidirectional training...")
    trainer.train()
    
    print(f"Saving T5 model to {output_dir}")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    return output_dir

class UniversalTranslator:
    """Universal translator that works with different model types."""
    
    def __init__(self, model_path):
        print(f"Loading translator from: {model_path}")
        
        # Detect model type
        config_path = os.path.join(model_path, "config.json")
        self.model_type = "unknown"
        
        if os.path.exists(config_path):
            import json
            with open(config_path, 'r') as f:
                config = json.load(f)
                if "marian" in config.get("architectures", [""])[0].lower():
                    self.model_type = "marian"
                elif "t5" in config.get("architectures", [""])[0].lower():
                    self.model_type = "t5"
        
        print(f"Detected model type: {self.model_type}")
        
        if self.model_type == "marian":
            self.tokenizer = MarianTokenizer.from_pretrained(model_path)
            self.model = MarianMTModel.from_pretrained(model_path)
        else:  # Default to T5
            from transformers import T5ForConditionalGeneration, T5Tokenizer
            self.tokenizer = T5Tokenizer.from_pretrained(model_path)
            self.model = T5ForConditionalGeneration.from_pretrained(model_path)
            self.model_type = "t5"
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        
        print("Translator ready!")

    def translate(self, text, src_lang='en', tgt_lang='hi'):
        """Translate text based on model type."""
        if self.model_type == "marian":
            # MarianMT expects plain text
            input_text = text
        else:  # T5
            lang_map = {'en': 'English', 'hi': 'Hindi'}
            input_text = f"translate {lang_map[src_lang]} to {lang_map[tgt_lang]}: {text}"
        
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=128,
            truncation=True,
            padding=True
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=128,
                num_beams=4,
                length_penalty=0.6,
                early_stopping=True,
                do_sample=False,
            )
        
        # Decode
        translation = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return translation.strip()

class TranslationEvaluator:
    def __init__(self):
        self.smoothing = SmoothingFunction().method1
    
    def calculate_bleu(self, reference, candidate):
        if not reference.strip() or not candidate.strip():
            return 0.0
        ref_tokens = nltk.word_tokenize(reference.lower())
        cand_tokens = nltk.word_tokenize(candidate.lower())
        return round(sentence_bleu([ref_tokens], cand_tokens, smoothing_function=self.smoothing) * 100, 2)
    
    def calculate_meteor(self, reference, candidate):
        if not reference.strip() or not candidate.strip():
            return 0.0
        ref_tokens = nltk.word_tokenize(reference.lower())
        cand_tokens = nltk.word_tokenize(candidate.lower())
        return round(meteor_score([ref_tokens], cand_tokens) * 100, 2)

def run_evaluation(model_paths):
    """Evaluate the trained models."""
    print("\n--- Starting Evaluation ---")
    
    # Download NLTK data
    for corpus in ['punkt', 'wordnet', 'omw-1.4']:
        nltk.download(corpus, quiet=True)

    evaluator = TranslationEvaluator()
    
    # Test cases for both directions
    en_to_hi_cases = [
        {'source': 'Hello', 'reference': 'नमस्ते'},
        {'source': 'How are you?', 'reference': 'आप कैसे हैं?'},
        {'source': 'Good morning', 'reference': 'सुप्रभात'},
        {'source': 'Thank you', 'reference': 'धन्यवाद'},
        {'source': 'I am fine', 'reference': 'मैं ठीक हूं'},
    ]
    
    hi_to_en_cases = [
        {'source': 'नमस्ते', 'reference': 'Hello'},
        {'source': 'आप कैसे हैं?', 'reference': 'How are you?'},
        {'source': 'सुप्रभात', 'reference': 'Good morning'},
        {'source': 'धन्यवाद', 'reference': 'Thank you'},
        {'source': 'मैं ठीक हूं', 'reference': 'I am fine'},
    ]
    
    # Evaluate both directions
    for direction, cases, model_path in [
        ("English-Hindi", en_to_hi_cases, model_paths.get('en_hi')),
        ("Hindi-English", hi_to_en_cases, model_paths.get('hi_en'))
    ]:
        if not model_path or not os.path.exists(model_path):
            print(f"Model not found for {direction}: {model_path}")
            continue
            
        print(f"\n--- Evaluating {direction} Translation ---")
        translator = UniversalTranslator(model_path)
        results = []
        
        for i, case in enumerate(cases, 1):
            print(f"{i:2d}. '{case['source']}'")
            
            try:
                src_lang = 'en' if direction == "English-Hindi" else 'hi'
                tgt_lang = 'hi' if direction == "English-Hindi" else 'en'
                
                prediction = translator.translate(case['source'], src_lang, tgt_lang)
                bleu = evaluator.calculate_bleu(case['reference'], prediction)
                meteor = evaluator.calculate_meteor(case['reference'], prediction)
                
                results.append({
                    'Source': case['source'],
                    'Reference': case['reference'], 
                    'Prediction': prediction,
                    'BLEU': bleu,
                    'METEOR': meteor
                })
                
                print(f"    → {prediction}")
                
            except Exception as e:
                print(f"    → ERROR: {str(e)}")
                results.append({
                    'Source': case['source'],
                    'Reference': case['reference'],
                    'Prediction': f"ERROR: {str(e)}",
                    'BLEU': 0.0,
                    'METEOR': 0.0
                })
        
        # Results summary
        df = pd.DataFrame(results)
        print(f"\n{'='*80}")
        print(f"{direction.upper()} EVALUATION RESULTS")
        print(f"{'='*80}")
        print(df.to_string(index=False, max_colwidth=40))
        
        valid_results = df[df['BLEU'] > 0]
        if len(valid_results) > 0:
            avg_bleu = valid_results['BLEU'].mean()
            avg_meteor = valid_results['METEOR'].mean()
            print(f"\nAverage BLEU: {avg_bleu:.2f}")
            print(f"Average METEOR: {avg_meteor:.2f}")
            print(f"Success rate: {len(valid_results)}/{len(results)} ({100*len(valid_results)/len(results):.1f}%)")
        else:
            print("\nNo successful translations generated.")

def main():
    """Main execution function."""
    os.makedirs("results", exist_ok=True)
    
    model_paths = {}
    
    # Train English-Hindi model
    print("Training English-Hindi model...")
    en_hi_path = fine_tune_marian_en_hi()
    if en_hi_path:
        model_paths['en_hi'] = en_hi_path
    
    # Train Hindi-English model
    print("Training Hindi-English model...")
    hi_en_path = fine_tune_marian_hi_en()
    if hi_en_path:
        model_paths['hi_en'] = hi_en_path
    
    # If both MarianMT models fail, train T5 as fallback
    if not model_paths:
        print("Both MarianMT models failed. Training T5 fallback...")
        t5_path = fine_tune_t5_model()
        if t5_path:
            model_paths['en_hi'] = t5_path
            model_paths['hi_en'] = t5_path
    
    # Evaluate models
    if model_paths:
        run_evaluation(model_paths)
    else:
        print("No models were successfully trained.")
    
    return model_paths

if __name__ == "__main__":
    main()

Training English-Hindi model...
--- Starting Fine-Tuning MarianMT English-Hindi ---
Loading Helsinki-NLP/opus-mt-en-hi...
Loading samanantar dataset...
Found 5000 valid translation pairs
Preprocessing datasets...


Map:   0%|          | 0/4500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Starting English-Hindi training...


Step,Training Loss,Validation Loss
200,4.1681,3.757415
400,3.6922,3.588024
600,3.3591,3.52431
800,3.3135,3.505749


Non-default generation parameters: {'max_length': 512, 'num_beams': 4, 'bad_words_ids': [[61949]], 'forced_eos_token_id': 0}
Non-default generation parameters: {'max_length': 512, 'num_beams': 4, 'bad_words_ids': [[61949]], 'forced_eos_token_id': 0}
Non-default generation parameters: {'max_length': 512, 'num_beams': 4, 'bad_words_ids': [[61949]], 'forced_eos_token_id': 0}
Non-default generation parameters: {'max_length': 512, 'num_beams': 4, 'bad_words_ids': [[61949]], 'forced_eos_token_id': 0}
There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.decoder.embed_tokens.weight', 'model.decoder.embed_positions.weight', 'lm_head.weight'].
Non-default generation parameters: {'max_length': 512, 'num_beams': 4, 'bad_words_ids': [[61949]], 'forced_eos_token_id': 0}


Saving English-Hindi model to results/marian_en_hi_finetuned
Training Hindi-English model...
--- Starting Fine-Tuning MarianMT Hindi-English ---
Loading Helsinki-NLP/opus-mt-hi-en...
Loading samanantar dataset...
Found 5000 valid translation pairs
Preprocessing datasets...


Map:   0%|          | 0/4500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Starting Hindi-English training...


Step,Training Loss,Validation Loss
200,4.1523,3.755752
400,3.6103,3.590309
600,3.2306,3.539832
800,3.2081,3.515731


Non-default generation parameters: {'max_length': 512, 'num_beams': 6, 'bad_words_ids': [[61126]], 'forced_eos_token_id': 0}
Non-default generation parameters: {'max_length': 512, 'num_beams': 6, 'bad_words_ids': [[61126]], 'forced_eos_token_id': 0}
Non-default generation parameters: {'max_length': 512, 'num_beams': 6, 'bad_words_ids': [[61126]], 'forced_eos_token_id': 0}
Non-default generation parameters: {'max_length': 512, 'num_beams': 6, 'bad_words_ids': [[61126]], 'forced_eos_token_id': 0}
There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.decoder.embed_tokens.weight', 'model.decoder.embed_positions.weight', 'lm_head.weight'].
Non-default generation parameters: {'max_length': 512, 'num_beams': 6, 'bad_words_ids': [[61126]], 'forced_eos_token_id': 0}


Saving Hindi-English model to results/marian_hi_en_finetuned

--- Starting Evaluation ---

--- Evaluating English-Hindi Translation ---
Loading translator from: results/marian_en_hi_finetuned
Detected model type: marian
Translator ready!
 1. 'Hello'
    → सलाम
 2. 'How are you?'
    → आप कैसे हैं?
 3. 'Good morning'
    → सुप्रभात
 4. 'Thank you'
    → धन्यवाद
 5. 'I am fine'
    → मैं ठीक हूं

ENGLISH-HINDI EVALUATION RESULTS
      Source    Reference   Prediction   BLEU  METEOR
       Hello       नमस्ते         सलाम   0.00    0.00
How are you? आप कैसे हैं? आप कैसे हैं? 100.00   99.22
Good morning     सुप्रभात     सुप्रभात  17.78   50.00
   Thank you      धन्यवाद      धन्यवाद  17.78   50.00
   I am fine  मैं ठीक हूं  मैं ठीक हूं  56.23   98.15

Average BLEU: 47.95
Average METEOR: 74.34
Success rate: 4/5 (80.0%)

--- Evaluating Hindi-English Translation ---
Loading translator from: results/marian_hi_en_finetuned
Detected model type: marian
Translator ready!
 1. 'नमस्ते'
    → Hi.
 2. '

In [5]:
from flask import Flask, request, jsonify
from flask_cors import CORS
import time
import threading
import torch
import os
import json
from transformers import MarianMTModel, MarianTokenizer, T5ForConditionalGeneration, T5Tokenizer

app = Flask(__name__)
CORS(app)

# Global translator cache and configuration
translator_cache = {}
LANGUAGES = {'en': 'English', 'hi': 'Hindi'}
MODEL_PATHS = {
    'en_hi': 'results/marian_en_hi_finetuned',
    'hi_en': 'results/marian_hi_en_finetuned',
    'fallback': 'results/t5_bidirectional_finetuned'
}

class DynamicTranslator:
    """Dynamic translator that loads models on demand based on language pair."""
    
    def __init__(self, model_path):
        self.model_path = model_path
        self.model_type = self._detect_model_type()
        self.tokenizer = None
        self.model = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._load_model()
    
    def _detect_model_type(self):
        """Detect model type from config.json"""
        config_path = os.path.join(self.model_path, "config.json")
        
        if os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f)
                    architecture = config.get("architectures", [""])[0].lower()
                    if "marian" in architecture:
                        return "marian"
                    elif "t5" in architecture:
                        return "t5"
            except Exception as e:
                print(f"Error reading config: {e}")
        
        # Fallback detection based on path
        if "marian" in self.model_path.lower():
            return "marian"
        elif "t5" in self.model_path.lower():
            return "t5"
        
        return "unknown"
    
    def _load_model(self):
        """Load model and tokenizer based on detected type."""
        print(f"Loading {self.model_type} model from: {self.model_path}")
        
        try:
            if self.model_type == "marian":
                self.tokenizer = MarianTokenizer.from_pretrained(self.model_path)
                self.model = MarianMTModel.from_pretrained(self.model_path)
            else:  # T5 or unknown (default to T5)
                self.tokenizer = T5Tokenizer.from_pretrained(self.model_path)
                self.model = T5ForConditionalGeneration.from_pretrained(self.model_path)
                self.model_type = "t5"
            
            self.model.to(self.device)
            self.model.eval()
            print(f"Model loaded successfully: {self.model_type}")
            
        except Exception as e:
            print(f"Error loading model: {e}")
            raise e
    
    def translate(self, text, src_lang='en', tgt_lang='hi'):
        """Translate text based on model type."""
        if not self.model or not self.tokenizer:
            raise Exception("Model not loaded")
        
        # Prepare input based on model type
        if self.model_type == "marian":
            input_text = text
        else:  # T5
            lang_map = {'en': 'English', 'hi': 'Hindi'}
            input_text = f"translate {lang_map[src_lang]} to {lang_map[tgt_lang]}: {text}"
        
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=128,
            truncation=True,
            padding=True
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Generate translation
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=128,
                num_beams=4,
                length_penalty=0.6,
                early_stopping=True,
                do_sample=False,
            )
        
        # Decode and return
        translation = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return translation.strip()

def get_model_path(src_lang, tgt_lang):
    """Get appropriate model path based on language pair."""
    language_pair = f"{src_lang}_{tgt_lang}"
    
    # Check if specific directional model exists
    if language_pair in MODEL_PATHS and os.path.exists(MODEL_PATHS[language_pair]):
        return MODEL_PATHS[language_pair]
    
    # Check if fallback model exists
    if os.path.exists(MODEL_PATHS['fallback']):
        return MODEL_PATHS['fallback']
    
    # Try to find any available model
    for path in MODEL_PATHS.values():
        if os.path.exists(path):
            return path
    
    return None

def get_translator(src_lang, tgt_lang):
    """Get or create translator for specific language pair."""
    cache_key = f"{src_lang}_{tgt_lang}"
    
    # Return cached translator if available
    if cache_key in translator_cache:
        return translator_cache[cache_key]
    
    # Get model path
    model_path = get_model_path(src_lang, tgt_lang)
    
    if not model_path:
        raise Exception(f"No model available for {src_lang} -> {tgt_lang} translation")
    
    # Create and cache new translator
    try:
        translator = DynamicTranslator(model_path)
        translator_cache[cache_key] = translator
        print(f"Cached translator for {src_lang} -> {tgt_lang}")
        return translator
    except Exception as e:
        raise Exception(f"Failed to load translator: {str(e)}")

def initialize_translators():
    """Pre-initialize common translators."""
    print("Initializing translators...")
    
    common_pairs = [('en', 'hi'), ('hi', 'en')]
    
    for src_lang, tgt_lang in common_pairs:
        try:
            get_translator(src_lang, tgt_lang)
            print(f"✓ Initialized {src_lang} -> {tgt_lang} translator")
        except Exception as e:
            print(f"✗ Failed to initialize {src_lang} -> {tgt_lang}: {e}")
    
    print("Translator initialization complete.")

@app.route('/api/health', methods=['GET'])
def health_check():
    """Health check endpoint with detailed model status."""
    model_status = {}
    
    for pair, path in MODEL_PATHS.items():
        model_status[pair] = {
            'path': path,
            'exists': os.path.exists(path),
            'cached': f"{pair.replace('_', '_')}" in translator_cache if '_' in pair else False
        }
    
    return jsonify({
        'status': 'healthy',
        'supported_languages': LANGUAGES,
        'model_status': model_status,
        'cached_translators': list(translator_cache.keys()),
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    })

@app.route('/api/translate', methods=['POST'])
def translate_text():
    """Single text translation endpoint."""
    try:
        data = request.json
        if not data:
            return jsonify({'error': 'No JSON data provided.'}), 400
            
        text = data.get('text', '').strip()
        src_lang = data.get('src_lang', 'en')
        tgt_lang = data.get('tgt_lang', 'hi')

        # Validation
        if not text:
            return jsonify({'error': 'No text provided for translation.'}), 400
        if src_lang not in LANGUAGES or tgt_lang not in LANGUAGES:
            return jsonify({'error': 'Unsupported language selected.'}), 400
        if src_lang == tgt_lang:
            return jsonify({'error': 'Source and target languages are the same.'}), 400

        # Get translator and translate
        start_time = time.time()
        
        try:
            translator = get_translator(src_lang, tgt_lang)
            translation = translator.translate(text, src_lang, tgt_lang)
        except Exception as e:
            return jsonify({'error': f'Translation failed: {str(e)}'}), 503
        
        end_time = time.time()

        return jsonify({
            'source_text': text,
            'source_language': src_lang,
            'source_language_name': LANGUAGES[src_lang],
            'target_language': tgt_lang,
            'target_language_name': LANGUAGES[tgt_lang],
            'translation': translation,
            'model_used': get_model_path(src_lang, tgt_lang),
            'processing_time': round(end_time - start_time, 3)
        })

    except Exception as e:
        print(f"Translation error: {e}")
        return jsonify({'error': 'An internal server error occurred.'}), 500

@app.route('/api/batch-translate', methods=['POST'])
def batch_translate():
    """Batch translation endpoint."""
    try:
        data = request.json
        if not data:
            return jsonify({'error': 'No JSON data provided.'}), 400
            
        texts = data.get('texts', [])
        src_lang = data.get('src_lang', 'en')
        tgt_lang = data.get('tgt_lang', 'hi')
        
        # Validation
        if not texts or not isinstance(texts, list):
            return jsonify({'error': 'No texts array provided.'}), 400
        if len(texts) > 100:
            return jsonify({'error': 'Maximum 100 texts allowed per batch.'}), 400
        if src_lang not in LANGUAGES or tgt_lang not in LANGUAGES:
            return jsonify({'error': 'Unsupported language selected.'}), 400
        if src_lang == tgt_lang:
            return jsonify({'error': 'Source and target languages are the same.'}), 400

        start_time = time.time()
        
        # Get translator
        try:
            translator = get_translator(src_lang, tgt_lang)
        except Exception as e:
            return jsonify({'error': f'Failed to load translator: {str(e)}'}), 503
        
        translations = []
        
        for text in texts:
            if isinstance(text, str) and text.strip():
                try:
                    translation = translator.translate(text.strip(), src_lang, tgt_lang)
                    translations.append({
                        'source': text.strip(),
                        'translation': translation,
                        'success': True
                    })
                except Exception as e:
                    translations.append({
                        'source': text.strip(),
                        'translation': None,
                        'success': False,
                        'error': str(e)
                    })
            else:
                translations.append({
                    'source': text,
                    'translation': None,
                    'success': False,
                    'error': 'Invalid text format'
                })
        
        end_time = time.time()
        
        return jsonify({
            'translations': translations,
            'source_language': src_lang,
            'target_language': tgt_lang,
            'model_used': get_model_path(src_lang, tgt_lang),
            'total_count': len(translations),
            'success_count': sum(1 for t in translations if t['success']),
            'processing_time': round(end_time - start_time, 3)
        })
        
    except Exception as e:
        print(f"Batch translation error: {e}")
        return jsonify({'error': 'An internal server error occurred.'}), 500

@app.route('/api/languages', methods=['GET'])
def get_languages():
    """Get supported languages and available translation pairs."""
    available_pairs = []
    
    # Check which models are actually available
    for src in LANGUAGES.keys():
        for tgt in LANGUAGES.keys():
            if src != tgt:
                model_path = get_model_path(src, tgt)
                if model_path:
                    available_pairs.append({
                        'source': src,
                        'target': tgt,
                        'source_name': LANGUAGES[src],
                        'target_name': LANGUAGES[tgt],
                        'model_path': model_path
                    })
    
    return jsonify({
        'supported_languages': LANGUAGES,
        'available_translation_pairs': available_pairs,
        'total_pairs': len(available_pairs)
    })

@app.route('/api/models', methods=['GET'])
def get_models():
    """Get information about available models."""
    models_info = {}
    
    for model_key, path in MODEL_PATHS.items():
        model_info = {
            'path': path,
            'exists': os.path.exists(path),
            'type': 'unknown',
            'cached': False
        }
        
        if os.path.exists(path):
            # Detect model type
            config_path = os.path.join(path, "config.json")
            if os.path.exists(config_path):
                try:
                    with open(config_path, 'r') as f:
                        config = json.load(f)
                        architecture = config.get("architectures", [""])[0].lower()
                        if "marian" in architecture:
                            model_info['type'] = "marian"
                        elif "t5" in architecture:
                            model_info['type'] = "t5"
                except:
                    pass
            
            # Check if cached
            cache_key = model_key.replace('_', '_')
            model_info['cached'] = cache_key in translator_cache
        
        models_info[model_key] = model_info
    
    return jsonify({
        'models': models_info,
        'cache_status': list(translator_cache.keys())
    })

@app.route('/api/clear-cache', methods=['POST'])
def clear_cache():
    """Clear translator cache."""
    global translator_cache
    
    # Get current cache size
    cache_size = len(translator_cache)
    
    # Clear cache
    translator_cache.clear()
    
    # Force garbage collection
    import gc
    gc.collect()
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return jsonify({
        'message': 'Cache cleared successfully',
        'cleared_translators': cache_size,
        'current_cache_size': len(translator_cache)
    })

@app.route('/api/warmup', methods=['POST'])
def warmup_translators():
    """Warm up translators for better performance."""
    data = request.json or {}
    language_pairs = data.get('pairs', [('en', 'hi'), ('hi', 'en')])
    
    results = []
    
    for src_lang, tgt_lang in language_pairs:
        try:
            start_time = time.time()
            translator = get_translator(src_lang, tgt_lang)
            
            # Perform a dummy translation to warm up
            test_text = "Hello" if src_lang == 'en' else "नमस्ते"
            translator.translate(test_text, src_lang, tgt_lang)
            
            end_time = time.time()
            
            results.append({
                'pair': f"{src_lang} -> {tgt_lang}",
                'success': True,
                'warmup_time': round(end_time - start_time, 3)
            })
            
        except Exception as e:
            results.append({
                'pair': f"{src_lang} -> {tgt_lang}",
                'success': False,
                'error': str(e)
            })
    
    return jsonify({
        'message': 'Warmup completed',
        'results': results,
        'total_cached': len(translator_cache)
    })

def run_flask_app():
    """Run Flask application."""
    initialize_translators()
    app.run(debug=False, use_reloader=False, host='0.0.0.0', port=5000, threaded=True)

def start_api_server():
    """Start API server in a thread."""
    thread = threading.Thread(target=run_flask_app, daemon=True)
    thread.start()
    return thread

def start_translation_api():
    """Start translation API for Jupyter notebook usage."""
    print("Starting enhanced translation API server...")
    thread = start_api_server()
    print("API server started on http://localhost:5000")
    print("\nAvailable endpoints:")
    print("  GET  /api/health - Check server status and model availability")
    print("  GET  /api/languages - Get supported languages and available pairs")
    print("  GET  /api/models - Get detailed model information")
    print("  POST /api/translate - Translate single text")
    print("  POST /api/batch-translate - Translate multiple texts")
    print("  POST /api/warmup - Warm up translators for better performance")
    print("  POST /api/clear-cache - Clear translator cache")
    
    print("\nExample usage:")
    print("# English to Hindi")
    print("curl -X POST http://localhost:5000/api/translate \\")
    print("  -H 'Content-Type: application/json' \\")
    print("  -d '{\"text\": \"Hello world\", \"src_lang\": \"en\", \"tgt_lang\": \"hi\"}'")
    
    print("\n# Hindi to English")
    print("curl -X POST http://localhost:5000/api/translate \\")
    print("  -H 'Content-Type: application/json' \\")
    print("  -d '{\"text\": \"नमस्ते दुनिया\", \"src_lang\": \"hi\", \"tgt_lang\": \"en\"}'")
    
    print("\n# Batch translation")
    print("curl -X POST http://localhost:5000/api/batch-translate \\")
    print("  -H 'Content-Type: application/json' \\")
    print("  -d '{\"texts\": [\"Hello\", \"Good morning\"], \"src_lang\": \"en\", \"tgt_lang\": \"hi\"}'")
    
    return thread

if __name__ == '__main__':
    run_flask_app()

Initializing translators...
Loading marian model from: results/marian_en_hi_finetuned
Model loaded successfully: marian
Cached translator for en -> hi
✓ Initialized en -> hi translator
Loading marian model from: results/marian_hi_en_finetuned
Model loaded successfully: marian
Cached translator for hi -> en
✓ Initialized hi -> en translator
Translator initialization complete.
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://192.168.202.209:5000
[33mPress CTRL+C to quit[0m
127.0.0.1 - - [19/Aug/2025 09:24:37] "OPTIONS /api/translate HTTP/1.1" 200 -
127.0.0.1 - - [19/Aug/2025 09:24:37] "POST /api/translate HTTP/1.1" 200 -
