## Inference Using Sarvam Translate Model on Flores plus dataset

In [None]:
import os
import json
import torch
import warnings
# import pandas as pd
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns
from dotenv import load_dotenv
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
# from evaluate import load


In [None]:
warnings.filterwarnings("ignore")
load_dotenv()
login(token=os.getenv("HUGGINGFACE_TOKEN"))

In [None]:
# data for doi_Deva  not in Flores Plus
# In_langs = ['asm_Beng','ben_Beng','brx_Deva','doi_Deva','eng_Latn','gom_Deva','guj_Gujr','hin_Deva','kan_Knda','kas_Arab','kas_Deva','mai_Deva','mal_Mlym','mar_Deva','mni_Beng','mni_Mtei','npi_Deva','ory_Orya','pan_Guru','san_Deva','sat_Olck','snd_Arab','snd_Deva','tam_Taml','tel_Telu','urd_Arab']
In_langs = ['asm_Beng','ben_Beng','brx_Deva','eng_Latn','gom_Deva','guj_Gujr','hin_Deva','kan_Knda','kas_Arab','kas_Deva','mai_Deva','mal_Mlym','mar_Deva','mni_Beng','mni_Mtei','npi_Deva','ory_Orya','pan_Guru','san_Deva','sat_Olck','snd_Arab','snd_Deva','tam_Taml','tel_Telu','urd_Arab']


In [None]:
code_lang_map = {
    'asm_Beng': 'Assamese',
    'ben_Beng': 'Bengali',
    'brx_Deva': 'Bodo',
    'doi_Deva':'Dogri',
    'eng_Latn': 'English',
    'gom_Deva': 'Konkani',
    'guj_Gujr': 'Gujarati',
    'hin_Deva': 'Hindi',
    'kan_Knda': 'Kannada',
    'kas_Arab': 'Kashmiri Arabic Script',
    'kas_Deva': 'Kahmiri Devangiri Script',
    'mai_Deva': 'Maithili',
    'mal_Mlym': 'Malayalam',
    'mar_Deva': 'Marathi',
    'mni_Beng':'Meiteilon Bengali Script',
    'mni_Mtei': 'Meiteilon Script',
    'npi_Deva': 'Nepali',
    'ory_Orya': 'Odia',
    'pan_Guru': 'Punjabi',
    'san_Deva': 'Sanskrit',
    'sat_Olck': 'Santali',
    'snd_Arab' : 'Sindhi Arabic Script',
    'snd_Deva': 'Sindhi Devangari Script',
    'tam_Taml': 'Tamil',
    'tel_Telu': 'Telugu',
    'urd_Arab': 'Urdu'
}

In [None]:
lang_code_map = {
    'Assamese': 'asm_Beng',
    'Bengali': 'ben_Beng',
    'Bodo': 'brx_Deva',
    'Dogri': 'doi_Deva',
    'English': 'eng_Latn',
    'Konkani': 'gom_Deva',
    'Gujarati': 'guj_Gujr',
    'Hindi': 'hin_Deva',
    'Kannada': 'kan_Knda',
    'Kashmiri Arabic Script': 'kas_Arab',
    'Kahmiri Devangiri Script': 'kas_Deva',
    'Maithili': 'mai_Deva',
    'Malayalam': 'mal_Mlym',
    'Marathi': 'mar_Deva',
    'Meiteilon Bengali Script':'mni_Beng',
    'Meiteilon Script': 'mni_Mtei',
    'Nepali': 'npi_Deva',
    'Odia': 'ory_Orya',
    'Punjabi': 'pan_Guru',
    'Sanskrit': 'san_Deva',
    'Santali': 'sat_Olck',
    'Sindhi Devangari Script': 'snd_Deva',
    'Sindhi Arabic Script': 'snd_Arab',
    'Tamil': 'tam_Taml',
    'Telugu': 'tel_Telu',
    'Urdu': 'urd_Arab'
}
    

In [None]:
ds={}

