## Fine-tune MarianMT model for English-Hindi translation

In [None]:
## Fine-tune mBART for bidirectional EN-HI translation

In [12]:
import torch
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import (
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
import os
import warnings
warnings.filterwarnings("ignore")

def prepare_mbart_data():
    """Prepare data for mBART fine-tuning."""
    print("Loading samanantar dataset for mBART...")
    dataset = load_dataset("ai4bharat/samanantar", "hi", split='train')
    
    train_data = []
    val_data = []
    
    for i, example in enumerate(dataset):
        if i >= 10000:
            break
        if example['src'] and example['tgt']:
            # EN->HI
            train_data.append({
                'source': example['src'].strip(),
                'target': example['tgt'].strip(),
                'src_lang': 'en_XX',
                'tgt_lang': 'hi_IN'
            })
            # HI->EN
            train_data.append({
                'source': example['tgt'].strip(),
                'target': example['src'].strip(),
                'src_lang': 'hi_IN',
                'tgt_lang': 'en_XX'
            })
    
    # Shuffle and split
    import random
    random.seed(42)
    random.shuffle(train_data)
    
    split_idx = int(0.9 * len(train_data))
    train_set = train_data[:split_idx]
    val_set = train_data[split_idx:]
    
    print(f"Prepared {len(train_set)} train, {len(val_set)} val samples")
    return train_set, val_set

def fine_tune_mbart():
    """Fine-tune mBART for bidirectional EN-HI translation."""
    print("--- Fine-tuning mBART-50 ---")
    
    model_name = "facebook/mbart-large-50-many-to-many-mmt"
    output_dir = "results/mbart_en_hi_bidirectional"
    
    # Load model and tokenizer
    print(f"Loading {model_name}...")
    tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
    model = MBartForConditionalGeneration.from_pretrained(model_name)
    
    # Prepare data
    train_data, val_data = prepare_mbart_data()
    
    def preprocess_function(examples):
        # Set source language
        tokenizer.src_lang = examples['src_lang'][0]  # Assuming batch has same src_lang
        
        # Tokenize sources
        model_inputs = tokenizer(
            examples['source'], 
            max_length=128, 
            truncation=True, 
            padding=False
        )
        
        # Tokenize targets with target language
        tokenizer.tgt_lang = examples['tgt_lang'][0]
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                examples['target'], 
                max_length=128, 
                truncation=True, 
                padding=False
            )
        
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    # Group by language pairs for preprocessing
    train_en_hi = [x for x in train_data if x['src_lang'] == 'en_XX']
    train_hi_en = [x for x in train_data if x['src_lang'] == 'hi_IN']
    val_en_hi = [x for x in val_data if x['src_lang'] == 'en_XX']
    val_hi_en = [x for x in val_data if x['src_lang'] == 'hi_IN']
    
    # Create datasets
    train_en_hi_ds = Dataset.from_list(train_en_hi).map(preprocess_function, batched=True)
    train_hi_en_ds = Dataset.from_list(train_hi_en).map(preprocess_function, batched=True)
    val_en_hi_ds = Dataset.from_list(val_en_hi).map(preprocess_function, batched=True)
    val_hi_en_ds = Dataset.from_list(val_hi_en).map(preprocess_function, batched=True)
    
    # Combine datasets
    from datasets import concatenate_datasets
    train_dataset = concatenate_datasets([train_en_hi_ds, train_hi_en_ds]).shuffle(seed=42)
    val_dataset = concatenate_datasets([val_en_hi_ds, val_hi_en_ds]).shuffle(seed=42)
    
    print(f"Final dataset sizes: train={len(train_dataset)}, val={len(val_dataset)}")
    
    # Data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
    
    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="steps",
        eval_steps=500,
        logging_steps=100,
        save_steps=500,
        save_total_limit=2,
        learning_rate=1e-5,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=4,
        num_train_epochs=3,
        weight_decay=0.01,
        warmup_steps=500,
        predict_with_generate=True,
        fp16=torch.cuda.is_available(),
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
    )
    
    # Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    
    print("Starting mBART training...")
    trainer.train()
    
    print(f"Saving mBART model to {output_dir}")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    return output_dir

