### Text Translation using mBART-50

In [1]:
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from typing import List
from langdetect import detect

class MBart50Translator:
    def __init__(self, model_name: str = "facebook/mbart-large-50-many-to-many-mmt"):
        self.device = torch.device("cpu")
        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
        self.model = MBartForConditionalGeneration.from_pretrained(model_name).to(self.device)
        self.lang_code_map = {
            'en': 'en_XX', 'fr': 'fr_XX', 'es': 'es_XX', 'de': 'de_DE', 'it': 'it_IT', 
            'pt': 'pt_XX', 'nl': 'nl_XX', 'ru': 'ru_RU', 'zh': 'zh_CN', 'ja': 'ja_XX', 
            'ko': 'ko_KR', 'ar': 'ar_AR', 'hi': 'hi_IN', 'tr': 'tr_TR', 'vi': 'vi_VN',
            'th': 'th_TH', 'pl': 'pl_PL', 'uk': 'uk_UA', 'fa': 'fa_IR', 'ro': 'ro_RO',
            # Add more mappings as needed
        }

    def detect_language(self, text: str) -> str:
        try:
            lang_code = detect(text)
            return self.lang_code_map.get(lang_code, 'en_XX')  # Default to English if not found
        except:
            return 'en_XX'  # Default to English if detection fails

    def translate(self, text: str, tgt_lang: str, src_lang: str = None) -> str:
        if src_lang is None:
            src_lang = self.detect_language(text)
        
        self.tokenizer.src_lang = src_lang
        encoded = self.tokenizer(text, return_tensors="pt").to(self.device)
        
        generated_tokens = self.model.generate(
            **encoded,
            forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
            max_length=128
        )
        
        return self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

    def translate_batch(self, texts: List[str], tgt_lang: str, src_lang: str = None) -> List[str]:
        if src_lang is None:
            src_langs = [self.detect_language(text) for text in texts]
        else:
            src_langs = [src_lang] * len(texts)
        
        translations = []
        for text, src_lang in zip(texts, src_langs):
            self.tokenizer.src_lang = src_lang
            encoded = self.tokenizer(text, return_tensors="pt").to(self.device)
            
            generated_tokens = self.model.generate(
                **encoded,
                forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
                max_length=128
            )
            
            translations.append(self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0])
        
        return translations

# Example usage
# if __name__ == "__main__":
#     translator = MBart50Translator()

#     # Single translation with auto-detection
#     text = "Hello, how are you?"
#     tgt_lang = "fr_XX"  # French
#     translation = translator.translate(text, tgt_lang)
#     print(f"Original: {text}")
#     print(f"Translation: {translation}")

#     # Batch translation with auto-detection
#     texts = ["Hello, how are you?", "Bonjour, comment allez-vous?", "Hola, ¿cómo estás?"]
#     translations = translator.translate_batch(texts, tgt_lang)
#     for original, translation in zip(texts, translations):
#         print(f"Original: {original}")
#         print(f"Translation: {translation}")
#         print()