In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [3]:
SEED = 0

## Прикручиваем LaBSE.

In [4]:
from sentence_transformers import SentenceTransformer

In [5]:
sentences = ["This is an example sentence", "Each sentence is converted"]

LaBSE_name = 'sentence-transformers/LaBSE'
LaBSE_small_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"

sentence_model = SentenceTransformer(LaBSE_small_name)
embeddings = sentence_model.encode(sentences)
print(embeddings.shape)

Downloading (…)0fe39/.gitattributes:   0%|          | 0.00/968 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)83e900fe39/README.md:   0%|          | 0.00/3.79k [00:00<?, ?B/s]

Downloading (…)e900fe39/config.json:   0%|          | 0.00/645 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/471M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)tencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

Downloading unigram.json:   0%|          | 0.00/14.8M [00:00<?, ?B/s]

Downloading (…)900fe39/modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

[2023-07-27 17:20:25,028] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)
(2, 384)


## Учим

In [7]:
from sentence_transformers import ParallelSentencesDataset, losses
from datasets import load_dataset
from torch.utils.data import DataLoader

In [8]:
from transformers import set_seed

In [9]:
# Define your fine-tuning task
dataset = load_dataset("RicardoRei/wmt-da-human-evaluation",split="train")
train_dataset = dataset.filter(lambda example: (example["year"] == 2022) & (example["lp"] in ["en-ru", "zh-en", "en-de"]))
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=128)

In [10]:
checkpoint = "bigscience/mt0-small"
# checkpoint = "bigscience/mt0-base"
# checkpoint = "bigscience/mt0-large"
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype="auto")
train_loss = losses.CosineSimilarityLoss(model)

In [11]:
from transformers import pipeline

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
try:
    del pipe
except NameError:
    pass
pipe = pipeline("text2text-generation",
                model=checkpoint,
                device="cpu",
                tokenizer=tokenizer
               )

In [None]:
def get_gimba_prompt(source_lang, source_seg, target_lang, target_seg):
    return f'''Score the following translation from {source_lang} to {target_lang} with respect to the human reference on a continuous scale from 0 to 100, where a score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".
    {source_lang} source: "{source_seg}"
    {target_lang} human reference: {reference_seg}
    {target_lang} translation: "{target_seg}"
    Score:'''

In [None]:
SEED += 1
set_seed(SEED)

In [None]:
counter = 5

import re
score_values_2 = []
for item in train_dataset:
    source_lang, target_lang = item['lp'].split('-')
    source_seg = item['src']
    reference_seg = item['mt']
    target_seg = item['ref']
    prompt = get_gimba_prompt(source_lang, source_seg, target_lang, target_seg)
    try:
        data_list = pipe(prompt, temperature=0.15)
        for item in data_list:
            generated_text = item['generated_text']
            
            print("generated_text", generated_text)
            print("item generated text:", item["generated_text"], " ;")
            print("score str:", generated_text, " ;")
            
            score_match = re.search(r'\d+(\.\d+)?', generated_text)
            if score_match:
                score = float(score_match.group())
                score_values_2.append(score)
            else:
                print(f"Ошибка: не удалось извлечь число из строки '{score_str}'")
    except Exception as e:
        print(f"Ошибка при обработке данных: {e}")
    
    print() # разделить вывод
    counter -= 1
    if counter == 0:
        break # debug, to see only one scoring