-
Notifications
You must be signed in to change notification settings - Fork 568
/
groundingdino.py
412 lines (362 loc) · 17 KB
/
groundingdino.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Conditional DETR model and criterion classes.
# Copyright (c) 2021 Microsoft. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
import copy
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from torchvision.ops.boxes import nms
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
from groundingdino.util import box_ops, get_tokenlizer
from groundingdino.util.misc import (
NestedTensor,
accuracy,
get_world_size,
interpolate,
inverse_sigmoid,
is_dist_avail_and_initialized,
nested_tensor_from_tensor_list,
)
from groundingdino.util.utils import get_phrases_from_posmap
from groundingdino.util.visualizer import COCOVisualizer
from groundingdino.util.vl_utils import create_positive_map_from_span
from ..registry import MODULE_BUILD_FUNCS
from .backbone import build_backbone
from .bertwarper import (
BertModelWarper,
generate_masks_with_special_tokens,
generate_masks_with_special_tokens_and_transfer_map,
)
from .transformer import build_transformer
from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss
class GroundingDINO(nn.Module):
"""This is the Cross-Attention Detector module that performs object detection"""
def __init__(
self,
backbone,
transformer,
num_queries,
aux_loss=False,
iter_update=False,
query_dim=2,
num_feature_levels=1,
nheads=8,
# two stage
two_stage_type="no", # ['no', 'standard']
dec_pred_bbox_embed_share=True,
two_stage_class_embed_share=True,
two_stage_bbox_embed_share=True,
num_patterns=0,
dn_number=100,
dn_box_noise_scale=0.4,
dn_label_noise_ratio=0.5,
dn_labelbook_size=100,
text_encoder_type="bert-base-uncased",
sub_sentence_present=True,
max_text_len=256,
):
"""Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.hidden_dim = hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.nheads = nheads
self.max_text_len = 256
self.sub_sentence_present = sub_sentence_present
# setting query dim
self.query_dim = query_dim
assert query_dim == 4
# for dn training
self.num_patterns = num_patterns
self.dn_number = dn_number
self.dn_box_noise_scale = dn_box_noise_scale
self.dn_label_noise_ratio = dn_label_noise_ratio
self.dn_labelbook_size = dn_labelbook_size
# bert
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type)
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type)
self.bert.pooler.dense.weight.requires_grad_(False)
self.bert.pooler.dense.bias.requires_grad_(False)
self.bert = BertModelWarper(bert_model=self.bert)
self.feat_map = nn.Linear(self.bert.config.hidden_size, self.hidden_dim, bias=True)
nn.init.constant_(self.feat_map.bias.data, 0)
nn.init.xavier_uniform_(self.feat_map.weight.data)
# freeze
# special tokens
self.specical_tokens = self.tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]", ".", "?"])
# prepare input projection layers
if num_feature_levels > 1:
num_backbone_outs = len(backbone.num_channels)
input_proj_list = []
for _ in range(num_backbone_outs):
in_channels = backbone.num_channels[_]
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
)
for _ in range(num_feature_levels - num_backbone_outs):
input_proj_list.append(
nn.Sequential(
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(32, hidden_dim),
)
)
in_channels = hidden_dim
self.input_proj = nn.ModuleList(input_proj_list)
else:
assert two_stage_type == "no", "two_stage_type should be no if num_feature_levels=1 !!!"
self.input_proj = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim),
)
]
)
self.backbone = backbone
self.aux_loss = aux_loss
self.box_pred_damping = box_pred_damping = None
self.iter_update = iter_update
assert iter_update, "Why not iter_update?"
# prepare pred layers
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
# prepare class & box embed
_class_embed = ContrastiveEmbed()
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
if dec_pred_bbox_embed_share:
box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
else:
box_embed_layerlist = [
copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)
]
class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
self.bbox_embed = nn.ModuleList(box_embed_layerlist)
self.class_embed = nn.ModuleList(class_embed_layerlist)
self.transformer.decoder.bbox_embed = self.bbox_embed
self.transformer.decoder.class_embed = self.class_embed
# two stage
self.two_stage_type = two_stage_type
assert two_stage_type in ["no", "standard"], "unknown param {} of two_stage_type".format(
two_stage_type
)
if two_stage_type != "no":
if two_stage_bbox_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_bbox_embed = _bbox_embed
else:
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
if two_stage_class_embed_share:
assert dec_pred_bbox_embed_share
self.transformer.enc_out_class_embed = _class_embed
else:
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
self.refpoint_embed = None
self._reset_parameters()
def _reset_parameters(self):
# init input_proj
for proj in self.input_proj:
nn.init.xavier_uniform_(proj[0].weight, gain=1)
nn.init.constant_(proj[0].bias, 0)
def set_image_tensor(self, samples: NestedTensor):
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
self.features, self.poss = self.backbone(samples)
def unset_image_tensor(self):
if hasattr(self, 'features'):
del self.features
if hasattr(self,'poss'):
del self.poss
def set_image_features(self, features , poss):
self.features = features
self.poss = poss
def init_ref_points(self, use_num_queries):
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
def forward(self, samples: NestedTensor, targets: List = None, **kw):
"""The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x num_classes]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, width, height). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if targets is None:
captions = kw["captions"]
else:
captions = [t["caption"] for t in targets]
# encoder texts
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to(
samples.device
)
(
text_self_attention_masks,
position_ids,
cate_to_token_mask_list,
) = generate_masks_with_special_tokens_and_transfer_map(
tokenized, self.specical_tokens, self.tokenizer
)
if text_self_attention_masks.shape[1] > self.max_text_len:
text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len
]
position_ids = position_ids[:, : self.max_text_len]
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len]
tokenized["attention_mask"] = tokenized["attention_mask"][:, : self.max_text_len]
tokenized["token_type_ids"] = tokenized["token_type_ids"][:, : self.max_text_len]
# extract text embeddings
if self.sub_sentence_present:
tokenized_for_encoder = {k: v for k, v in tokenized.items() if k != "attention_mask"}
tokenized_for_encoder["attention_mask"] = text_self_attention_masks
tokenized_for_encoder["position_ids"] = position_ids
else:
# import ipdb; ipdb.set_trace()
tokenized_for_encoder = tokenized
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
encoded_text = self.feat_map(bert_output["last_hidden_state"]) # bs, 195, d_model
text_token_mask = tokenized.attention_mask.bool() # bs, 195
# text_token_mask: True for nomask, False for mask
# text_self_attention_masks: True for nomask, False for mask
if encoded_text.shape[1] > self.max_text_len:
encoded_text = encoded_text[:, : self.max_text_len, :]
text_token_mask = text_token_mask[:, : self.max_text_len]
position_ids = position_ids[:, : self.max_text_len]
text_self_attention_masks = text_self_attention_masks[
:, : self.max_text_len, : self.max_text_len
]
text_dict = {
"encoded_text": encoded_text, # bs, 195, d_model
"text_token_mask": text_token_mask, # bs, 195
"position_ids": position_ids, # bs, 195
"text_self_attention_masks": text_self_attention_masks, # bs, 195,195
}
# import ipdb; ipdb.set_trace()
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
if not hasattr(self, 'features') or not hasattr(self, 'poss'):
self.set_image_tensor(samples)
srcs = []
masks = []
for l, feat in enumerate(self.features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](self.features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
self.poss.append(pos_l)
input_query_bbox = input_query_label = attn_mask = dn_meta = None
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(
srcs, masks, input_query_bbox, self.poss, input_query_label, attn_mask, text_dict
)
# deformable-detr-like anchor update
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed, hs)
):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# output
outputs_class = torch.stack(
[
layer_cls_embed(layer_hs, text_dict)
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
]
)
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]}
# # for intermediate outputs
# if self.aux_loss:
# out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
# # for encoder output
# if hs_enc is not None:
# # prepare intermediate outputs
# interm_coord = ref_enc[-1]
# interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict)
# out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
# out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
unset_image_tensor = kw.get('unset_image_tensor', True)
if unset_image_tensor:
self.unset_image_tensor() ## If necessary
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [
{"pred_logits": a, "pred_boxes": b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
]
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino")
def build_groundingdino(args):
backbone = build_backbone(args)
transformer = build_transformer(args)
dn_labelbook_size = args.dn_labelbook_size
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
sub_sentence_present = args.sub_sentence_present
model = GroundingDINO(
backbone,
transformer,
num_queries=args.num_queries,
aux_loss=True,
iter_update=True,
query_dim=4,
num_feature_levels=args.num_feature_levels,
nheads=args.nheads,
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
two_stage_type=args.two_stage_type,
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
two_stage_class_embed_share=args.two_stage_class_embed_share,
num_patterns=args.num_patterns,
dn_number=0,
dn_box_noise_scale=args.dn_box_noise_scale,
dn_label_noise_ratio=args.dn_label_noise_ratio,
dn_labelbook_size=dn_labelbook_size,
text_encoder_type=args.text_encoder_type,
sub_sentence_present=sub_sentence_present,
max_text_len=args.max_text_len,
)
return model