In [1]:
import os
import json
import spacy

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from utils import separate_text_and_code, extract_code_block, translate_dataset, CODE_TRANSLATION_CONSTANT

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang="eng_Latn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device)

In [3]:
nlp = spacy.load("en_core_web_sm")
LANGUAGE = 'en'

In [4]:
data_path = '../../data/2023-04-12_oasst_ready.trees.jsonl'
with open(data_path, 'r') as f:
    data = [json.loads(line) for line in f]

In [5]:
for tree in data:
    if extract_code_block(tree['prompt']['text']):
        break

In [6]:
def batch_translate(text, device):
    doc = nlp(text)
    sentences = [sent.text for sent in doc.sents]
    inputs = tokenizer(sentences, return_tensors="pt", truncation=True, padding='max_length').to(device)
    translated_tokens = model.generate(
    **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["slv_Latn"], max_new_tokens=1024
    )
    out = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
    translated = ' '.join(out)
    return translated

In [7]:
batch_translate(data[0]['prompt']['text'], device)

'Lahko napišete kratek uvod o pomembnosti izraza "monopsony" v ekonomiji? Uporabite primere, povezane z morebitnimi monopsonijami na trgu dela, in navedite ustrezne raziskave.'

In [8]:
def translate(text):
    # extract code blocks
    text, code_blocks = separate_text_and_code(text)
    # do not translate numbers
    if text.isdigit():
        translation = text
    else:
        translation = batch_translate(text, device)
    # place the code blocks back into the translation
    for code in code_blocks:
        translation = translation.replace(CODE_TRANSLATION_CONSTANT, code, 1)
    return translation

In [9]:
translate('Hello world! 😀 😃 😄 😁 😆 😅 😂 🤣 🥲 🥹 ')

'Živjo svetu! 😀 😃 😄 😁 😆 😅     '

In [10]:
os.makedirs('../data/nllb', exist_ok=True)

In [11]:
#failed_ids = translate_dataset(data, translate, translations_path='../data/nllb', language=LANGUAGE)