class MBartTranslator:
    """mBART-based bidirectional translator."""
    
    def __init__(self, model_path):
        print(f"Loading mBART translator from: {model_path}")
        
        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_path)
        self.model = MBartForConditionalGeneration.from_pretrained(model_path)
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        
        # Language codes
        self.lang_codes = {
            'en': 'en_XX',
            'hi': 'hi_IN'
        }
        
        print("mBART translator ready!")

    def translate(self, text, src_lang='en', tgt_lang='hi'):
        """Translate using mBART."""
        src_code = self.lang_codes[src_lang]
        tgt_code = self.lang_codes[tgt_lang]
        
        # Set source language
        self.tokenizer.src_lang = src_code
        
        # Tokenize
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            max_length=128,
            truncation=True,
            padding=True
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Generate with target language
        generated_tokens = self.model.generate(
            **inputs,
            forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_code],
            max_length=128,
            num_beams=5,
            length_penalty=1.0,
            early_stopping=True,
        )
        
        # Decode
        translation = self.tokenizer.batch_decode(
            generated_tokens, 
            skip_special_tokens=True
        )[0]
        
        return translation.strip()

# Fallback: Use pre-trained model without fine-tuning
def use_pretrained_mbart():
    """Use pre-trained mBART without fine-tuning as fallback."""
    print("--- Using Pre-trained mBART (No Fine-tuning) ---")
    model_name = "facebook/mbart-large-50-many-to-many-mmt"
    
    # Create a simple wrapper
    class PretrainedMBart:
        def __init__(self):
            self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
            self.model = MBartForConditionalGeneration.from_pretrained(model_name)
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.model.to(self.device)
            self.model.eval()
            
        def translate(self, text, src_lang='en', tgt_lang='hi'):
            lang_codes = {'en': 'en_XX', 'hi': 'hi_IN'}
            src_code = lang_codes[src_lang]
            tgt_code = lang_codes[tgt_lang]
            
            self.tokenizer.src_lang = src_code
            inputs = self.tokenizer(text, return_tensors="pt", max_length=128, truncation=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            generated_tokens = self.model.generate(
                **inputs,
                forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_code],
                max_length=128,
                num_beams=5,
            )
            
            return self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
    
    return PretrainedMBart()

class TranslationEvaluator:
    def __init__(self):
        self.smoothing = SmoothingFunction().method1
    
    def calculate_bleu(self, reference, candidate):
        if not reference.strip() or not candidate.strip():
            return 0.0
        ref_tokens = nltk.word_tokenize(reference.lower())
        cand_tokens = nltk.word_tokenize(candidate.lower())
        return round(sentence_bleu([ref_tokens], cand_tokens, smoothing_function=self.smoothing) * 100, 2)
    
    def calculate_meteor(self, reference, candidate):
        if not reference.strip() or not candidate.strip():
            return 0.0
        ref_tokens = nltk.word_tokenize(reference.lower())
        cand_tokens = nltk.word_tokenize(candidate.lower())
        return round(meteor_score([ref_tokens], cand_tokens) * 100, 2)

