<a href="https://colab.research.google.com/github/Darkunquie/FMML_PROJECT_2024/blob/main/document_to_document.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install transformers[sentencepiece] sacrebleu rouge-score nltk gradio  # version 4.31.0, 2.3.1, 1.0.1, 3.8.1, 3.38.0

from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer
import gradio as gr
import io
import sacrebleu
from rouge_score import rouge_scorer
import nltk

nltk.download('punkt')

# Define supported languages and models
SUPPORTED_LANGUAGES = ["en", "fr", "de", "es", "ar", "ja"]  # Added Arabic (ar) and Japanese (ja)
MODEL_NAME = "facebook/m2m100_1.2B"
def fine_tune_model(model_name, train_dataset, eval_dataset, output_dir):
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=8,  # Adjust as needed
        per_device_eval_batch_size=64, # Adjust as needed
        # ... other training arguments ...
    )
    trainer = Trainer(
        model=AutoModelForSeq2SeqLM.from_pretrained(model_name),
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    trainer.train()
    trainer.save_model(output_dir)

def translate_document(input_file, source_lang, target_lang, model_name=MODEL_NAME, reference_text=""):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    with open(input_file.name, 'r', encoding='utf-8') as f:
        input_text = f.read()

    try:
        input_ids = tokenizer(input_text, return_tensors="pt").input_ids
        generated_ids = model.generate(input_ids, forced_bos_token_id=tokenizer.get_lang_id(target_lang))
        translation = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # Calculate BLEU score
        bleu = sacrebleu.corpus_bleu([translation], [[reference_text]])

        # Calculate ROUGE scores
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
        scores = scorer.score(reference_text, translation)
        rouge_scores_str = f"ROUGE-1: {scores['rouge1'].fmeasure:.4f}, ROUGE-L: {scores['rougeL'].fmeasure:.4f}"

    except Exception as e:
        print(f"M2M100 translation error: {e}")
        translation = "Translation failed. Please check your input or try a different model."
        bleu = None  # Or some default value
        rouge_scores_str = None  # Or some default value

    return translation, bleu.score if bleu else "N/A", rouge_scores_str if rouge_scores_str else "N/A"

 # Return translation, BLEU, and ROUGE

iface = gr.Interface(
    fn=translate_document,
    inputs=[
        gr.File(label="Input File"),
        gr.Dropdown(choices=SUPPORTED_LANGUAGES, label="Source Language"),
        gr.Dropdown(choices=SUPPORTED_LANGUAGES, label="Target Language"),
        gr.Textbox(label="Model Name (Optional)", placeholder=MODEL_NAME, value=MODEL_NAME),
        gr.Textbox(lines=5, placeholder="Enter reference text here...", label="Reference Text (for evaluation)")  # Add reference text input
    ],
    outputs=[
        gr.Textbox(label="Translated Document", lines=10),
        gr.Textbox(label="BLEU Score"),
        gr.Textbox(label="ROUGE Scores"),
    ],
    title="Document Translator",
    description="Translate documents between supported languages using the M2M100 model.",
)

iface.launch()



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://77455473164bae903a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


