# BhaashikTantraE2EPrototype: Translation for Low-Resource Indic Languages

This notebook demonstrates the end-to-end process of loading a tokenizer, preparing datasets, fine-tuning pretrained models, translating texts, and evaluating the results using the BhaashikTantraE2EPrototype library.


This notebook provides a comprehensive setup for Google Colab to facilitate the end-to-end workflow for machine translation with BhaashikTantraE2EPrototype. You can copy the above cells into a Jupyter notebook file (`.ipynb`) and run them in Google Colab. Make sure to adapt paths and hyperparameters to your specific requirements!

## Setup

Please run the cells below to install the necessary dependencies.

In [1]:
!pip install datasets sacrebleu peft indic-nlp-library



In [2]:
!git clone https://github.com/VarunGumma/IndicTransToolkit.git
%cd IndicTransToolkit
!python3 -m pip install --editable ./
%cd ..

fatal: destination path 'IndicTransToolkit' already exists and is not an empty directory.
/content/IndicTransToolkit
Obtaining file:///content/IndicTransToolkit
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting indic-nlp-library-IT2@ git+https://github.com/VarunGumma/indic_nlp_library (from IndicTransToolkit==1.0.2)
  Cloning https://github.com/VarunGumma/indic_nlp_library to /tmp/pip-install-k23x4gm8/indic-nlp-library-it2_234dd9d25da0495aa5adc57506742261
  Running command git clone --filter=blob:none --quiet https://github.com/VarunGumma/indic_nlp_library /tmp/pip-install-k23x4gm8/indic-nlp-library-it2_234dd9d25da0495aa5adc57506742261
  Resolved https://github.com/VarunGumma/indic_nlp_library to commit 601521e05ed0ed8f2165ac317a47d186e25b6f0d
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: IndicTransToolkit
  Attempting uninstall: IndicTransToolkit
    Found existing installation: IndicTransToolkit 1.0.2
    Uninstalling IndicTransTool

**IMPORTANT : Restart your run-time first and then run the cells below.**

## Import Libraries and Initialize Metrics

This section of the notebook imports the required libraries and initializes evaluation metrics for fine-tuning and evaluating the IndicTrans2 model.

### Libraries Imported
1. **General-purpose libraries**:
   - `os`: For file path operations.
   - `argparse`: For parsing command-line arguments.

2. **Data handling and metrics**:
   - `pandas`: For data manipulation.
   - `datasets.Dataset`: For working with Hugging Face-compatible datasets.
   - `sacrebleu.metrics.BLEU`: To evaluate translation quality using the BLEU metric.
   - `sacrebleu.metrics.CHRF`: To evaluate translation quality using the CHRF metric.

3. **Model fine-tuning**:
   - `peft`: For applying parameter-efficient fine-tuning using LoRA (Low-Rank Adaptation).

4. **Indic language processing**:
   - `IndicTransToolkit`: For processing and handling Indic language data.

5. **Hugging Face Transformers**:
   - `Seq2SeqTrainer`: For training sequence-to-sequence models.
   - `Seq2SeqTrainingArguments`: For defining training configurations.
   - `AutoModelForSeq2SeqLM`: For loading pre-trained sequence-to-sequence models.
   - `AutoTokenizer`: For tokenizing text data.
   - `EarlyStoppingCallback`: For stopping training early when validation loss stops improving.

### Metric Initialization
- **BLEU Metric**: Initialized with `BLEU()` from `sacrebleu`.
- **CHRF Metric**: Initialized with `CHRF()` from `sacrebleu`.

