Skip to content

Commit

Permalink
Proper max len
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikTromp committed Jan 8, 2024
1 parent 24301ff commit e85b713
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions translate_oasst.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,18 @@ def batch_translate(texts, source_lang, target_lang, intermediate_lang = 'en'):
translated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in translated_outputs]
return translated_texts

def batch_translate_madlad(texts, target_lang):
def translate_madlad(texts, target_lang):
# Get madlad
model, tokenizer = model_cache['madlad']
# Add the target language to the texts
madlad_texts = [f'<2{target_lang}> ' + text.replace("\n", " ") for text in texts]
input_ids = tokenizer(madlad_texts, max_length=1024, return_tensors="pt").to(device)
outputs = model.generate(**input_ids, max_new_tokens=1024)
translated_texts = []
for text in texts:
# Add the target language to the text
madlad_text = f'<2{target_lang}> ' + text.replace("\n", " ")
input_ids = tokenizer(madlad_text, return_tensors="pt").input_ids.to(device)
outputs = model.generate(input_ids=input_ids, max_new_tokens=1024)
# Decoding outputs
translated_texts.append(tokenizer.decode(outputs[0], skip_special_tokens=True))

# Decoding outputs
translated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return translated_texts


Expand Down Expand Up @@ -197,7 +199,7 @@ def main():
if not(use_madlad):
translated_batch = batch_translate(texts_to_translate, source_lang, target_lang)
else:
translated_batch = batch_translate_madlad(texts_to_translate, target_lang)
translated_batch = translate_madlad(texts_to_translate, target_lang)

if translated_batch is not None:
# Combine original record with translated text
Expand Down

0 comments on commit e85b713

Please sign in to comment.