In [1]:
from transformers import (T5Tokenizer, 
                          T5Config, 
                          T5ForConditionalGeneration)

In [61]:
T5_PATH = 't5-base' # "t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"
DEVICE = 'cpu' #torch.device('cuda' if torch.cuda.is_available() else 'cpu') # My envirnment uses CPU

In [62]:
t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)
t5_config = T5Config.from_pretrained(T5_PATH)
t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config).to(DEVICE)

In [78]:
# Input text
text = '''Litwo ojczyzno moja ty jesteś jak zdrowie
<extra_id_0> <extra_id_1> <extra_id_2> posłowie</s>'''

text = '''O Lithuania, my native land,
you are like health--so valued when lost
beyond recovery; let these words now stand
<extra_id_0> cost.</s>
'''

encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')
input_ids = encoded['input_ids'].to(DEVICE)

In [79]:
outputs = t5_mlm.generate(input_ids=input_ids, 
                          num_beams=200, num_return_sequences=1,
                          max_length=50)

_0_index = text.index('<extra_id_0>')
_result_prefix = text[:_0_index]
# _2_index = text.index('<extra_id_2>')
_result_suffix = text[_0_index+12:]  # 12 is the length of <extra_id_0>

In [80]:
print(_result_prefix)

O Lithuania, my native land,
you are like health--so valued when lost
beyond recovery; let these words now stand



In [81]:
txt = t5_tokenizer.decode(outputs[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)
print(txt)

<pad> <extra_id_0> the<extra_id_1> the cost. O Lithuania, my native land, you are like health--so valued when lost beyond recovery, so valued when lost beyond recovery, so valued when lost beyond recovery, so valued when lost beyond recovery, so


In [82]:
print(_result_suffix)

 cost.</s>



In [83]:
def _filter(output, end_token='<extra_id_1>'):
    # The first token is <unk> (inidex at 0) and the second token is <extra_id_0> (indexed at 32099)
    _txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)
    if end_token in _txt:
        _end_token_index = _txt.index(end_token)
        return _result_prefix + _txt[:_end_token_index] + _result_suffix
    else:
        return _result_prefix + _txt + _result_suffix

results = list(map(_filter, outputs))
print(results[0])

O Lithuania, my native land,
you are like health--so valued when lost
beyond recovery; let these words now stand
the cost.</s>

