# Загрузка и предобработка данных

In [None]:
!git lfs install
!git clone https: // huggingface.co/facebook/nllb-200-distilled-600M

In [None]:
from datasets import load_dataset

train_dataset = load_dataset("SirNeural/flan_v2", cache_dir="flan_v2", data_files="cot_*_train.jsonl.gz")

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", use_auth_token=True, src_lang="eng_Latn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M", use_auth_token=True)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

In [None]:
article = "Hi cutie"
inputs = tokenizer(article, return_tensors="pt")

translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["rus_Cyrl"], max_length=30)
tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

In [None]:
train_dataset = train_dataset['train']

In [None]:
train_dataset

In [None]:
from tqdm.auto import tqdm

tasks = []
for i in tqdm(range(len(train_dataset))):
    t = train_dataset[i]['inputs'].find(':')
    if train_dataset[i]['inputs'][:t] not in tasks:
        tasks.append(train_dataset[i]['inputs'][:t])
tasks[:20]

In [None]:
sentences = []

for i in tqdm(range(len(train_dataset))):
    sentences.append(train_dataset[i]['inputs'])
    sentences.append(train_dataset[i]['targets'])

# Запуск перевода

In [None]:
import jsonlines
from tqdm.auto import tqdm
from torch.cuda.amp import autocast


def predict(
        model_name,
        data_,
        max_source_tokens_count=520,
        max_target_tokens_count=520,
        use_cuda=True,
        batch_size=128
):
    russian_samples = []

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True, src_lang="eng_Latn")
    device = torch.device("cuda:1") if use_cuda else torch.device("cpu")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=True).to(device)
    model.eval()
    with autocast(dtype=torch.float16):
        with torch.no_grad():
            for i in tqdm(range(0, len(sentences), batch_size)):
                batch = sentences[i:i + batch_size]
                input_ids = tokenizer.prepare_seq2seq_batch(
                    batch,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=max_source_tokens_count
                )["input_ids"].to(device)

                output_ids = model.generate(
                    input_ids=input_ids,
                    max_length=max_target_tokens_count,
                    forced_bos_token_id=tokenizer.lang_code_to_id["rus_Cyrl"]
                )
                decoded_output = tokenizer.batch_decode(output_ids, skip_special_tokens=True,
                                                        clean_up_tokenization_spaces=False)

                russian_samples.extend(decoded_output)

                if (i // batch_size) % 10 == 0:
                    with jsonlines.open('flan_traslation_v2.jsonl', mode='w') as writer:
                        writer.write(russian_samples)
    return russian_samples

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

russian_samples = predict("facebook/nllb-200-distilled-600M", sentences)

In [None]:
with jsonlines.open('flan_traslation_v2.jsonl', mode='w') as writer:
    writer.write(russian_samples)

In [None]:
import jsonlines

with jsonlines.open('flan_traslation_v2.jsonl') as reader:
    f = reader.read()
f[:10]

In [None]:
from tqdm import tqdm

ds = []
t = {}
for i in tqdm(range(0, len(f), 2)):
    t = {'inputs': f[i], 'target': f[i + 1]}
    ds.append(t)
ds[:10]

In [None]:
with jsonlines.open('flan_traslation_v22.jsonl', mode='w') as writer:
    writer.write(ds)

In [None]:
import json

with open('flan_traslation_v22.jsonl', 'w') as f:
    for item in ds:
        json.dump(item, f)
        f.write('\n')

In [None]:
def jsonl_reader(file_name):
    with open(file_name, "r") as file:
        reader = jsonlines.Reader(file)
        for line in reader.iter():
            print(line)

In [None]:
jsonl_reader('flan_traslation_v22.jsonl')