/
fast_rcnn.py
855 lines (753 loc) · 36.4 KB
/
fast_rcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from typing import Dict, Union
import torch
import os
import math
import shortuuid
from fvcore.nn import giou_loss, smooth_l1_loss
from torch import nn
from torch.nn import functional as F
from torch.distributions.normal import Normal
import detectron2.utils.comm as comm
from detectron2.config import configurable
from detectron2.layers import Linear, ShapeSpec, batched_nms, cat, nonzero_tuple
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.structures import Boxes, Instances
from detectron2.utils.events import get_event_storage
from detectron2.utils.store import Store
__all__ = ["fast_rcnn_inference", "FastRCNNOutputLayers"]
logger = logging.getLogger(__name__)
"""
Shape shorthand in this module:
N: number of images in the minibatch
R: number of ROIs, combined over all images, in the minibatch
Ri: number of ROIs in image i
K: number of foreground classes. E.g.,there are 80 foreground classes in COCO.
Naming convention:
deltas: refers to the 4-d (dx, dy, dw, dh) deltas that parameterize the box2box
transform (see :class:`box_regression.Box2BoxTransform`).
pred_class_logits: predicted class scores in [-inf, +inf]; use
softmax(pred_class_logits) to estimate P(class).
gt_classes: ground-truth classification labels in [0, K], where [0, K) represent
foreground object classes and K represents the background class.
pred_proposal_deltas: predicted box2box transform deltas for transforming proposals
to detection box predictions.
gt_proposal_deltas: ground-truth box2box transform deltas
"""
def fast_rcnn_inference(boxes, scores, image_shapes, predictions, score_thresh, nms_thresh, topk_per_image):
"""
Call `fast_rcnn_inference_single_image` for all images.
Args:
boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic
boxes for each image. Element i has shape (Ri, K * 4) if doing
class-specific regression, or (Ri, 4) if doing class-agnostic
regression, where Ri is the number of predicted objects for image i.
This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`.
scores (list[Tensor]): A list of Tensors of predicted class scores for each image.
Element i has shape (Ri, K + 1), where Ri is the number of predicted objects
for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`.
image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch.
score_thresh (float): Only return detections with a confidence score exceeding this
threshold.
nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
topk_per_image (int): The number of top scoring detections to return. Set < 0 to return
all detections.
Returns:
instances: (list[Instances]): A list of N instances, one for each image in the batch,
that stores the topk most confidence detections.
kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates
the corresponding boxes/scores index in [0, Ri) from the input, for image i.
"""
result_per_image = [
fast_rcnn_inference_single_image(
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image, prediction
)
for scores_per_image, boxes_per_image, image_shape, prediction in zip(scores, boxes, image_shapes, predictions)
]
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
def fast_rcnn_inference_single_image(
boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image, prediction
):
"""
Single-image inference. Return bounding-box detection results by thresholding
on scores and applying non-maximum suppression (NMS).
Args:
Same as `fast_rcnn_inference`, but with boxes, scores, and image shapes
per image.
Returns:
Same as `fast_rcnn_inference`, but for only one image.
"""
logits = prediction
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
if not valid_mask.all():
boxes = boxes[valid_mask]
scores = scores[valid_mask]
logits = logits[valid_mask]
scores = scores[:, :-1]
logits = logits[:, :-1]
num_bbox_reg_classes = boxes.shape[1] // 4
# Convert to Boxes to use the `clip` function ...
boxes = Boxes(boxes.reshape(-1, 4))
boxes.clip(image_shape)
boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4
# 1. Filter results based on detection scores. It can make NMS more efficient
# by filtering out low-confidence detections.
filter_mask = scores > score_thresh # R x K
# R' x 2. First column contains indices of the R predictions;
# Second column contains indices of classes.
filter_inds = filter_mask.nonzero()
if num_bbox_reg_classes == 1:
boxes = boxes[filter_inds[:, 0], 0]
else:
boxes = boxes[filter_mask]
scores = scores[filter_mask]
logits = logits[filter_inds[:,0]]
# 2. Apply NMS for each class independently.
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
if topk_per_image >= 0:
keep = keep[:topk_per_image]
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
logits = logits[keep]
result = Instances(image_shape)
result.pred_boxes = Boxes(boxes)
result.scores = scores
result.pred_classes = filter_inds[:, 1]
result.logits = logits
return result, filter_inds[:, 0]
class FastRCNNOutputs:
"""
An internal implementation that stores information about outputs of a Fast R-CNN head,
and provides methods that are used to decode the outputs of a Fast R-CNN head.
"""
def __init__(
self,
box2box_transform,
pred_class_logits,
pred_proposal_deltas,
proposals,
invalid_class_range,
smooth_l1_beta=0.0,
box_reg_loss_type="smooth_l1",
):
"""
Args:
box2box_transform (Box2BoxTransform/Box2BoxTransformRotated):
box2box transform instance for proposal-to-detection transformations.
pred_class_logits (Tensor): A tensor of shape (R, K + 1) storing the predicted class
logits for all R predicted object instances.
Each row corresponds to a predicted object instance.
pred_proposal_deltas (Tensor): A tensor of shape (R, K * B) or (R, B) for
class-specific or class-agnostic regression. It stores the predicted deltas that
transform proposals into final box detections.
B is the box dimension (4 or 5).
When B is 4, each row is [dx, dy, dw, dh (, ....)].
When B is 5, each row is [dx, dy, dw, dh, da (, ....)].
proposals (list[Instances]): A list of N Instances, where Instances i stores the
proposals for image i, in the field "proposal_boxes".
When training, each Instances must have ground-truth labels
stored in the field "gt_classes" and "gt_boxes".
The total number of all instances must be equal to R.
smooth_l1_beta (float): The transition point between L1 and L2 loss in
the smooth L1 loss function. When set to 0, the loss becomes L1. When
set to +inf, the loss becomes constant 0.
box_reg_loss_type (str): Box regression loss type. One of: "smooth_l1", "giou"
"""
self.box2box_transform = box2box_transform
self.num_preds_per_image = [len(p) for p in proposals]
self.pred_class_logits = pred_class_logits
self.pred_proposal_deltas = pred_proposal_deltas
self.smooth_l1_beta = smooth_l1_beta
self.box_reg_loss_type = box_reg_loss_type
self.image_shapes = [x.image_size for x in proposals]
self.invalid_class_range = invalid_class_range
if len(proposals):
box_type = type(proposals[0].proposal_boxes)
# cat(..., dim=0) concatenates over all images in the batch
self.proposals = box_type.cat([p.proposal_boxes for p in proposals])
assert (
not self.proposals.tensor.requires_grad
), "Proposals should not require gradients!"
# The following fields should exist only when training.
if proposals[0].has("gt_boxes"):
self.gt_boxes = box_type.cat([p.gt_boxes for p in proposals])
assert proposals[0].has("gt_classes")
self.gt_classes = cat([p.gt_classes for p in proposals], dim=0)
else:
self.proposals = Boxes(torch.zeros(0, 4, device=self.pred_proposal_deltas.device))
self._no_instances = len(proposals) == 0 # no instances found
def _log_accuracy(self):
"""
Log the accuracy metrics to EventStorage.
"""
num_instances = self.gt_classes.numel()
pred_classes = self.pred_class_logits.argmax(dim=1)
bg_class_ind = self.pred_class_logits.shape[1] - 1
fg_inds = (self.gt_classes >= 0) & (self.gt_classes < bg_class_ind)
num_fg = fg_inds.nonzero().numel()
fg_gt_classes = self.gt_classes[fg_inds]
fg_pred_classes = pred_classes[fg_inds]
num_false_negative = (fg_pred_classes == bg_class_ind).nonzero().numel()
num_accurate = (pred_classes == self.gt_classes).nonzero().numel()
fg_num_accurate = (fg_pred_classes == fg_gt_classes).nonzero().numel()
storage = get_event_storage()
if num_instances > 0:
storage.put_scalar("fast_rcnn/cls_accuracy", num_accurate / num_instances)
if num_fg > 0:
storage.put_scalar("fast_rcnn/fg_cls_accuracy", fg_num_accurate / num_fg)
storage.put_scalar("fast_rcnn/false_negative", num_false_negative / num_fg)
def softmax_cross_entropy_loss(self):
"""
Compute the softmax cross entropy loss for box classification.
Returns:
scalar Tensor
"""
if self._no_instances:
return 0.0 * self.pred_class_logits.sum()
else:
self._log_accuracy()
self.pred_class_logits[:, self.invalid_class_range] = -10e10
# self.log_logits(self.pred_class_logits, self.gt_classes)
return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")
def log_logits(self, logits, cls):
data = (logits, cls)
location = '/home/fk1/workspace/OWOD/output/logits/' + shortuuid.uuid() + '.pkl'
torch.save(data, location)
def box_reg_loss(self):
"""
Compute the smooth L1 loss for box regression.
Returns:
scalar Tensor
"""
if self._no_instances:
return 0.0 * self.pred_proposal_deltas.sum()
box_dim = self.gt_boxes.tensor.size(1) # 4 or 5
cls_agnostic_bbox_reg = self.pred_proposal_deltas.size(1) == box_dim
device = self.pred_proposal_deltas.device
bg_class_ind = self.pred_class_logits.shape[1] - 1
# Box delta loss is only computed between the prediction for the gt class k
# (if 0 <= k < bg_class_ind) and the target; there is no loss defined on predictions
# for non-gt classes and background.
# Empty fg_inds produces a valid loss of zero as long as the size_average
# arg to smooth_l1_loss is False (otherwise it uses torch.mean internally
# and would produce a nan loss).
fg_inds = nonzero_tuple((self.gt_classes >= 0) & (self.gt_classes < bg_class_ind))[0]
if cls_agnostic_bbox_reg:
# pred_proposal_deltas only corresponds to foreground class for agnostic
gt_class_cols = torch.arange(box_dim, device=device)
else:
fg_gt_classes = self.gt_classes[fg_inds]
# pred_proposal_deltas for class k are located in columns [b * k : b * k + b],
# where b is the dimension of box representation (4 or 5)
# Note that compared to Detectron1,
# we do not perform bounding box regression for background classes.
gt_class_cols = box_dim * fg_gt_classes[:, None] + torch.arange(box_dim, device=device)
if self.box_reg_loss_type == "smooth_l1":
gt_proposal_deltas = self.box2box_transform.get_deltas(
self.proposals.tensor, self.gt_boxes.tensor
)
loss_box_reg = smooth_l1_loss(
self.pred_proposal_deltas[fg_inds[:, None], gt_class_cols],
gt_proposal_deltas[fg_inds],
self.smooth_l1_beta,
reduction="sum",
)
elif self.box_reg_loss_type == "giou":
loss_box_reg = giou_loss(
self._predict_boxes()[fg_inds[:, None], gt_class_cols],
self.gt_boxes.tensor[fg_inds],
reduction="sum",
)
else:
raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")
# The loss is normalized using the total number of regions (R), not the number
# of foreground regions even though the box regression loss is only defined on
# foreground regions. Why? Because doing so gives equal training influence to
# each foreground example. To see how, consider two different minibatches:
# (1) Contains a single foreground region
# (2) Contains 100 foreground regions
# If we normalize by the number of foreground regions, the single example in
# minibatch (1) will be given 100 times as much influence as each foreground
# example in minibatch (2). Normalizing by the total number of regions, R,
# means that the single example in minibatch (1) and each of the 100 examples
# in minibatch (2) are given equal influence.
loss_box_reg = loss_box_reg / self.gt_classes.numel()
return loss_box_reg
def _predict_boxes(self):
"""
Returns:
Tensor: A Tensors of predicted class-specific or class-agnostic boxes
for all images in a batch. Element i has shape (Ri, K * B) or (Ri, B), where Ri is
the number of predicted objects for image i and B is the box dimension (4 or 5)
"""
return self.box2box_transform.apply_deltas(self.pred_proposal_deltas, self.proposals.tensor)
"""
A subclass is expected to have the following methods because
they are used to query information about the head predictions.
"""
def losses(self):
"""
Compute the default losses for box head in Fast(er) R-CNN,
with softmax cross entropy loss and smooth L1 loss.
Returns:
A dict of losses (scalar tensors) containing keys "loss_cls" and "loss_box_reg".
"""
return {"loss_cls": self.softmax_cross_entropy_loss(), "loss_box_reg": self.box_reg_loss()}
def predict_boxes(self):
"""
Deprecated
"""
return self._predict_boxes().split(self.num_preds_per_image, dim=0)
def predict_probs(self):
"""
Deprecated
"""
probs = F.softmax(self.pred_class_logits, dim=-1)
return probs.split(self.num_preds_per_image, dim=0)
def inference(self, score_thresh, nms_thresh, topk_per_image):
"""
Deprecated
"""
boxes = self.predict_boxes()
scores = self.predict_probs()
image_shapes = self.image_shapes
return fast_rcnn_inference(
boxes, scores, image_shapes, score_thresh, nms_thresh, topk_per_image
)
class AE(nn.Module):
def __init__(self, input_size, z_dim):
super(AE,self).__init__()
self.e1 = nn.Linear(input_size,z_dim)
self.d1 = nn.Linear(z_dim, input_size)
def encoder(self, x):
z = self.e1(x)
z = torch.relu(z)
return z
def decoder(self, z):
x = self.d1(z)
x = torch.relu(x)
return x
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
def Xavier(m):
if m.__class__.__name__ == 'Linear':
fan_in, fan_out = m.weight.data.size(1), m.weight.data.size(0)
std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out))
a = math.sqrt(3.0) * std
m.weight.data.uniform_(-a, a)
m.bias.data.fill_(0.0)
class FastRCNNOutputLayers(nn.Module):
"""
Two linear layers for predicting Fast R-CNN outputs:
1. proposal-to-detection box regression deltas
2. classification scores
"""
@configurable
def __init__(
self,
input_shape: ShapeSpec,
*,
box2box_transform,
clustering_items_per_class,
clustering_start_iter,
clustering_update_mu_iter,
clustering_momentum,
clustering_z_dimension,
enable_clustering,
prev_intro_cls,
curr_intro_cls,
max_iterations,
output_dir,
feat_store_path,
margin,
num_classes: int,
test_score_thresh: float = 0.0,
test_nms_thresh: float = 0.5,
test_topk_per_image: int = 100,
cls_agnostic_bbox_reg: bool = False,
smooth_l1_beta: float = 0.0,
box_reg_loss_type: str = "smooth_l1",
loss_weight: Union[float, Dict[str, float]] = 1.0,
):
"""
NOTE: this interface is experimental.
Args:
input_shape (ShapeSpec): shape of the input feature to this module
box2box_transform (Box2BoxTransform or Box2BoxTransformRotated):
num_classes (int): number of foreground classes
test_score_thresh (float): threshold to filter predictions results.
test_nms_thresh (float): NMS threshold for prediction results.
test_topk_per_image (int): number of top predictions to produce per image.
cls_agnostic_bbox_reg (bool): whether to use class agnostic for bbox regression
smooth_l1_beta (float): transition point from L1 to L2 loss. Only used if
`box_reg_loss_type` is "smooth_l1"
box_reg_loss_type (str): Box regression loss type. One of: "smooth_l1", "giou"
loss_weight (float|dict): weights to use for losses. Can be single float for weighting
all losses, or a dict of individual weightings. Valid dict keys are:
* "loss_cls": applied to classification loss
* "loss_box_reg": applied to box regression loss
"""
super().__init__()
if isinstance(input_shape, int): # some backward compatibility
input_shape = ShapeSpec(channels=input_shape)
input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1)
# prediction layer for num_classes foreground classes and one background class (hence + 1)
self.cls_score = Linear(input_size, num_classes + 1)
num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes
box_dim = len(box2box_transform.weights)
self.bbox_pred = Linear(input_size, num_bbox_reg_classes * box_dim)
nn.init.normal_(self.cls_score.weight, std=0.01)
nn.init.normal_(self.bbox_pred.weight, std=0.001)
for l in [self.cls_score, self.bbox_pred]:
nn.init.constant_(l.bias, 0)
self.box2box_transform = box2box_transform
self.smooth_l1_beta = smooth_l1_beta
self.test_score_thresh = test_score_thresh
self.test_nms_thresh = test_nms_thresh
self.test_topk_per_image = test_topk_per_image
self.box_reg_loss_type = box_reg_loss_type
if isinstance(loss_weight, float):
loss_weight = {"loss_cls": loss_weight, "loss_box_reg": loss_weight}
self.loss_weight = loss_weight
self.num_classes = num_classes
self.clustering_start_iter = clustering_start_iter
self.clustering_update_mu_iter = clustering_update_mu_iter
self.clustering_momentum = clustering_momentum
self.hingeloss = nn.HingeEmbeddingLoss(2)
self.enable_clustering = enable_clustering
self.prev_intro_cls = prev_intro_cls
self.curr_intro_cls = curr_intro_cls
self.seen_classes = self.prev_intro_cls + self.curr_intro_cls
self.invalid_class_range = list(range(self.seen_classes, self.num_classes-1))
logging.getLogger(__name__).info("Invalid class range: " + str(self.invalid_class_range))
self.max_iterations = max_iterations
self.feature_store_is_stored = False
self.output_dir = output_dir
self.feat_store_path = feat_store_path
self.feature_store_save_loc = os.path.join(self.output_dir, self.feat_store_path, 'feat.pt')
if os.path.isfile(self.feature_store_save_loc):
logging.getLogger(__name__).info('Trying to load feature store from ' + self.feature_store_save_loc)
self.feature_store = torch.load(self.feature_store_save_loc)
else:
logging.getLogger(__name__).info('Feature store not found in ' +
self.feature_store_save_loc + '. Creating new feature store.')
self.feature_store = Store(num_classes + 1, clustering_items_per_class)
self.means = [None for _ in range(num_classes + 1)]
self.margin = margin
# self.ae_model = AE(input_size, clustering_z_dimension)
# self.ae_model.apply(Xavier)
@classmethod
def from_config(cls, cfg, input_shape):
return {
"input_shape": input_shape,
"box2box_transform": Box2BoxTransform(weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS),
# fmt: off
"num_classes" : cfg.MODEL.ROI_HEADS.NUM_CLASSES,
"cls_agnostic_bbox_reg" : cfg.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG,
"smooth_l1_beta" : cfg.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA,
"test_score_thresh" : cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST,
"test_nms_thresh" : cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST,
"test_topk_per_image" : cfg.TEST.DETECTIONS_PER_IMAGE,
"box_reg_loss_type" : cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE,
"loss_weight" : {"loss_box_reg": cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT, "loss_clustering": 0.1},
"clustering_items_per_class" : cfg.OWOD.CLUSTERING.ITEMS_PER_CLASS,
"clustering_start_iter" : cfg.OWOD.CLUSTERING.START_ITER,
"clustering_update_mu_iter" : cfg.OWOD.CLUSTERING.UPDATE_MU_ITER,
"clustering_momentum" : cfg.OWOD.CLUSTERING.MOMENTUM,
"clustering_z_dimension": cfg.OWOD.CLUSTERING.Z_DIMENSION,
"enable_clustering" : cfg.OWOD.ENABLE_CLUSTERING,
"prev_intro_cls" : cfg.OWOD.PREV_INTRODUCED_CLS,
"curr_intro_cls" : cfg.OWOD.CUR_INTRODUCED_CLS,
"max_iterations" : cfg.SOLVER.MAX_ITER,
"output_dir" : cfg.OUTPUT_DIR,
"feat_store_path" : cfg.OWOD.FEATURE_STORE_SAVE_PATH,
"margin" : cfg.OWOD.CLUSTERING.MARGIN,
# fmt: on
}
def forward(self, x):
"""
Args:
x: per-region features of shape (N, ...) for N bounding boxes to predict.
Returns:
(Tensor, Tensor):
First tensor: shape (N,K+1), scores for each of the N box. Each row contains the
scores for K object categories and 1 background class.
Second tensor: bounding box regression deltas for each box. Shape is shape (N,Kx4),
or (N,4) for class-agnostic regression.
"""
if x.dim() > 2:
x = torch.flatten(x, start_dim=1)
scores = self.cls_score(x)
proposal_deltas = self.bbox_pred(x)
return scores, proposal_deltas
def update_feature_store(self, features, proposals):
# cat(..., dim=0) concatenates over all images in the batch
gt_classes = torch.cat([p.gt_classes for p in proposals])
self.feature_store.add(features, gt_classes)
storage = get_event_storage()
if storage.iter == self.max_iterations-1 and self.feature_store_is_stored is False and comm.is_main_process():
logging.getLogger(__name__).info('Saving image store at iteration ' + str(storage.iter) + ' to ' + self.feature_store_save_loc)
torch.save(self.feature_store, self.feature_store_save_loc)
self.feature_store_is_stored = True
# self.feature_store.add(F.normalize(features, dim=0), gt_classes)
# self.feature_store.add(self.ae_model.encoder(features), gt_classes)
def clstr_loss_l2_cdist(self, input_features, proposals):
"""
Get the foreground input_features, generate distributions for the class,
get probability of each feature from each distribution;
Compute loss: if belonging to a class -> likelihood should be higher
else -> lower
:param input_features:
:param proposals:
:return:
"""
gt_classes = torch.cat([p.gt_classes for p in proposals])
mask = gt_classes != self.num_classes
fg_features = input_features[mask]
classes = gt_classes[mask]
# fg_features = F.normalize(fg_features, dim=0)
# fg_features = self.ae_model.encoder(fg_features)
all_means = self.means
for item in all_means:
if item != None:
length = item.shape
break
for i, item in enumerate(all_means):
if item == None:
all_means[i] = torch.zeros((length))
distances = torch.cdist(fg_features, torch.stack(all_means).cuda(), p=self.margin)
labels = []
for index, feature in enumerate(fg_features):
for cls_index, mu in enumerate(self.means):
if mu is not None and feature is not None:
if classes[index] == cls_index:
labels.append(1)
else:
labels.append(-1)
else:
labels.append(0)
loss = self.hingeloss(distances, torch.tensor(labels).reshape((-1, self.num_classes+1)).cuda())
return loss
def get_clustering_loss(self, input_features, proposals):
if not self.enable_clustering:
return 0
storage = get_event_storage()
c_loss = 0
if storage.iter == self.clustering_start_iter:
items = self.feature_store.retrieve(-1)
for index, item in enumerate(items):
if len(item) == 0:
self.means[index] = None
else:
mu = torch.tensor(item).mean(dim=0)
self.means[index] = mu
c_loss = self.clstr_loss_l2_cdist(input_features, proposals)
# Freeze the parameters when clustering starts
# for param in self.ae_model.parameters():
# param.requires_grad = False
elif storage.iter > self.clustering_start_iter:
if storage.iter % self.clustering_update_mu_iter == 0:
# Compute new MUs
items = self.feature_store.retrieve(-1)
new_means = [None for _ in range(self.num_classes + 1)]
for index, item in enumerate(items):
if len(item) == 0:
new_means[index] = None
else:
new_means[index] = torch.tensor(item).mean(dim=0)
# Update the MUs
for i, mean in enumerate(self.means):
if(mean) is not None and new_means[i] is not None:
self.means[i] = self.clustering_momentum * mean + \
(1 - self.clustering_momentum) * new_means[i]
c_loss = self.clstr_loss_l2_cdist(input_features, proposals)
return c_loss
# def get_ae_loss(self, input_features):
# # storage = get_event_storage()
# # ae_loss = 0
# # if storage.iter < self.clustering_start_iter :
# features_hat = self.ae_model(input_features)
# ae_loss = F.mse_loss(features_hat, input_features)
# return ae_loss
# TODO: move the implementation to this class.
def losses(self, predictions, proposals, input_features=None):
"""
Args:
predictions: return values of :meth:`forward()`.
proposals (list[Instances]): proposals that match the features that were used
to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``,
``gt_classes`` are expected.
Returns:
Dict[str, Tensor]: dict of losses
"""
scores, proposal_deltas = predictions
losses = FastRCNNOutputs(
self.box2box_transform,
scores,
proposal_deltas,
proposals,
self.invalid_class_range,
self.smooth_l1_beta,
self.box_reg_loss_type,
).losses()
if input_features is not None:
# losses["loss_cluster_encoder"] = self.get_ae_loss(input_features)
losses["loss_clustering"] = self.get_clustering_loss(input_features, proposals)
return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
def inference(self, predictions, proposals):
"""
Args:
predictions: return values of :meth:`forward()`.
proposals (list[Instances]): proposals that match the features that were
used to compute predictions. The ``proposal_boxes`` field is expected.
Returns:
list[Instances]: same as `fast_rcnn_inference`.
list[Tensor]: same as `fast_rcnn_inference`.
"""
boxes = self.predict_boxes(predictions, proposals)
scores = self.predict_probs(predictions, proposals)
image_shapes = [x.image_size for x in proposals]
return fast_rcnn_inference(
boxes,
scores,
image_shapes,
predictions,
self.test_score_thresh,
self.test_nms_thresh,
self.test_topk_per_image,
)
def predict_boxes_for_gt_classes(self, predictions, proposals):
"""
Args:
predictions: return values of :meth:`forward()`.
proposals (list[Instances]): proposals that match the features that were used
to compute predictions. The fields ``proposal_boxes``, ``gt_classes`` are expected.
Returns:
list[Tensor]:
A list of Tensors of predicted boxes for GT classes in case of
class-specific box head. Element i of the list has shape (Ri, B), where Ri is
the number of proposals for image i and B is the box dimension (4 or 5)
"""
if not len(proposals):
return []
scores, proposal_deltas = predictions
proposal_boxes = [p.proposal_boxes for p in proposals]
proposal_boxes = proposal_boxes[0].cat(proposal_boxes).tensor
N, B = proposal_boxes.shape
predict_boxes = self.box2box_transform.apply_deltas(
proposal_deltas, proposal_boxes
) # Nx(KxB)
K = predict_boxes.shape[1] // B
if K > 1:
gt_classes = torch.cat([p.gt_classes for p in proposals], dim=0)
# Some proposals are ignored or have a background class. Their gt_classes
# cannot be used as index.
gt_classes = gt_classes.clamp_(0, K - 1)
predict_boxes = predict_boxes.view(N, K, B)[
torch.arange(N, dtype=torch.long, device=predict_boxes.device), gt_classes
]
num_prop_per_image = [len(p) for p in proposals]
return predict_boxes.split(num_prop_per_image)
def predict_boxes(self, predictions, proposals):
"""
Args:
predictions: return values of :meth:`forward()`.
proposals (list[Instances]): proposals that match the features that were
used to compute predictions. The ``proposal_boxes`` field is expected.
Returns:
list[Tensor]:
A list of Tensors of predicted class-specific or class-agnostic boxes
for each image. Element i has shape (Ri, K * B) or (Ri, B), where Ri is
the number of proposals for image i and B is the box dimension (4 or 5)
"""
if not len(proposals):
return []
_, proposal_deltas = predictions
num_prop_per_image = [len(p) for p in proposals]
proposal_boxes = [p.proposal_boxes for p in proposals]
proposal_boxes = proposal_boxes[0].cat(proposal_boxes).tensor
predict_boxes = self.box2box_transform.apply_deltas(
proposal_deltas, proposal_boxes
) # Nx(KxB)
return predict_boxes.split(num_prop_per_image)
def predict_probs(self, predictions, proposals):
"""
Args:
predictions: return values of :meth:`forward()`.
proposals (list[Instances]): proposals that match the features that were
used to compute predictions.
Returns:
list[Tensor]:
A list of Tensors of predicted class probabilities for each image.
Element i has shape (Ri, K + 1), where Ri is the number of proposals for image i.
"""
scores, _ = predictions
num_inst_per_image = [len(p) for p in proposals]
probs = F.softmax(scores, dim=-1)
return probs.split(num_inst_per_image, dim=0)
# def clstr_loss(self, input_features, proposals):
# """
# Get the foreground input_features, generate distributions for the class,
# get probability of each feature from each distribution;
# Compute loss: if belonging to a class -> likelihood should be higher
# else -> lower
# :param input_features:
# :param proposals:
# :return:
# """
# loss = 0
# gt_classes = torch.cat([p.gt_classes for p in proposals])
# mask = gt_classes != self.num_classes
# fg_features = input_features[mask]
# classes = gt_classes[mask]
# # fg_features = self.ae_model.encoder(fg_features)
#
# # Distribution per class
# log_prob = [None for _ in range(self.num_classes + 1)]
# # https://github.com/pytorch/pytorch/issues/23780
# for cls_index, mu in enumerate(self.means):
# if mu is not None:
# dist = Normal(loc=mu.cuda(), scale=torch.ones_like(mu.cuda()))
# log_prob[cls_index] = dist.log_prob(fg_features).mean(dim=1)
# # log_prob[cls_index] = torch.distributions.multivariate_normal. \
# # MultivariateNormal(mu.cuda(), torch.eye(len(mu)).cuda()).log_prob(fg_features)
# # MultivariateNormal(mu, torch.eye(len(mu))).log_prob(fg_features.cpu())
# # MultivariateNormal(mu[:2], torch.eye(len(mu[:2]))).log_prob(fg_features[:,:2].cpu())
# else:
# log_prob[cls_index] = torch.zeros((len(fg_features))).cuda()
#
# log_prob = torch.stack(log_prob).T # num_of_fg_proposals x num_of_classes
# for i, p in enumerate(log_prob):
# weight = torch.ones_like(p) * -1
# weight[classes[i]] = 1
# p = p * weight
# loss += p.mean()
# return loss
# def clstr_loss_l2(self, input_features, proposals):
# """
# Get the foreground input_features, generate distributions for the class,
# get probability of each feature from each distribution;
# Compute loss: if belonging to a class -> likelihood should be higher
# else -> lower
# :param input_features:
# :param proposals:
# :return:
# """
# loss = 0
# gt_classes = torch.cat([p.gt_classes for p in proposals])
# mask = gt_classes != self.num_classes
# fg_features = input_features[mask]
# classes = gt_classes[mask]
# fg_features = self.ae_model.encoder(fg_features)
#
# for index, feature in enumerate(fg_features):
# for cls_index, mu in enumerate(self.means):
# if mu is not None and feature is not None:
# mu = mu.cuda()
# if classes[index] == cls_index:
# loss -= F.mse_loss(feature, mu)
# else:
# loss += F.mse_loss(feature, mu)
#
# return loss