def run_evaluation(translator):
    """Evaluate translation model."""
    print("\n--- Starting Evaluation ---")
    
    # Download NLTK data
    for corpus in ['punkt', 'wordnet', 'omw-1.4']:
        nltk.download(corpus, quiet=True)

    evaluator = TranslationEvaluator()
    
    test_cases = [
        # English to Hindi
        {'source': 'Hello', 'reference': 'नमस्ते', 'src_lang': 'en', 'tgt_lang': 'hi'},
        {'source': 'How are you?', 'reference': 'आप कैसे हैं?', 'src_lang': 'en', 'tgt_lang': 'hi'},
        {'source': 'Good morning', 'reference': 'सुप्रभात', 'src_lang': 'en', 'tgt_lang': 'hi'},
        {'source': 'Thank you', 'reference': 'धन्यवाद', 'src_lang': 'en', 'tgt_lang': 'hi'},
        {'source': 'I am fine', 'reference': 'मैं ठीक हूं', 'src_lang': 'en', 'tgt_lang': 'hi'},
        
        # Hindi to English  
        {'source': 'नमस्ते', 'reference': 'Hello', 'src_lang': 'hi', 'tgt_lang': 'en'},
        {'source': 'आप कैसे हैं?', 'reference': 'How are you?', 'src_lang': 'hi', 'tgt_lang': 'en'},
        {'source': 'धन्यवाद', 'reference': 'Thank you', 'src_lang': 'hi', 'tgt_lang': 'en'},
        {'source': 'मैं ठीक हूं', 'reference': 'I am fine', 'src_lang': 'hi', 'tgt_lang': 'en'},
        {'source': 'सुप्रभात', 'reference': 'Good morning', 'src_lang': 'hi', 'tgt_lang': 'en'},
    ]
    
    results = []
    print("Evaluating translations...")
    
    for i, case in enumerate(test_cases, 1):
        print(f"{i:2d}. '{case['source']}' ({case['src_lang']}->{case['tgt_lang']})")
        
        try:
            prediction = translator.translate(case['source'], case['src_lang'], case['tgt_lang'])
            bleu = evaluator.calculate_bleu(case['reference'], prediction)
            meteor = evaluator.calculate_meteor(case['reference'], prediction)
            
            results.append({
                'Source': case['source'],
                'Reference': case['reference'], 
                'Prediction': prediction,
                'BLEU': bleu,
                'METEOR': meteor
            })
            
            print(f"    → {prediction}")
            
        except Exception as e:
            print(f"    → ERROR: {str(e)}")
            results.append({
                'Source': case['source'],
                'Reference': case['reference'],
                'Prediction': f"ERROR: {str(e)}",
                'BLEU': 0.0,
                'METEOR': 0.0
            })
    
    # Results summary
    df = pd.DataFrame(results)
    print(f"\n{'='*80}")
    print("EVALUATION RESULTS")
    print(f"{'='*80}")
    print(df.to_string(index=False, max_colwidth=40))
    
    valid_results = df[df['BLEU'] > 0]
    if len(valid_results) > 0:
        avg_bleu = valid_results['BLEU'].mean()
        avg_meteor = valid_results['METEOR'].mean()
        print(f"\nAverage BLEU: {avg_bleu:.2f}")
        print(f"Average METEOR: {avg_meteor:.2f}")
        print(f"Success rate: {len(valid_results)}/{len(results)} ({100*len(valid_results)/len(results):.1f}%)")
    else:
        print("\nNo successful translations generated.")

def main():
    """Main execution - try fine-tuning, fallback to pretrained."""
    os.makedirs("results", exist_ok=True)
    
    try:
        print("Attempting mBART fine-tuning...")
        model_path = fine_tune_mbart()
        translator = MBartTranslator(model_path)
        print("Using fine-tuned mBART model")
    except Exception as e:
        print(f"Fine-tuning failed: {e}")
        print("Using pre-trained mBART model...")
        translator = use_pretrained_mbart()
    
    # Evaluate
    run_evaluation(translator)
    
    return translator

if __name__ == "__main__":
    main()

Attempting mBART fine-tuning...
--- Fine-tuning mBART-50 ---
Loading facebook/mbart-large-50-many-to-many-mmt...


tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

Loading samanantar dataset for mBART...
Prepared 18000 train, 2000 val samples


Map:   0%|          | 0/8978 [00:00<?, ? examples/s]

Map:   0%|          | 0/9022 [00:00<?, ? examples/s]

Map:   0%|          | 0/1022 [00:00<?, ? examples/s]

Map:   0%|          | 0/978 [00:00<?, ? examples/s]

Final dataset sizes: train=18000, val=2000
Starting mBART training...


Step,Training Loss,Validation Loss
500,1.9573,1.967723
1000,1.9499,1.899083
1500,1.7823,1.877854
2000,1.7003,1.861613
2500,1.589,1.865013
3000,1.5526,1.858425


Non-default generation parameters: {'max_length': 200, 'early_stopping': True, 'num_beams': 5, 'forced_eos_token_id': 2}
Non-default generation parameters: {'max_length': 200, 'early_stopping': True, 'num_beams': 5, 'forced_eos_token_id': 2}
Non-default generation parameters: {'max_length': 200, 'early_stopping': True, 'num_beams': 5, 'forced_eos_token_id': 2}
Non-default generation parameters: {'max_length': 200, 'early_stopping': True, 'num_beams': 5, 'forced_eos_token_id': 2}
Non-default generation parameters: {'max_length': 200, 'early_stopping': True, 'num_beams': 5, 'forced_eos_token_id': 2}
Non-default generation parameters: {'max_length': 200, 'early_stopping': True, 'num_beams': 5, 'forced_eos_token_id': 2}
There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].
Non-default generation parameters: {'max_length': 200, 'early_stopping': True, 'num_beams': 5, 'forced_eos_token_id': 2}