### Notes
- Ensure all libraries are installed before executing this cell. Install missing packages using the following commands:
  ```bash
  pip install pandas datasets sacrebleu peft IndicTransToolkit transformers


In [3]:
import os
import argparse
import pandas as pd
from datasets import Dataset
from sacrebleu.metrics import BLEU, CHRF
from peft import LoraConfig, get_peft_model
from IndicTransToolkit import IndicProcessor, IndicDataCollator
from transformers import (
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    EarlyStoppingCallback,
)

bleu_metric = BLEU()
chrf_metric = CHRF()

## Dataset Loading, Preprocessing, and Evaluation

This section defines functions to load and preprocess the translation dataset, compute evaluation metrics, and tokenize input data for the model.

---

### `load_and_process_translation_dataset`
**Purpose**:  
Load a parallel dataset for translation, preprocess it, and tokenize it for model training.

**Parameters**:
- `data_dir` (str): Path to the directory containing the dataset.
- `split` (str): Data split to use (`train`, `val`, or `test`).
- `tokenizer`: Tokenizer object to process text data.
- `processor`: Instance of `IndicProcessor` for Indic-specific preprocessing.
- `src_lang_list` (list): List of source languages.
- `tgt_lang_list` (list): List of target languages.
- `num_proc` (int): Number of parallel processes for tokenization.
- `seed` (int): Random seed for dataset shuffling.

**Functionality**:
1. Constructs paths for source and target language files.
2. Reads and validates the number of lines in source and target files.
3. Preprocesses text using the `processor` object.
4. Converts the dataset into a Hugging Face `Dataset` and tokenizes it.

**Returns**:  
A tokenized and shuffled `Dataset` object ready for training.

**Notes**:
- Raises `FileNotFoundError` if source or target files are missing.
- Ensures source and target files have the same number of lines.

---

### `compute_metrics_factory`
**Purpose**:  
Factory function to create a metric computation function for evaluating translation quality.

**Parameters**:
- `tokenizer`: Tokenizer used to decode model predictions.
- `metric_dict` (dict): Dictionary of metric objects (e.g., BLEU, CHRF).
- `print_samples` (bool): Whether to print a sample of predictions and references.
- `n_samples` (int): Number of samples to print if `print_samples` is `True`.

**Functionality**:
1. Decodes predictions and references into human-readable text.
2. Computes evaluation metrics (e.g., BLEU, CHRF) on the predictions.
3. Optionally prints a random sample of predictions and their corresponding references.

**Returns**:  
A dictionary containing the scores for each metric.

---

### `preprocess_fn`
**Purpose**:  
Tokenize source and target text into input IDs for the model.

**Parameters**:
- `example` (dict): A single example from the dataset containing source and target sentences.
- `tokenizer`: Tokenizer object to convert text into model-readable input.
- `**kwargs`: Additional arguments for the tokenizer.

**Functionality**:
1. Tokenizes the source sentence with truncation and padding.
2. Tokenizes the target sentence using the tokenizer in target mode.
3. Adds tokenized target input IDs as labels to the model inputs.

**Returns**:  
A dictionary containing tokenized inputs and labels.

---

### Notes
- Ensure the `processor` and `tokenizer` are correctly initialized for your language pair before using these functions.
- Missing source/target files or mismatched line counts will cause errors during dataset loading.



In [4]:
def load_and_process_translation_dataset(
    data_dir,
    split="train",
    tokenizer=None,
    processor=None,
    src_lang_list=None,
    tgt_lang_list=None,
    num_proc=None,
    seed=42
):
    complete_dataset = {
        "sentence_SRC": [],
        "sentence_TGT": [],
    }

    for src_lang in src_lang_list:
        for tgt_lang in tgt_lang_list:
            if src_lang == tgt_lang:
                continue
            src_path = os.path.join(
                data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{src_lang}"
            )
            tgt_path = os.path.join(
                data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{tgt_lang}"
            )
            if not os.path.exists(src_path) or not os.path.exists(tgt_path):
                raise FileNotFoundError(
                    f"Source ({split}.{src_lang}) or Target ({split}.{tgt_lang}) file not found in {data_dir}"
                )
            with open(src_path, encoding="utf-8") as src_file, open(
                tgt_path, encoding="utf-8"
            ) as tgt_file:
                src_lines = src_file.readlines()
                tgt_lines = tgt_file.readlines()

            # Ensure both files have the same number of lines
            assert len(src_lines) == len(
                tgt_lines
            ), f"Source and Target files have different number of lines for {split}.{src_lang} and {split}.{tgt_lang}"

            complete_dataset["sentence_SRC"] += processor.preprocess_batch(
                src_lines, src_lang=src_lang, tgt_lang=tgt_lang, is_target=False
            )

            complete_dataset["sentence_TGT"] += processor.preprocess_batch(
                tgt_lines, src_lang=tgt_lang, tgt_lang=src_lang, is_target=True
            )

    complete_dataset = Dataset.from_dict(complete_dataset).shuffle(seed=seed)

    return complete_dataset.map(
        lambda example: preprocess_fn(
            example,
            tokenizer=tokenizer
        ),
        batched=True,
        num_proc=num_proc,
    )


def compute_metrics_factory(
    tokenizer, metric_dict=None, print_samples=False, n_samples=10
):
    def compute_metrics(eval_preds):
        preds, labels = eval_preds

        labels[labels == -100] = tokenizer.pad_token_id
        preds[preds == -100] = tokenizer.pad_token_id

        with tokenizer.as_target_tokenizer():
            preds = [
                x.strip()
                for x in tokenizer.batch_decode(
                    preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
            ]
            labels = [
                x.strip()
                for x in tokenizer.batch_decode(
                    labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
            ]

        assert len(preds) == len(
            labels
        ), "Predictions and Labels have different lengths"

        df = pd.DataFrame({"Predictions": preds, "References": labels}).sample(
            n=n_samples
        )

        if print_samples:
            for pred, label in zip(df["Predictions"].values, df["References"].values):
                print(f" | > Prediction: {pred}")
                print(f" | > Reference: {label}\n")

        return {
            metric_name: metric.corpus_score(preds, [labels]).score
            for (metric_name, metric) in metric_dict.items()
        }

    return compute_metrics


def preprocess_fn(example, tokenizer, **kwargs):
    model_inputs = tokenizer(
        example["sentence_SRC"], truncation=True, padding=False, max_length=256
    )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example["sentence_TGT"], truncation=True, padding=False, max_length=256
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

##Mount the drive

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Configuration


In [10]:
class Args:
    def __init__(self):
        self.model = "ai4bharat/indictrans2-en-indic-dist-200M"
        self.src_lang_list = "eng_Latn"
        self.tgt_lang_list = "hin_Deva"
        self.data_dir = "/content/drive/MyDrive/en-indic-exp"
        self.output_dir = "/content/drive/MyDrive/output"
        self.save_steps = 100
        self.eval_steps = 100
        self.batch_size = 32
        self.num_train_epochs = 100
        self.max_steps = 100
        self.grad_accum_steps = 4
        self.warmup_steps = 4000
        self.warmup_ratio = 0.0
        self.max_grad_norm = 1.0
        self.learning_rate = 5e-4
        self.weight_decay = 0.0
        self.adam_beta1 = 0.9
        self.adam_beta2 = 0.98
        self.dropout = 0.0
        self.print_samples = True
        self.optimizer = "adamw_torch"
        self.lr_scheduler = "inverse_sqrt"
        self.label_smoothing = 0.0
        self.num_workers = 8
        self.metric_for_best_model = "eval_loss"
        self.greater_is_better = True
        self.lora_target_modules = "q_proj,k_proj"
        self.lora_dropout = 0.1
        self.lora_r = 16
        self.lora_alpha = 32
        self.report_to = "none"
        self.patience = 5
        self.threshold = 1e-3

args = Args()


## Main Training and Evaluation Workflow

The `main` function orchestrates the training and evaluation of the IndicTrans2 translation model using the provided arguments and configurations.

---

### Function: `main(args)`

**Purpose**:  
To load the model, tokenizer, datasets, and metrics, and to set up and train a sequence-to-sequence model with parameter-efficient fine-tuning (LoRA).

---

### Workflow

#### 1. **Loading Model and Tokenizer**
- Loads the pre-trained IndicTrans2 model specified by `args.model`.
- Initializes the tokenizer for text preprocessing.
- Prepares an `IndicProcessor` for language-specific preprocessing before tokenization.

#### 2. **Data Collation**
- Creates a data collation function (`IndicDataCollator`) to pad and prepare batches for training and evaluation.

#### 3. **Dataset Preparation**
- Loads and preprocesses the training and evaluation datasets using the `load_and_process_translation_dataset` function.
- Handles missing data directory or files by raising errors.

#### 4. **LoRA Configuration**
- Configures LoRA fine-tuning with the specified parameters:
  - `lora_r`, `lora_alpha`, and `lora_dropout` control the low-rank adaptation.
  - `target_modules` specifies which modules in the model to fine-tune.

#### 5. **Metrics Setup**
- Initializes a metric computation factory using BLEU and chrF metrics to evaluate translation performance.

#### 6. **Training Arguments**
- Sets up training configurations using `Seq2SeqTrainingArguments`, including:
  - Hyperparameters (learning rate, batch size, warmup steps).
  - Optimization settings (`adam_beta1`, `adam_beta2`, weight decay).
  - Evaluation and logging frequency.
  - Mixed precision (`fp16`) for faster training.

#### 7. **Trainer Initialization**
- Creates a `Seq2SeqTrainer` instance with:
  - Model, arguments, data collator, datasets, and metrics.
  - An `EarlyStoppingCallback` to halt training based on patience and loss threshold.

#### 8. **Training and Saving**
- Trains the model and saves the LoRA adapter weights to `args.output_dir`.

---

### Key Features
- **Fine-tuning with LoRA**: Efficient parameter tuning for sequence-to-sequence models.
- **Dynamic Dataset Loading**: Supports multiple language pairs with validation.
- **Metrics for Evaluation**: BLEU and chrF metrics for assessing translation quality.
- **Early Stopping**: Prevents overfitting and conserves resources.

---

### Notes
- Interrupting training with `Ctrl+C` saves progress up to the last checkpoint.
- The model saves only the LoRA adapter weights, making the saved model lightweight and portable.



In [11]:
def main(args):
    print(f" | > Loading {args.model} and tokenizer ...")
    model = AutoModelForSeq2SeqLM.from_pretrained(
        args.model,
        trust_remote_code=True,
        attn_implementation="eager",
        dropout=args.dropout
    )

    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    processor = IndicProcessor(inference=False) # pre-process before tokenization

    data_collator = IndicDataCollator(
        tokenizer=tokenizer,
        model=model,
        padding="longest", # saves padding tokens
        pad_to_multiple_of=8, # better to have it as 8 when using fp16
        label_pad_token_id=-100
    )

    if args.data_dir is not None:
        train_dataset = load_and_process_translation_dataset(
            args.data_dir,
            split="train",
            tokenizer=tokenizer,
            processor=processor,
            src_lang_list=args.src_lang_list.split(","),
            tgt_lang_list=args.tgt_lang_list.split(","),
        )
        print(f" | > Loaded train dataset from {args.data_dir}. Size: {len(train_dataset)} ...")

        eval_dataset = load_and_process_translation_dataset(
            args.data_dir,
            split="dev",
            tokenizer=tokenizer,
            processor=processor,
            src_lang_list=args.src_lang_list.split(","),
            tgt_lang_list=args.tgt_lang_list.split(","),
        )
        print(f" | > Loaded eval dataset from {args.data_dir}. Size: {len(eval_dataset)} ...")
    else:
        raise ValueError(" | > Data directory not provided")

    lora_config = LoraConfig(
        r=args.lora_r,
        bias="none",
        inference_mode=False,
        task_type="SEQ_2_SEQ_LM",
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=args.lora_target_modules.split(","),
    )

    model.set_label_smoothing(args.label_smoothing)

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    print(f" | > Loading metrics factory with BLEU and chrF ...")
    seq2seq_compute_metrics = compute_metrics_factory(
        tokenizer=tokenizer,
        print_samples=args.print_samples,
        metric_dict={"BLEU": bleu_metric, "chrF": chrf_metric},
    )

    training_args = Seq2SeqTrainingArguments(
        output_dir=args.output_dir,
        do_train=True,
        do_eval=True,
        fp16=True, # use fp16 for faster training
        logging_strategy="steps",
        evaluation_strategy="steps",
        save_strategy="steps",
        logging_steps=5,
        save_total_limit=1,
        predict_with_generate=True,
        load_best_model_at_end=True,
        max_steps=args.max_steps, # max_steps overrides num_train_epochs
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum_steps,
        eval_accumulation_steps=args.grad_accum_steps,
        weight_decay=args.weight_decay,
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,
        max_grad_norm=args.max_grad_norm,
        optim=args.optimizer,
        lr_scheduler_type=args.lr_scheduler,
        warmup_ratio=args.warmup_ratio,
        warmup_steps=args.warmup_steps,
        learning_rate=args.learning_rate,
        num_train_epochs=args.num_train_epochs,
        save_steps=args.save_steps,
        eval_steps=args.eval_steps,
        dataloader_num_workers=args.num_workers,
        metric_for_best_model=args.metric_for_best_model,
        greater_is_better=args.greater_is_better,
        report_to=args.report_to,
        generation_max_length=256,
        generation_num_beams=5,
        sortish_sampler=True,
        group_by_length=True,
        include_tokens_per_second=True,
        include_num_input_tokens_seen=True,
        dataloader_prefetch_factor=2,
    )

    # Create Trainer instance
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=seq2seq_compute_metrics,
        callbacks=[
            EarlyStoppingCallback(
                early_stopping_patience=args.patience,
                early_stopping_threshold=args.threshold,
            )
        ],
    )

    print(f" | > Starting training ...")

    try:
        trainer.train()
    except KeyboardInterrupt:
        print(f" | > Training interrupted ...")

    # this will only save the LoRA adapter weights
    model.save_pretrained(args.output_dir)

In [12]:
if __name__ == "__main__":

    main(args)

 | > Loading ai4bharat/indictrans2-en-indic-dist-200M and tokenizer ...


Map:   0%|          | 0/3 [00:00<?, ? examples/s]

 | > Loaded train dataset from /content/drive/MyDrive/en-indic-exp. Size: 3 ...


Map:   0%|          | 0/520 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


 | > Loaded eval dataset from /content/drive/MyDrive/en-indic-exp. Size: 520 ...
trainable params: 1,769,472 || all params: 276,354,048 || trainable%: 0.6403
 | > Loading metrics factory with BLEU and chrF ...
 | > Starting training ...


Step,Training Loss,Validation Loss,Bleu,Chrf,Input Tokens Seen
100,1.7775,1.913523,24.456055,52.107819,16800


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


 | > Prediction: लोगों को प्रेरित करते हुए बलवीर सिंह ने कहा कि अगर कोई अधिकारी या कर्मचारी किसी काम के लिए रिश्वत मांगता है तो सतर्कता विभाग को शिकायत भेजी जा सकती है ।
 | > Reference: लोगों को प्रेरित करते हुए कहा कि यदि कोई अधिकारी या कर्मचारी किसी भी काम के लिए रिश्वत की मांग करता है तो उसकी शिकायत विजिलेंस विभाग से की जा सकती है ।

 | > Prediction: कुछ टीवी द्वारा बनाए गए हैं और ज्यादातर राजनीतिक पूर्वाग्रह के साथ पक्षपाती हैं ।
 | > Reference: कुछ टीवी के गढ़े हुए और च्यादातर राजनीतिक पक्षपात से रंगे हुए ।

 | > Prediction: वे सभी नवीनीकरण और वैध ड्राइविंग लाइसेंस ( डी. एल. ) तैयार हैं जहां 30 सितंबर तक तस्वीरें जमा की गई हैं, बाकी अगले सप्ताह तक प्रदान कर दी जाएंगी ।
 | > Reference: रिन्युअल व पक्के ड्राइविंग लाइसेंस ( डीएल ) जिनकी फोटो 30 सितंबर तक हो चुकी है वह तैयार हो चुके हैं, बाकी लाइसेंस अगले सप्ताह मिलेंगे ।

 | > Prediction: इस वजह से तेज रफ्तार टैक्सी राजमार्ग के किनारे एक नीलगिरी के पेड़ से टकरा गई ।
 | > Reference: इससे तेज रफ्तार टैक्सी राजमार्ग के किनारे सफेदे के प