In [None]:
from tqdm import tqdm
import torch
import numpy as np

In [None]:
from datasets import load_dataset

train_data = load_dataset("Open-Orca/OpenOrca", split="train", streaming=True)
train_data

In [None]:
c = 0
promt = np.array([])
question = np.array([])
response = np.array([])

for feature in train_data:
    if (
        len(
            (
                feature["system_prompt"]
                + " "
                + feature["question"]
                + " "
                + feature["response"]
            ).split()
        )
        + 4
        <= 512
        and len(feature["system_prompt"]) <= 500
        and len(feature["question"]) <= 500
        and len(feature["response"]) <= 500
    ):
        promt = np.append(
            promt,
            [
                feature["system_prompt"]
                .replace("[\n", "")
                .replace("\n]", "")
                .replace("\n", " ")
            ],
        )
        question = np.append(
            question,
            [
                feature["question"]
                .replace("[\n", "")
                .replace("\n]", "")
                .replace("\n", " ")
            ],
        )
        response = np.append(
            response,
            [
                feature["response"]
                .replace("[\n", "")
                .replace("\n]", "")
                .replace("\n", " ")
            ],
        )

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


def translate(model, tokenizer, sentences, batch_size):
    ru_sent = []
    device = "cuda:1" if torch.cuda.is_available() else "cpu"
    with torch.no_grad():
        for i in tqdm(range(0, len(sentences), batch_size)):
            batch = sentences[i : i + batch_size]

            input_ids = tokenizer.batch_encode_plus(
                batch,
                padding="max_length",
                max_length=512,
                return_tensors="pt",
                truncation=True,
            )["input_ids"].to(device)
            generated_tokens = model.generate(
                input_ids,
                max_length=512,
                forced_bos_token_id=tokenizer.lang_code_to_id["rus_Cyrl"],
            )
            output_ids = tokenizer.batch_decode(
                generated_tokens, skip_special_tokens=True
            )

            ru_sent = ru_sent + output_ids

    return ru_sent

In [None]:
checkpoint = "facebook/nllb-200-distilled-600M"
device = "cuda:1" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)
model.eval()

In [None]:
batch_size = 128
p_ru_sent = translate(model, tokenizer, promt, batch_size)
q_ru_sent = translate(model, tokenizer, question, batch_size)
r_ru_sent = translate(model, tokenizer, response, batch_size)

In [None]:
pr = ["" if len(i.split()) == 1 and i.lower() == "отношения" else i for i in p_ru_sent]
que = ["" if len(i.split()) == 1 and i.lower() == "отношения" else i for i in q_ru_sent]
res = ["" if len(i.split()) == 1 and i.lower() == "отношения" else i for i in r_ru_sent]

In [None]:
pr, que, res

In [None]:
res_data = [
    {"promt": pr[i], "question": que[i], "response": res[i]} for i in range(len(pr))
]
res_data

In [None]:
res_data[0]

In [None]:
import jsonlines

with jsonlines.open("orca_traslation_2.jsonl", mode="w") as writer:
    writer.write(res_data)