<a href="https://colab.research.google.com/github/daKeshra/Extending-AfriXNLI/blob/main/Languge_translation_for_AfriXNLI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U transformers

## Local Inference on GPU
Model page: https://huggingface.co/facebook/nllb-200-distilled-1.3B

⚠️ If the generated code snippets do not work, please open an issue on either the [model repo](https://huggingface.co/facebook/nllb-200-distilled-1.3B)
			and/or on [huggingface.js](https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries-snippets.ts) 🙏

In [None]:
!pip install transformers sentencepiece sacremoses
!pip install torch

In [None]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("translation", model="facebook/nllb-200-distilled-1.3B")

In [None]:
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from tqdm import tqdm
import os

model_name = "facebook/nllb-200-distilled-600M"
src_lang = "eng_Latn"
tgt_lang = "hau_Latn"

dev_file = "https://raw.githubusercontent.com/Tonative-Research/Extending-AfriXNLI/refs/heads/main/dev/xnli.dev%20-%20eng_dataset.csv"
test_file = "https://raw.githubusercontent.com/Tonative-Research/Extending-AfriXNLI/refs/heads/main/test/xnli.test%20-%20english.csv"

os.mkdir('dev')
sentence1_dev= '/dev/sentence1_translated.csv'
sentence2_dev= '/dev/sentence2_translated.csv'

os.mkdir('test')
sentence1_test= '/test/sentence1_translated.csv'
sentence2_test= '/test/sentence2_translated.csv'

In [None]:
df = pd.read_csv(dev_file, sep=",")
df.head()

Load model

In [None]:
torch.cuda.is_available()

In [None]:
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=src_lang)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)


In [None]:
def translate_batch(texts, batch_size=8):
    translations = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Translating"):
        batch = [str(t) if pd.notna(t) else "" for t in texts[i:i+batch_size]]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        translated_tokens = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer._convert_token_to_id_with_added_voc(tgt_lang),
            max_length=512
        )
        batch_translations = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
        translations.extend(batch_translations)
    return translations

In [None]:
def process_file(input_path, output_path1, output_path2):
    print(f"Processing {input_path}...")
    df = pd.read_csv(input_path, sep=",")

    sentence1_translated_df = pd.DataFrame()
    sentence2_translated_df = pd.DataFrame()

    sentence1_translated_df[['promptID', 'sentence1']] = df[['promptID', 'sentence1']]
    sentence2_translated_df[['promptID', 'sentence2']] = df[['promptID', 'sentence2']]



    print("Translating sentence...")
    sentence1_translated_df["sentence1_translated"] = translate_batch(df["sentence1"].tolist(), batch_size=8)
    sentence2_translated_df["sentence2_translated"] = translate_batch(df["sentence2"].tolist(), batch_size=8)

    #sentences
    sentence1_translated_df = sentence1_translated_df[["promptID", "sentence1", "sentence1_translated"]]
    sentence2_translated_df = sentence2_translated_df[["promptID", "sentence2", "sentence2_translated"]]


    #duplicates
    sentence1_translated_df = sentence1_translated_df.drop_duplicates()
    sentence2_translated_df = sentence2_translated_df.drop_duplicates()

    #save
    sentence1_translated_df.to_csv(output_path1, index=False)
    sentence2_translated_df.to_csv(output_path2, index=False)

    print(f"Saved translated file {output_path1} and {output_path2}")

In [None]:
process_file(dev_file, sentence1_dev, sentence2_dev)

In [None]:
process_file(test_file, sentence1_test, sentence2_test)