Saving mBART model to results/mbart_en_hi_bidirectional
Loading mBART translator from: results/mbart_en_hi_bidirectional
mBART translator ready!
Using fine-tuned mBART model

--- Starting Evaluation ---
Evaluating translations...
 1. 'Hello' (en->hi)
    → नमस्ते
 2. 'How are you?' (en->hi)
    → आप कैसे हैं?
 3. 'Good morning' (en->hi)
    → नमस्ते
 4. 'Thank you' (en->hi)
    → धन्यवाद
 5. 'I am fine' (en->hi)
    → मैं ठीक हूं।
 6. 'नमस्ते' (hi->en)
    → goodbye
 7. 'आप कैसे हैं?' (hi->en)
    → How are you?
 8. 'धन्यवाद' (hi->en)
    → thank you
 9. 'मैं ठीक हूं' (hi->en)
    → I'm fine.
10. 'सुप्रभात' (hi->en)
    → sunrise

EVALUATION RESULTS
      Source    Reference   Prediction   BLEU  METEOR
       Hello       नमस्ते       नमस्ते  17.78   50.00
How are you? आप कैसे हैं? आप कैसे हैं? 100.00   99.22
Good morning     सुप्रभात       नमस्ते   0.00    0.00
   Thank you      धन्यवाद      धन्यवाद  17.78   50.00
   I am fine  मैं ठीक हूं मैं ठीक हूं।  24.03   62.50
      नमस्ते      

In [None]:
from flask import Flask, request, jsonify
from flask_cors import CORS
import time
import threading

app = Flask(__name__)
CORS(app)

# Global translator instance
translator = None
LANGUAGES = {'en': 'English', 'hi': 'Hindi'}

def initialize_translator():
    global translator
    print("Initializing mBART translator...")
    try:
        # Try fine-tuned model first
        model_path = "results/mbart_en_hi_bidirectional"
        translator = MBartTranslator(model_path)
        print("Fine-tuned mBART translator initialized.")
    except Exception as e:
        print(f"Fine-tuned model not found: {e}")
        print("Using pre-trained mBART model...")
        try:
            # Fallback to pre-trained mBART
            class PretrainedMBart:
                def __init__(self):
                    from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
                    model_name = "facebook/mbart-large-50-many-to-many-mmt"
                    self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
                    self.model = MBartForConditionalGeneration.from_pretrained(model_name)
                    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                    self.model.to(self.device)
                    self.model.eval()
                    
                def translate(self, text, src_lang='en', tgt_lang='hi'):
                    lang_codes = {'en': 'en_XX', 'hi': 'hi_IN'}
                    src_code = lang_codes[src_lang]
                    tgt_code = lang_codes[tgt_lang]
                    
                    self.tokenizer.src_lang = src_code
                    inputs = self.tokenizer(text, return_tensors="pt", max_length=128, truncation=True)
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    
                    generated_tokens = self.model.generate(
                        **inputs,
                        forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_code],
                        max_length=128,
                        num_beams=5,
                    )
                    
                    return self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
            
            translator = PretrainedMBart()
            print("Pre-trained mBART translator initialized.")
        except Exception as e2:
            print(f"Failed to initialize any translator: {e2}")
            translator = None

@app.route('/api/health', methods=['GET'])
def health_check():
    return jsonify({
        'status': 'healthy',
        'translator_ready': translator is not None,
        'supported_languages': LANGUAGES
    })

@app.route('/api/translate', methods=['POST'])
def translate_text():
    try:
        if translator is None:
            return jsonify({'error': 'Translator not initialized. Please check model path.'}), 503
        
        data = request.json
        if not data:
            return jsonify({'error': 'No JSON data provided.'}), 400
            
        text = data.get('text', '').strip()
        src_lang = data.get('src_lang', 'en')
        tgt_lang = data.get('tgt_lang', 'hi')

        if not text:
            return jsonify({'error': 'No text provided for translation.'}), 400
        if src_lang not in LANGUAGES or tgt_lang not in LANGUAGES:
            return jsonify({'error': 'Unsupported language selected.'}), 400
        if src_lang == tgt_lang:
            return jsonify({'error': 'Source and target languages are the same.'}), 400

        start_time = time.time()
        translation = translator.translate(text, src_lang, tgt_lang)
        end_time = time.time()

        return jsonify({
            'source_text': text,
            'source_language': src_lang,
            'target_language': tgt_lang,
            'translation': translation,
            'processing_time': round(end_time - start_time, 3)
        })

    except Exception as e:
        print(f"Translation error: {e}")
        return jsonify({'error': 'An internal server error occurred.'}), 500

