-
Notifications
You must be signed in to change notification settings - Fork 136
/
transformer.py
68 lines (56 loc) · 2.41 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from torch import nn
import copy
from models.containers import ModuleList
from ..captioning_model import CaptioningModel
class Transformer(CaptioningModel):
def __init__(self, bos_idx, encoder, decoder):
super(Transformer, self).__init__()
self.bos_idx = bos_idx
self.encoder = encoder
self.decoder = decoder
self.register_state('enc_output', None)
self.register_state('mask_enc', None)
self.init_weights()
@property
def d_model(self):
return self.decoder.d_model
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, images, seq, *args):
enc_output, mask_enc = self.encoder(images)
dec_output = self.decoder(seq, enc_output, mask_enc)
return dec_output
def init_state(self, b_s, device):
return [torch.zeros((b_s, 0), dtype=torch.long, device=device),
None, None]
def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
it = None
if mode == 'teacher_forcing':
raise NotImplementedError
elif mode == 'feedback':
if t == 0:
self.enc_output, self.mask_enc = self.encoder(visual)
if isinstance(visual, torch.Tensor):
it = visual.data.new_full((visual.shape[0], 1), self.bos_idx).long()
else:
it = visual[0].data.new_full((visual[0].shape[0], 1), self.bos_idx).long()
else:
it = prev_output
return self.decoder(it, self.enc_output, self.mask_enc)
class TransformerEnsemble(CaptioningModel):
def __init__(self, model: Transformer, weight_files):
super(TransformerEnsemble, self).__init__()
self.n = len(weight_files)
self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
for i in range(self.n):
state_dict_i = torch.load(weight_files[i])['state_dict']
self.models[i].load_state_dict(state_dict_i)
def step(self, t, prev_output, visual, seq, mode='teacher_forcing', **kwargs):
out_ensemble = []
for i in range(self.n):
out_i = self.models[i].step(t, prev_output, visual, seq, mode, **kwargs)
out_ensemble.append(out_i.unsqueeze(0))
return torch.mean(torch.cat(out_ensemble, 0), dim=0)