-
Notifications
You must be signed in to change notification settings - Fork 228
/
model_motifs.py
398 lines (330 loc) · 18.2 KB
/
model_motifs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from maskrcnn_benchmark.modeling import registry
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import PackedSequence
from torch.nn import functional as F
from maskrcnn_benchmark.modeling.utils import cat
from .utils_motifs import obj_edge_vectors, center_x, sort_by_score, to_onehot, get_dropout_mask, nms_overlaps, encode_box_info
class FrequencyBias(nn.Module):
"""
The goal of this is to provide a simplified way of computing
P(predicate | obj1, obj2, img).
"""
def __init__(self, cfg, statistics, eps=1e-3):
super(FrequencyBias, self).__init__()
pred_dist = statistics['pred_dist'].float()
assert pred_dist.size(0) == pred_dist.size(1)
self.num_objs = pred_dist.size(0)
self.num_rels = pred_dist.size(2)
pred_dist = pred_dist.view(-1, self.num_rels)
self.obj_baseline = nn.Embedding(self.num_objs*self.num_objs, self.num_rels)
with torch.no_grad():
self.obj_baseline.weight.copy_(pred_dist, non_blocking=True)
def index_with_labels(self, labels):
"""
:param labels: [batch_size, 2]
:return:
"""
return self.obj_baseline(labels[:, 0] * self.num_objs + labels[:, 1])
def index_with_probability(self, pair_prob):
"""
:param labels: [batch_size, num_obj, 2]
:return:
"""
batch_size, num_obj, _ = pair_prob.shape
joint_prob = pair_prob[:,:,0].contiguous().view(batch_size, num_obj, 1) * pair_prob[:,:,1].contiguous().view(batch_size, 1, num_obj)
return joint_prob.view(batch_size, num_obj*num_obj) @ self.obj_baseline.weight
def forward(self, labels):
# implement through index_with_labels
return self.index_with_labels(labels)
class DecoderRNN(nn.Module):
def __init__(self, config, obj_classes, embed_dim, inputs_dim, hidden_dim, rnn_drop):
super(DecoderRNN, self).__init__()
self.cfg = config
self.obj_classes = obj_classes
self.embed_dim = embed_dim
obj_embed_vecs = obj_edge_vectors(['start'] + self.obj_classes, wv_dir=self.cfg.GLOVE_DIR, wv_dim=embed_dim)
self.obj_embed = nn.Embedding(len(self.obj_classes)+1, embed_dim)
with torch.no_grad():
self.obj_embed.weight.copy_(obj_embed_vecs, non_blocking=True)
self.hidden_size = hidden_dim
self.inputs_dim = inputs_dim
self.input_size = self.inputs_dim + self.embed_dim
self.nms_thresh = self.cfg.TEST.RELATION.LATER_NMS_PREDICTION_THRES
self.rnn_drop=rnn_drop
self.input_linearity = torch.nn.Linear(self.input_size, 6 * self.hidden_size, bias=True)
self.state_linearity = torch.nn.Linear(self.hidden_size, 5 * self.hidden_size, bias=True)
self.out_obj = nn.Linear(self.hidden_size, len(self.obj_classes))
self.init_parameters()
def init_parameters(self):
# Use sensible default initializations for parameters.
with torch.no_grad():
torch.nn.init.constant_(self.state_linearity.bias, 0.0)
torch.nn.init.constant_(self.input_linearity.bias, 0.0)
def lstm_equations(self, timestep_input, previous_state, previous_memory, dropout_mask=None):
"""
Does the hairy LSTM math
:param timestep_input:
:param previous_state:
:param previous_memory:
:param dropout_mask:
:return:
"""
# Do the projections for all the gates all at once.
projected_input = self.input_linearity(timestep_input)
projected_state = self.state_linearity(previous_state)
# Main LSTM equations using relevant chunks of the big linear
# projections of the hidden state and inputs.
input_gate = torch.sigmoid(projected_input[:, 0 * self.hidden_size:1 * self.hidden_size] +
projected_state[:, 0 * self.hidden_size:1 * self.hidden_size])
forget_gate = torch.sigmoid(projected_input[:, 1 * self.hidden_size:2 * self.hidden_size] +
projected_state[:, 1 * self.hidden_size:2 * self.hidden_size])
memory_init = torch.tanh(projected_input[:, 2 * self.hidden_size:3 * self.hidden_size] +
projected_state[:, 2 * self.hidden_size:3 * self.hidden_size])
output_gate = torch.sigmoid(projected_input[:, 3 * self.hidden_size:4 * self.hidden_size] +
projected_state[:, 3 * self.hidden_size:4 * self.hidden_size])
memory = input_gate * memory_init + forget_gate * previous_memory
timestep_output = output_gate * torch.tanh(memory)
highway_gate = torch.sigmoid(projected_input[:, 4 * self.hidden_size:5 * self.hidden_size] +
projected_state[:, 4 * self.hidden_size:5 * self.hidden_size])
highway_input_projection = projected_input[:, 5 * self.hidden_size:6 * self.hidden_size]
timestep_output = highway_gate * timestep_output + (1 - highway_gate) * highway_input_projection
# Only do dropout if the dropout prob is > 0.0 and we are in training mode.
if dropout_mask is not None and self.training:
timestep_output = timestep_output * dropout_mask
return timestep_output, memory
def forward(self, inputs, initial_state=None, labels=None, boxes_for_nms=None):
if not isinstance(inputs, PackedSequence):
raise ValueError('inputs must be PackedSequence but got %s' % (type(inputs)))
assert isinstance(inputs, PackedSequence)
sequence_tensor, batch_lengths, _, _ = inputs
batch_size = batch_lengths[0]
# We're just doing an LSTM decoder here so ignore states, etc
if initial_state is None:
previous_memory = sequence_tensor.new().resize_(batch_size, self.hidden_size).fill_(0)
previous_state = sequence_tensor.new().resize_(batch_size, self.hidden_size).fill_(0)
else:
assert len(initial_state) == 2
previous_memory = initial_state[1].squeeze(0)
previous_state = initial_state[0].squeeze(0)
previous_obj_embed = self.obj_embed.weight[0, None].expand(batch_size, self.embed_dim)
if self.rnn_drop > 0.0:
dropout_mask = get_dropout_mask(self.rnn_drop, previous_memory.size(), previous_memory.device)
else:
dropout_mask = None
# Only accumulating label predictions here, discarding everything else
out_dists = []
out_commitments = []
end_ind = 0
for i, l_batch in enumerate(batch_lengths):
start_ind = end_ind
end_ind = end_ind + l_batch
if previous_memory.size(0) != l_batch:
previous_memory = previous_memory[:l_batch]
previous_state = previous_state[:l_batch]
previous_obj_embed = previous_obj_embed[:l_batch]
if dropout_mask is not None:
dropout_mask = dropout_mask[:l_batch]
timestep_input = torch.cat((sequence_tensor[start_ind:end_ind], previous_obj_embed), 1)
previous_state, previous_memory = self.lstm_equations(timestep_input, previous_state,
previous_memory, dropout_mask=dropout_mask)
pred_dist = self.out_obj(previous_state)
out_dists.append(pred_dist)
if self.training:
labels_to_embed = labels[start_ind:end_ind].clone()
# Whenever labels are 0 set input to be our max prediction
nonzero_pred = pred_dist[:, 1:].max(1)[1] + 1
is_bg = (labels_to_embed == 0).nonzero()
if is_bg.dim() > 0:
labels_to_embed[is_bg.squeeze(1)] = nonzero_pred[is_bg.squeeze(1)]
out_commitments.append(labels_to_embed)
previous_obj_embed = self.obj_embed(labels_to_embed+1)
else:
assert l_batch == 1
out_dist_sample = F.softmax(pred_dist, dim=1)
best_ind = out_dist_sample[:, 1:].max(1)[1] + 1
out_commitments.append(best_ind)
previous_obj_embed = self.obj_embed(best_ind+1)
# Do NMS here as a post-processing step
if boxes_for_nms is not None and not self.training:
is_overlap = nms_overlaps(boxes_for_nms).view(
boxes_for_nms.size(0), boxes_for_nms.size(0), boxes_for_nms.size(1)
).cpu().numpy() >= self.nms_thresh
out_dists_sampled = F.softmax(torch.cat(out_dists,0), 1).cpu().numpy()
out_dists_sampled[:,0] = 0
out_commitments = out_commitments[0].new(len(out_commitments)).fill_(0)
for i in range(out_commitments.size(0)):
box_ind, cls_ind = np.unravel_index(out_dists_sampled.argmax(), out_dists_sampled.shape)
out_commitments[int(box_ind)] = int(cls_ind)
out_dists_sampled[is_overlap[box_ind,:,cls_ind], cls_ind] = 0.0
out_dists_sampled[box_ind] = -1.0 # This way we won't re-sample
out_commitments = out_commitments
else:
out_commitments = torch.cat(out_commitments, 0)
return torch.cat(out_dists, 0), out_commitments
class LSTMContext(nn.Module):
"""
Modified from neural-motifs to encode contexts for each objects
"""
def __init__(self, config, obj_classes, rel_classes, in_channels):
super(LSTMContext, self).__init__()
self.cfg = config
self.obj_classes = obj_classes
self.rel_classes = rel_classes
self.num_obj_classes = len(obj_classes)
# mode
if self.cfg.MODEL.ROI_RELATION_HEAD.USE_GT_BOX:
if self.cfg.MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL:
self.mode = 'predcls'
else:
self.mode = 'sgcls'
else:
self.mode = 'sgdet'
# word embedding
self.embed_dim = self.cfg.MODEL.ROI_RELATION_HEAD.EMBED_DIM
obj_embed_vecs = obj_edge_vectors(self.obj_classes, wv_dir=self.cfg.GLOVE_DIR, wv_dim=self.embed_dim)
self.obj_embed1 = nn.Embedding(self.num_obj_classes, self.embed_dim)
self.obj_embed2 = nn.Embedding(self.num_obj_classes, self.embed_dim)
with torch.no_grad():
self.obj_embed1.weight.copy_(obj_embed_vecs, non_blocking=True)
self.obj_embed2.weight.copy_(obj_embed_vecs, non_blocking=True)
# position embedding
self.pos_embed = nn.Sequential(*[
nn.Linear(9, 32), nn.BatchNorm1d(32, momentum= 0.001),
nn.Linear(32, 128), nn.ReLU(inplace=True),
])
# object & relation context
self.obj_dim = in_channels
self.dropout_rate = self.cfg.MODEL.ROI_RELATION_HEAD.CONTEXT_DROPOUT_RATE
self.hidden_dim = self.cfg.MODEL.ROI_RELATION_HEAD.CONTEXT_HIDDEN_DIM
self.nl_obj = self.cfg.MODEL.ROI_RELATION_HEAD.CONTEXT_OBJ_LAYER
self.nl_edge = self.cfg.MODEL.ROI_RELATION_HEAD.CONTEXT_REL_LAYER
assert self.nl_obj > 0 and self.nl_edge > 0
# TODO Kaihua Tang
# AlternatingHighwayLSTM is invalid for pytorch 1.0
self.obj_ctx_rnn = torch.nn.LSTM(
input_size=self.obj_dim+self.embed_dim + 128,
hidden_size=self.hidden_dim,
num_layers=self.nl_obj,
dropout=self.dropout_rate if self.nl_obj > 1 else 0,
bidirectional=True)
self.decoder_rnn = DecoderRNN(self.cfg, self.obj_classes, embed_dim=self.embed_dim,
inputs_dim=self.hidden_dim + self.obj_dim + self.embed_dim + 128,
hidden_dim=self.hidden_dim,
rnn_drop=self.dropout_rate)
self.edge_ctx_rnn = torch.nn.LSTM(
input_size=self.embed_dim + self.hidden_dim + self.obj_dim,
hidden_size=self.hidden_dim,
num_layers=self.nl_edge,
dropout=self.dropout_rate if self.nl_edge > 1 else 0,
bidirectional=True)
# map bidirectional hidden states of dimension self.hidden_dim*2 to self.hidden_dim
self.lin_obj_h = nn.Linear(self.hidden_dim*2, self.hidden_dim)
self.lin_edge_h = nn.Linear(self.hidden_dim*2, self.hidden_dim)
# untreated average features
self.average_ratio = 0.0005
self.effect_analysis = config.MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_ANALYSIS
if self.effect_analysis:
self.register_buffer("untreated_dcd_feat", torch.zeros(self.hidden_dim + self.obj_dim + self.embed_dim + 128))
self.register_buffer("untreated_obj_feat", torch.zeros(self.obj_dim+self.embed_dim + 128))
self.register_buffer("untreated_edg_feat", torch.zeros(self.embed_dim + self.obj_dim))
def sort_rois(self, proposals):
c_x = center_x(proposals)
# leftright order
scores = c_x / (c_x.max() + 1)
return sort_by_score(proposals, scores)
def obj_ctx(self, obj_feats, proposals, obj_labels=None, boxes_per_cls=None, ctx_average=False):
"""
Object context and object classification.
:param obj_feats: [num_obj, img_dim + object embedding0 dim]
:param obj_labels: [num_obj] the GT labels of the image
:param box_priors: [num_obj, 4] boxes. We'll use this for NMS
:param boxes_per_cls
:return: obj_dists: [num_obj, #classes] new probability distribution.
obj_preds: argmax of that distribution.
obj_final_ctx: [num_obj, #feats] For later!
"""
# Sort by the confidence of the maximum detection.
perm, inv_perm, ls_transposed = self.sort_rois(proposals)
# Pass object features, sorted by score, into the encoder LSTM
obj_inp_rep = obj_feats[perm].contiguous()
input_packed = PackedSequence(obj_inp_rep, ls_transposed)
encoder_rep = self.obj_ctx_rnn(input_packed)[0][0]
encoder_rep = self.lin_obj_h(encoder_rep) # map to hidden_dim
# untreated decoder input
batch_size = encoder_rep.shape[0]
if (not self.training) and self.effect_analysis and ctx_average:
decoder_inp = self.untreated_dcd_feat.view(1, -1).expand(batch_size, -1)
else:
decoder_inp = torch.cat((obj_inp_rep, encoder_rep), 1)
if self.training and self.effect_analysis:
self.untreated_dcd_feat = self.moving_average(self.untreated_dcd_feat, decoder_inp)
# Decode in order
if self.mode != 'predcls':
decoder_inp = PackedSequence(decoder_inp, ls_transposed)
obj_dists, obj_preds = self.decoder_rnn(
decoder_inp, #obj_dists[perm],
labels=obj_labels[perm] if obj_labels is not None else None,
boxes_for_nms=boxes_per_cls[perm] if boxes_per_cls is not None else None,
)
obj_preds = obj_preds[inv_perm]
obj_dists = obj_dists[inv_perm]
else:
assert obj_labels is not None
obj_preds = obj_labels
obj_dists = to_onehot(obj_preds, self.num_obj_classes)
encoder_rep = encoder_rep[inv_perm]
return obj_dists, obj_preds, encoder_rep, perm, inv_perm, ls_transposed
def edge_ctx(self, inp_feats, perm, inv_perm, ls_transposed):
"""
Object context and object classification.
:param obj_feats: [num_obj, img_dim + object embedding0 dim]
:return: edge_ctx: [num_obj, #feats] For later!
"""
edge_input_packed = PackedSequence(inp_feats[perm], ls_transposed)
edge_reps = self.edge_ctx_rnn(edge_input_packed)[0][0]
edge_reps = self.lin_edge_h(edge_reps) # map to hidden_dim
edge_ctx = edge_reps[inv_perm]
return edge_ctx
def moving_average(self, holder, input):
assert len(input.shape) == 2
with torch.no_grad():
holder = holder * (1 - self.average_ratio) + self.average_ratio * input.mean(0).view(-1)
return holder
def forward(self, x, proposals, rel_pair_idxs, logger=None, all_average=False, ctx_average=False):
# labels will be used in DecoderRNN during training (for nms)
if self.training or self.cfg.MODEL.ROI_RELATION_HEAD.USE_GT_BOX:
obj_labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0)
else:
obj_labels = None
if self.cfg.MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL:
obj_embed = self.obj_embed1(obj_labels.long())
else:
obj_logits = cat([proposal.get_field("predict_logits") for proposal in proposals], dim=0).detach()
obj_embed = F.softmax(obj_logits, dim=1) @ self.obj_embed1.weight
assert proposals[0].mode == 'xyxy'
pos_embed = self.pos_embed(encode_box_info(proposals))
batch_size = x.shape[0]
if all_average and self.effect_analysis and (not self.training):
obj_pre_rep = self.untreated_obj_feat.view(1, -1).expand(batch_size, -1)
else:
obj_pre_rep = cat((x, obj_embed, pos_embed), -1)
boxes_per_cls = None
if self.mode == 'sgdet' and not self.training:
boxes_per_cls = cat([proposal.get_field('boxes_per_cls') for proposal in proposals], dim=0) # comes from post process of box_head
# object level contextual feature
obj_dists, obj_preds, obj_ctx, perm, inv_perm, ls_transposed = self.obj_ctx(obj_pre_rep, proposals, obj_labels, boxes_per_cls, ctx_average=ctx_average)
# edge level contextual feature
obj_embed2 = self.obj_embed2(obj_preds.long())
if (all_average or ctx_average) and self.effect_analysis and (not self.training):
obj_rel_rep = cat((self.untreated_edg_feat.view(1, -1).expand(batch_size, -1), obj_ctx), dim=-1)
else:
obj_rel_rep = cat((obj_embed2, x, obj_ctx), -1)
edge_ctx = self.edge_ctx(obj_rel_rep, perm=perm, inv_perm=inv_perm, ls_transposed=ls_transposed)
# memorize average feature
if self.training and self.effect_analysis:
self.untreated_obj_feat = self.moving_average(self.untreated_obj_feat, obj_pre_rep)
self.untreated_edg_feat = self.moving_average(self.untreated_edg_feat, cat((obj_embed2, x), -1))
return obj_dists, obj_preds, edge_ctx, None