In [7]:
for lan in In_langs:
    ds[lan] = load_dataset("openlanguagedata/flores_plus",lan,split='dev')

In [None]:
ds

In [None]:
ds['eng_Latn']['text'][:5]

In [None]:
model_name = "sarvamai/sarvam-translate"

In [None]:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to('cuda:0')

In [None]:
# All the translations will be saved in this dict
results={}

In [None]:
# Load existing results if available
try:
    with open("translations_sarvam_flores.json", "r",encoding="utf-8") as f:
        results = json.load(f)
except:
    print("File does not exist")

In [None]:
def translate(sentences, src_lang,tgt_lang, batch_size=20):
    model.eval()
    translations = []
    print("**"*20, end=" ")
    print(f"Translating from {src_lang} to {tgt_lang}",end=" ")
    print("**"*20)
    with torch.inference_mode():
        for i in range(0, len(sentences), batch_size):
            batch = sentences[i:i+batch_size]

            # Build messages for each sentence in the batch
            messages = [
                [
                    {"role": "system", "content": f"Translate the text below from {src_lang} to {tgt_lang}."},
                    {"role": "user", "content": sent}
                ]
                for sent in batch
            ]

            # Apply chat template to each
            inputs = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages]

            # Tokenize and move to device
            model_inputs = tokenizer(inputs, return_tensors="pt", padding="max_length", max_length=1024, truncation=True).to(model.device)

            # Generate
            generated_ids = model.generate(
                **model_inputs,
                max_new_tokens=1024,
                do_sample=False,
                # temperature=0.01,
                num_return_sequences=1
            )

            # Decode only new tokens for each sentence
            for j, gen in enumerate(generated_ids):
                output_ids = gen[len(model_inputs.input_ids[j]):].tolist()
                output_text = tokenizer.decode(output_ids, skip_special_tokens=True)
                translations.append(output_text.strip())
            
            print(f"Processed {src_lang}-{tgt_lang} {i + len(batch)}/{len(sentences)} sentences", end='\r')
            
            # Clean up to save memory
            del model_inputs, generated_ids
            torch.cuda.empty_cache()

    return translations

Indian to Indain languages translation is not supported by this model. Indian to English and English to Indian is alone possible.

## English to Indian Language Translation

In [None]:
for src_lang in ['eng_Latn']:
    if src_lang not in results:
        results[src_lang] = {}
    for tgt_lang in In_langs:
        if results[src_lang].get(f'{src_lang}-{tgt_lang}'):
            print(f"Skipping {src_lang} to {tgt_lang}, already exists.")
            continue
        out = translate(ds[src_lang]['text'], code_lang_map[src_lang], code_lang_map[tgt_lang])
        results[src_lang][f'{src_lang}-{tgt_lang}'] = out

        # One every translation pair, save to a file

        try:
            with open("translations_sarvam_flores.json", "w", encoding="utf-8") as f:
                json.dump(results, f, indent=4, ensure_ascii=False)
                print(f"File saved as of {tgt_lang}")
        except IOError as e:
            print(f"Error saving file: {e}")

## Indian Languages to English Translation

In [None]:
for src_lang in In_langs:
    if src_lang not in results:
        results[src_lang] = {}
    for tgt_lang in ['eng_Latn']:
        if results[src_lang].get(f'{src_lang}-{tgt_lang}'):
            print(f"Skipping {src_lang} to {tgt_lang}, already exists.")
            continue
        out = translate(ds[src_lang]['text'], code_lang_map[src_lang], code_lang_map[tgt_lang])
        results[src_lang][f'{src_lang}-{tgt_lang}'] = out

        # One every translation pair, save to a file

        try:
            with open("translations_sarvam_flores.json", "w", encoding="utf-8") as f:
                json.dump(results, f, indent=4, ensure_ascii=False)
                print(f"File saved as of {tgt_lang}")
        except IOError as e:
            print(f"Error saving file: {e}")

In [None]:
results['eng_Latn'].keys()

## Evaluation with chrF++