In [1]:
import torch
from tqdm import tqdm

def generate(category_with_word, tokenizer, model, device, num=5, max_length=32):
    sentence_list = []
    
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    inputs = tokenizer(category_with_word + "|", return_tensors="pt").to(device)
    token_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    for _ in tqdm(range(num)):
        gen_ids = model.generate(
            input_ids=token_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            repetition_penalty=2.0,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bos_token_id=tokenizer.bos_token_id,
            use_cache=True,
            do_sample=True,
        )
        sentence = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        
        try:
            partitioned_sentence = sentence.split("|", 1)[1]
        except IndexError:
            partitioned_sentence = sentence
        
        sentence = partitioned_sentence.replace("|", " ").strip()
        sentence_list.append(sentence)
    
    return sentence_list