@app.route('/api/batch-translate', methods=['POST'])
def batch_translate():
    try:
        if translator is None:
            return jsonify({'error': 'Translator not initialized.'}), 503
            
        data = request.json
        if not data:
            return jsonify({'error': 'No JSON data provided.'}), 400
            
        texts = data.get('texts', [])
        src_lang = data.get('src_lang', 'en')
        tgt_lang = data.get('tgt_lang', 'hi')
        
        if not texts or not isinstance(texts, list):
            return jsonify({'error': 'No texts array provided.'}), 400
        if len(texts) > 100:
            return jsonify({'error': 'Maximum 100 texts allowed per batch.'}), 400

        start_time = time.time()
        translations = []
        
        for text in texts:
            if isinstance(text, str) and text.strip():
                try:
                    translation = translator.translate(text.strip(), src_lang, tgt_lang)
                    translations.append({
                        'source': text.strip(),
                        'translation': translation,
                        'success': True
                    })
                except Exception as e:
                    translations.append({
                        'source': text.strip(),
                        'translation': None,
                        'success': False,
                        'error': str(e)
                    })
            else:
                translations.append({
                    'source': text,
                    'translation': None,
                    'success': False,
                    'error': 'Invalid text format'
                })
        
        end_time = time.time()
        
        return jsonify({
            'translations': translations,
            'total_count': len(translations),
            'success_count': sum(1 for t in translations if t['success']),
            'processing_time': round(end_time - start_time, 3)
        })
        
    except Exception as e:
        print(f"Batch translation error: {e}")
        return jsonify({'error': 'An internal server error occurred.'}), 500

@app.route('/api/languages', methods=['GET'])
def get_languages():
    return jsonify({
        'supported_languages': LANGUAGES,
        'translation_pairs': [
            {'source': 'en', 'target': 'hi'},
            {'source': 'hi', 'target': 'en'}
        ]
    })

def run_flask_app():
    initialize_translator()
    app.run(debug=False, use_reloader=False, host='0.0.0.0', port=5000, threaded=True)

def start_api_server():
    thread = threading.Thread(target=run_flask_app, daemon=True)
    thread.start()
    return thread

# For Jupyter notebook usage
def start_translation_api():
    print("Starting translation API server...")
    thread = start_api_server()
    print("API server started on http://localhost:5000")
    print("\nAvailable endpoints:")
    print("  GET  /api/health - Check server status")
    print("  GET  /api/languages - Get supported languages")
    print("  POST /api/translate - Translate single text")
    print("  POST /api/batch-translate - Translate multiple texts")
    print("\nExample usage:")
    print("curl -X POST http://localhost:5000/api/translate \\")
    print("  -H 'Content-Type: application/json' \\")
    print("  -d '{\"text\": \"Hello\", \"src_lang\": \"en\", \"tgt_lang\": \"hi\"}'")
    return thread

if __name__ == '__main__':
    run_flask_app()

Initializing mBART translator...
Loading mBART translator from: results/mbart_en_hi_bidirectional
mBART translator ready!
Fine-tuned mBART translator initialized.
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://192.168.202.209:5000
[33mPress CTRL+C to quit[0m
127.0.0.1 - - [19/Aug/2025 09:05:02] "OPTIONS /api/translate HTTP/1.1" 200 -
127.0.0.1 - - [19/Aug/2025 09:05:03] "POST /api/translate HTTP/1.1" 200 -
127.0.0.1 - - [19/Aug/2025 09:05:05] "POST /api/translate HTTP/1.1" 200 -
127.0.0.1 - - [19/Aug/2025 09:05:16] "OPTIONS /api/translate HTTP/1.1" 200 -
127.0.0.1 - - [19/Aug/2025 09:05:16] "POST /api/translate HTTP/1.1" 200 -
127.0.0.1 - - [19/Aug/2025 09:05:23] "OPTIONS /api/translate HTTP/1.1" 200 -
127.0.0.1 - - [19/Aug/2025 09:05:23] "POST /api/translate HTTP/1.1" 200 -
127.0.0.1 - - [19/Aug/2025 09:05:26] "POST /api/translate HTTP/1.1" 200 -


: 