-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
34 lines (24 loc) · 961 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
import torch
from transformers import AutoTokenizer, MT5ForConditionalGeneration
DEVICE = device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def predict(text, tokenizer, model):
tokenized_text = tokenizer.encode(text, return_tensors="pt").to(DEVICE)
summary_ids = model.generate(
tokenized_text,
max_length=150,
num_beams=5,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model_dir', type=str)
parser.add_argument('text', type=str)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained('google/mt5-small')
model = MT5ForConditionalGeneration.from_pretrained(args.model_dir)
model.to(DEVICE)
print(predict(args.text, tokenizer, model))