forked from hfxunlp/transformer
/
rank_loss.py
100 lines (73 loc) · 2.99 KB
/
rank_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#encoding: utf-8
# usage: python rank.py rsf h5f models...
norm_token = True
import sys
import torch
from torch.cuda.amp import autocast
from tqdm import tqdm
import h5py
import cnfg.base as cnfg
from cnfg.ihyp import *
from transformer.NMT import NMT
from transformer.EnsembleNMT import NMT as Ensemble
from parallel.parallelMT import DataParallelMT
from parallel.base import DataParallelCriterion
from loss.base import LabelSmoothingLoss
from utils.base import *
from utils.fmt.base import pad_id
from utils.fmt.base4torch import parse_cuda
def load_fixing(module):
if hasattr(module, "fix_load"):
module.fix_load()
td = h5py.File(sys.argv[2], "r")
ntest = td["ndata"][:].item()
nword = td["nword"][:].tolist()
nwordi, nwordt = nword[0], nword[-1]
if len(sys.argv) == 4:
mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)
mymodel = load_model_cpu(sys.argv[3], mymodel)
mymodel.apply(load_fixing)
else:
models = []
for modelf in sys.argv[3:]:
tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)
tmp = load_model_cpu(modelf, tmp)
tmp.apply(load_fixing)
models.append(tmp)
mymodel = Ensemble(models)
mymodel.eval()
lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='none', forbidden_index=cnfg.forbidden_indexes)
use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid)
use_amp = cnfg.use_amp and use_cuda
# Important to make cudnn methods deterministic
set_random_seed(cnfg.seed, use_cuda)
if cuda_device:
mymodel.to(cuda_device)
lossf.to(cuda_device)
if multi_gpu:
mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False)
lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True)
ens = "\n".encode("utf-8")
src_grp, tgt_grp = td["src"], td["tgt"]
with open(sys.argv[1], "wb") as f:
with torch.no_grad():
for i in tqdm(range(ntest)):
_curid = str(i)
seq_batch = torch.from_numpy(src_grp[_curid][:])
seq_o = torch.from_numpy(tgt_grp[_curid][:])
if cuda_device:
seq_batch = seq_batch.to(cuda_device)
seq_o = seq_o.to(cuda_device)
seq_batch, seq_o = seq_batch.long(), seq_o.long()
lo = seq_o.size(1) - 1
ot = seq_o.narrow(1, 1, lo).contiguous()
with autocast(enabled=use_amp):
output = mymodel(seq_batch, seq_o.narrow(1, 0, lo))
loss = lossf(output, ot).sum(-1).view(-1, lo).sum(-1)
if norm_token:
lenv = ot.ne(pad_id).int().sum(-1).to(loss)
loss = loss / lenv
f.write("\n".join([str(rsu) for rsu in loss.tolist()]).encode("utf-8"))
loss = output = ot = seq_batch = seq_o = None
f.write(ens)
td.close()