# Предсказание модели

In [1]:
from transformers import T5ForConditionalGeneration, AutoTokenizer
import torch

## T5

Выбираем обученную модель:

In [2]:
model = T5ForConditionalGeneration.from_pretrained(r'weights/t5_small_taiga_aggl__results/t5_small_cl_train_80000')

base_model_name = 't5-small'
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Функция для предсказания:

In [3]:
def generate(text, model, n=None, max_length='auto', temperature=0.0, beams=3):
    texts = [text] if isinstance(text, str) else text
    
    inputs = tokenizer(texts, return_tensors='pt', padding=True)['input_ids'].to(model.device)
    
    if max_length == 'auto':
        max_length = int(inputs.shape[1] * 1.2) + 10
    
    result = model.generate(
        inputs, 
        num_return_sequences=n or 1, 
        do_sample=False, 
        temperature=temperature, 
        max_length=max_length,
        repetition_penalty=3.0,
        num_beams=beams,
    )
    
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]
    
    if not n and isinstance(text, str):
        return texts[0]
    return texts

Выводим предсказание модели:

In [5]:
print(generate(['дипломат'], model, temperature=20.0, beams=10))

['диломат,NOUN,Inan,Nom,M']


## RuPrompts

In [6]:
from transformers import GPT2LMHeadModel, AutoTokenizer
from transformers import pipeline
from ruprompts import Prompt

In [7]:
backbone_id = "sberbank-ai/rugpt3large_based_on_gpt2"

model = GPT2LMHeadModel.from_pretrained(backbone_id)
tokenizer = AutoTokenizer.from_pretrained(backbone_id, pad_token="<pad>", eos_token="<pad>")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
prompt = Prompt.from_pretrained(r'weights/rupr_not_clustered_syn/checkpoint-80850')

ppln = pipeline("text2text-generation-with-prompt", prompt=prompt, model=model, tokenizer=tokenizer)

In [11]:
ppln({"form": "дипломат"}, do_sample=True)

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


[{'generated_text': 'NOUN,Inan,Nom,Masc,Sing'}]