diff --git a/translate_oasst.py b/translate_oasst.py index 42a104d..73116d5 100644 --- a/translate_oasst.py +++ b/translate_oasst.py @@ -86,16 +86,18 @@ def batch_translate(texts, source_lang, target_lang, intermediate_lang = 'en'): translated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in translated_outputs] return translated_texts -def batch_translate_madlad(texts, target_lang): +def translate_madlad(texts, target_lang): # Get madlad model, tokenizer = model_cache['madlad'] - # Add the target language to the texts - madlad_texts = [f'<2{target_lang}> ' + text.replace("\n", " ") for text in texts] - input_ids = tokenizer(madlad_texts, max_length=1024, return_tensors="pt").to(device) - outputs = model.generate(**input_ids, max_new_tokens=1024) + translated_texts = [] + for text in texts: + # Add the target language to the text + madlad_text = f'<2{target_lang}> ' + text.replace("\n", " ") + input_ids = tokenizer(madlad_text, return_tensors="pt").input_ids.to(device) + outputs = model.generate(input_ids=input_ids, max_new_tokens=1024) + # Decoding outputs + translated_texts.append(tokenizer.decode(outputs[0], skip_special_tokens=True)) - # Decoding outputs - translated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) return translated_texts @@ -197,7 +199,7 @@ def main(): if not(use_madlad): translated_batch = batch_translate(texts_to_translate, source_lang, target_lang) else: - translated_batch = batch_translate_madlad(texts_to_translate, target_lang) + translated_batch = translate_madlad(texts_to_translate, target_lang) if translated_batch is not None: # Combine original record with translated text