-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
executable file
·22 lines (20 loc) · 892 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torch.nn import functional as F
def top_k_top_p_filter(logits, top_k: int = 0, top_p: float = 0.0):
if top_k > 0:
filter = torch.topk(logits, min(top_k, logits.size(-1)))[0]
logits[logits < filter[:, [-1]]] = float('-inf')
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)
filter = cumulative_probs > top_p
filter[..., 1:] = filter[..., :-1].clone()
filter[..., 0] = 0
indices_to_remove = filter.scatter(1, sorted_indices, filter)
logits[indices_to_remove] = float('-inf')
return logits
def pad_collate(batch, tokenizer):
data = tokenizer(batch, padding="max_length",
truncation=True, return_tensors="pt")
return data