-
Notifications
You must be signed in to change notification settings - Fork 0
/
Batch.py
56 lines (49 loc) · 2.07 KB
/
Batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
print('change use_cuda to false at some point and set it correctly (in main)')
from pyro.distributions import Bernoulli
class Batch:
"""Object for holding a batch of data with mask during training.
Input is a batch from a torch text iterator.
"""
def __init__(self, src, trg, pad_index=0, word_drop=0.0, unk_indx=0, use_cuda=False):
src, src_lengths = src
self.src = src
self.src_lengths = src_lengths
self.src_mask = (src != pad_index).unsqueeze(-2)
self.nseqs = src.size(0)
self.trg = None
self.trg_y = None
self.trg_mask = None
self.trg_lengths = None
self.ntokens = None
if trg is not None:
trg, trg_lengths = trg
self.trg = trg[:, :-1]
#word drop out approach proposed in bowman et. al 2016
mask = trg.new_zeros(self.trg.size(0), self.trg.size(1)).float().fill_(word_drop)
mask = Bernoulli(mask).sample().byte()
try:
mask = mask.bool()
except AttributeError as e:
#just means your using an older pytorch version...
_ = 0
self.trg.masked_fill_(mask, unk_indx)
self.trg_lengths = trg_lengths
self.trg_y = trg[:, 1:]
self.trg_mask = (self.trg_y != pad_index)
self.ntokens = (self.trg_y != pad_index).data.sum().item()
if use_cuda:
self.src = self.src.cuda()
self.src_mask = self.src_mask.cuda()
if trg is not None:
self.trg = self.trg.cuda()
self.trg_y = self.trg_y.cuda()
self.trg_mask = self.trg_mask.cuda()
else:
self.src = self.src.cpu()
self.src_mask = self.src_mask.cpu()
self.src_lengths = self.src_lengths.cpu()
if trg is not None:
self.trg = self.trg.cpu()
self.trg_y = self.trg_y.cpu()
self.trg_mask = self.trg_mask.cpu()
self.trg_lengths = self.trg_lengths.cpu()