In [1]:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Detection Training Script.

This scripts reads a given config file and runs the training or evaluation.
It is an entry point that is made to train standard models in detectron2.

In order to let one script support training of many models,
this script contains logic that are specific to these built-in models and therefore
may not be suitable for your own project.
For example, your research project perhaps only needs a single "evaluator".

Therefore, we recommend you to use detectron2 as an library and take
this file as an example of how to use the library.
You may want to write your own script with your datasets and other customizations.
"""

import logging
import os
from collections import OrderedDict
from typing import Any, Dict, List, Set
import torch
import itertools

import detectron2.utils.comm as comm
from detectron2.data import MetadataCatalog, build_detection_train_loader
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
from detectron2.utils.events import EventStorage
from detectron2.evaluation import (
    COCOEvaluator,
    COCOPanopticEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    SemSegEvaluator,
    verify_results,
)
from detectron2.solver.build import maybe_add_gradient_clipping
from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.utils.logger import setup_logger

from adet.data.dataset_mapper import DatasetMapperWithBasis
from adet.config import get_cfg
from adet.checkpoint import AdetCheckpointer
from adet.evaluation import TextEvaluator, TextDetEvaluator


class Trainer(DefaultTrainer):
    """
    This is the same Trainer except that we rewrite the
    `build_train_loader`/`resume_or_load` method.
    """

    def build_hooks(self):
        """
        Replace `DetectionCheckpointer` with `AdetCheckpointer`.

        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        """
        ret = super().build_hooks()
        for i in range(len(ret)):
            if isinstance(ret[i], hooks.PeriodicCheckpointer):
                self.checkpointer = AdetCheckpointer(
                    self.model,
                    self.cfg.OUTPUT_DIR,
                    optimizer=self.optimizer,
                    scheduler=self.scheduler,
                )
                ret[i] = hooks.PeriodicCheckpointer(
                    self.checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD)
        return ret

    def resume_or_load(self, resume=True):
        checkpoint = self.checkpointer.resume_or_load(
            self.cfg.MODEL.WEIGHTS, resume=resume)
        if resume and self.checkpointer.has_checkpoint():
            self.start_iter = checkpoint.get("iteration", -1) + 1

    def train_loop(self, start_iter: int, max_iter: int):
        """
        Args:
            start_iter, max_iter (int): See docs above
        """
        logger = logging.getLogger("adet.trainer")
        logger.info("Starting training from iteration {}".format(start_iter))

        self.iter = self.start_iter = start_iter
        self.max_iter = max_iter

        with EventStorage(start_iter) as self.storage:
            self.before_train()
            for self.iter in range(start_iter, max_iter):
                self.before_step()
                self.run_step()
                self.after_step()
            self.after_train()

    def train(self):
        """
        Run training.

        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        self.train_loop(self.start_iter, self.max_iter)
        if hasattr(self, "_last_eval_results") and comm.is_main_process():
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    @classmethod
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable

        It calls :func:`detectron2.data.build_detection_train_loader` with a customized
        DatasetMapper, which adds categorical labels as a semantic mask.
        """
        mapper = DatasetMapperWithBasis(cfg, True)
        return build_detection_train_loader(cfg, mapper=mapper)

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        """
        Create evaluator(s) for a given dataset.
        This uses the special metadata "evaluator_type" associated with each builtin dataset.
        For your own dataset, you can simply create an evaluator manually in your
        script and do not have to worry about the hacky if-else logic here.
        """
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        evaluator_list = []
        evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
        if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
            evaluator_list.append(
                SemSegEvaluator(
                    dataset_name,
                    distributed=True,
                    num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
                    ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
                    output_dir=output_folder,
                )
            )
        if evaluator_type in ["coco", "coco_panoptic_seg"]:
            evaluator_list.append(COCOEvaluator(
                dataset_name, cfg, True, output_folder))
        if evaluator_type == "coco_panoptic_seg":
            evaluator_list.append(COCOPanopticEvaluator(
                dataset_name, output_folder))
        if evaluator_type == "pascal_voc":
            return PascalVOCDetectionEvaluator(dataset_name)
        if evaluator_type == "lvis":
            return LVISEvaluator(dataset_name, cfg, True, output_folder)
        if evaluator_type == "text":
            if cfg.TEST.DET_ONLY:
                return TextDetEvaluator(dataset_name, cfg, True, output_folder)
            else:
                return TextEvaluator(dataset_name, cfg, True, output_folder)
        if len(evaluator_list) == 0:
            raise NotImplementedError(
                "no Evaluator for the dataset {} with the type {}".format(
                    dataset_name, evaluator_type
                )
            )
        if len(evaluator_list) == 1:
            return evaluator_list[0]
        return DatasetEvaluators(evaluator_list)

    @classmethod
    def test_with_TTA(cls, cfg, model):
        logger = logging.getLogger("adet.trainer")
        # In the end of training, run an evaluation with TTA
        # Only support some R-CNN models.
        logger.info("Running inference with test-time augmentation ...")
        model = GeneralizedRCNNWithTTA(cfg, model)
        evaluators = [
            cls.build_evaluator(
                cfg, name, output_folder=os.path.join(
                    cfg.OUTPUT_DIR, "inference_TTA")
            )
            for name in cfg.DATASETS.TEST
        ]
        res = cls.test(cfg, model, evaluators)
        res = OrderedDict({k + "_TTA": v for k, v in res.items()})
        return res

    @classmethod
    def build_optimizer(cls, cfg, model):
        def match_name_keywords(n, name_keywords):
            out = False
            for b in name_keywords:
                if b in n:
                    out = True
                    break
            return out

        params: List[Dict[str, Any]] = []
        memo: Set[torch.nn.parameter.Parameter] = set()
        for key, value in model.named_parameters(recurse=True):
            if not value.requires_grad:
                continue
            # Avoid duplicating parameters
            if value in memo:
                continue
            memo.add(value)
            lr = cfg.SOLVER.BASE_LR
            weight_decay = cfg.SOLVER.WEIGHT_DECAY

            if match_name_keywords(key, cfg.SOLVER.LR_BACKBONE_NAMES):
                lr = cfg.SOLVER.LR_BACKBONE
            elif match_name_keywords(key, cfg.SOLVER.LR_LINEAR_PROJ_NAMES):
                lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.LR_LINEAR_PROJ_MULT

            params += [{"params": [value], "lr": lr,
                        "weight_decay": weight_decay}]

        # optim: the optimizer class
        def maybe_add_full_model_gradient_clipping(optim):
            # detectron2 doesn't have full model gradient clipping now
            clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
            enable = (
                cfg.SOLVER.CLIP_GRADIENTS.ENABLED
                and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
                and clip_norm_val > 0.0
            )

            class FullModelGradientClippingOptimizer(optim):
                def step(self, closure=None):
                    all_params = itertools.chain(
                        *[x["params"] for x in self.param_groups])
                    torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
                    super().step(closure=closure)

            return FullModelGradientClippingOptimizer if enable else optim

        optimizer_type = cfg.SOLVER.OPTIMIZER
        if optimizer_type == "SGD":
            optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
                params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
            )
        elif optimizer_type == "ADAMW":
            optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
                params, cfg.SOLVER.BASE_LR
            )
        else:
            raise NotImplementedError(f"no optimizer type {optimizer_type}")
        if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
            optimizer = maybe_add_gradient_clipping(cfg, optimizer)
        return optimizer


def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)

    rank = comm.get_rank()
    setup_logger(cfg.OUTPUT_DIR, distributed_rank=rank, name="adet")

    return cfg


def main(args):
    cfg = setup(args)

    if args.eval_only:
        model = Trainer.build_model(cfg)
        AdetCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)  # d2 defaults.py
        if comm.is_main_process():
            verify_results(cfg, res)
        if cfg.TEST.AUG.ENABLED:
            res.update(Trainer.test_with_TTA(cfg, model))
        return res

    """
    If you'd like to do anything fancier than the standard training logic,
    consider writing your own training loop or subclassing the trainer.
    """
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(
                0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()


# if __name__ == "__main__":
import argparse
args = argparse.Namespace(config_file='configs/DPText_DETR/CTW1500/R_50_poly.yaml', dist_url='tcp://127.0.0.1:49152', eval_only=False, machine_rank=0, num_gpus=1, num_machines=1, opts=[], resume=False)
cfg = setup(args)
trainer = Trainer(cfg)
data=iter(trainer.data_loader).__next__()
model=trainer.model
dptext=model.dptext_detr
transformer=dptext.transformer
decoder=transformer.decoder
layer=decoder.layers[0]
deform_attn=layer.attn_cross
# trainer.resume_or_load(resume=args.resume)
#     print("Command Line Args:", args)
#     launch(
#         main,
#         args.num_gpus,
#         num_machines=args.num_machines,
#         machine_rank=args.machine_rank,
#         dist_url=args.dist_url,
#         args=(args,),
#     )


Config 'configs/DPText_DETR/CTW1500/R_50_poly.yaml' has no VERSION. Assuming it to be compatible with latest v2.


[32m[04/23 15:20:20 detectron2]: [0mRank of current process: 0. World size: 1
[32m[04/23 15:20:21 detectron2]: [0mEnvironment info:
----------------------  ---------------------------------------------------------------------------------------
sys.platform            linux
Python                  3.8.16 (default, Mar  2 2023, 03:21:46) [GCC 11.2.0]
numpy                   1.24.2
detectron2              0.6 @/root/miniconda3/envs/DPText-DETR/lib/python3.8/site-packages/detectron2
Compiler                GCC 7.3
CUDA compiler           CUDA 11.1
detectron2 arch flags   3.7, 5.0, 5.2, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6
DETECTRON2_ENV_MODULE   <not set>
PyTorch                 1.9.1+cu111 @/root/miniconda3/envs/DPText-DETR/lib/python3.8/site-packages/torch
PyTorch debug build     False
GPU available           Yes
GPU 0,1                 NVIDIA GeForce RTX 3090 (arch=8.6)
Driver version          510.73.08
CUDA_HOME               /usr/local/cuda
Pillow                  9.5.0
torchvision       

In [2]:
model.training=True
batched_inputs=data[:1]
print(batched_inputs[0].keys())
tmp=batched_inputs[0]['instances']
print(tmp.get_fields().keys())
images = model.preprocess_image(batched_inputs)
print(images.image_sizes)
print(images.tensor.shape)
gt_instances = [x["instances"].to(model.device) for x in batched_inputs]
print(gt_instances[0].get_fields().keys())
targets = model.prepare_targets(gt_instances)
print(targets[0].keys())
print(targets[0]['ctrl_points'].shape)
output = model.dptext_detr(images)
print(output.keys())

dict_keys(['file_name', 'height', 'width', 'image_id', 'image', 'instances'])
dict_keys(['gt_boxes', 'gt_classes', 'text', 'polygons'])
[(840, 800)]
torch.Size([1, 3, 840, 800])
dict_keys(['gt_boxes', 'gt_classes', 'text', 'polygons'])
dict_keys(['labels', 'boxes', 'ctrl_points'])
torch.Size([7, 16, 2])
dict_keys(['pred_logits', 'pred_ctrl_points', 'aux_outputs', 'enc_outputs'])


In [3]:
from adet.utils.misc import NestedTensor
import torch.nn.functional as F
features, pos = dptext.backbone(images)

if dptext.num_feature_levels == 1:
    raise NotImplementedError

srcs = []
masks = []
for l, feat in enumerate(features):
    src, mask = feat.decompose()
    srcs.append(dptext.input_proj[l](src))
    masks.append(mask)
    assert mask is not None
if dptext.num_feature_levels > len(srcs):
    _len_srcs = len(srcs)
    for l in range(_len_srcs, dptext.num_feature_levels):
        if l == _len_srcs:
            src = dptext.input_proj[l](features[-1].tensors)
        else:
            src = dptext.input_proj[l](srcs[-1])
        m = masks[0]
        mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
        pos_l = dptext.backbone[1](NestedTensor(src, mask)).to(src.dtype)
        srcs.append(src)
        masks.append(mask)
        pos.append(pos_l)

# n_pts, embed_dim --> n_q, n_pts, embed_dim
ctrl_point_embed = dptext.ctrl_point_embed.weight[None, ...].repeat(dptext.num_proposals, 1, 1)
# hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = dptext.transformer(
#     srcs, masks, pos, ctrl_point_embed
# )
srcs, masks, pos_embeds, query_embed=srcs, masks, pos, ctrl_point_embed


In [4]:
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
    # print(lvl)
    bs, c, h, w = src.shape
    spatial_shape = (h, w)
    spatial_shapes.append(spatial_shape)
    src = src.flatten(2).transpose(1, 2)
    mask = mask.flatten(1)
    pos_embed = pos_embed.flatten(2).transpose(1, 2)
    lvl_pos_embed = pos_embed + transformer.level_embed[lvl].view(1, 1, -1)
    lvl_pos_embed_flatten.append(lvl_pos_embed)
    src_flatten.append(src)
    mask_flatten.append(mask)
src_flatten = torch.cat(src_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([transformer.get_valid_ratio(m) for m in masks], 1)

In [5]:
memory = transformer.encoder(
    src_flatten,
    spatial_shapes,
    level_start_index,
    valid_ratios,
    lvl_pos_embed_flatten,
    mask_flatten
)
bs, _, c = memory.shape
output_memory, output_proposals = transformer.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
enc_outputs_class = transformer.bbox_class_embed(output_memory)
enc_outputs_coord_unact = transformer.bbox_embed(output_memory) + output_proposals

topk = transformer.num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid()  # (bs, nq, 4)
reference_points = transformer.init_control_points_from_anchor(reference_points)
init_reference_out = reference_points
# learnable control point content queries
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1, -1)
hs, inter_references = transformer.decoder(
    query_embed,
    reference_points,
    memory,
    spatial_shapes,
    level_start_index,
    valid_ratios,
    query_pos= None,
    src_padding_mask=mask_flatten
)

In [6]:
tgt=query_embed
reference_points=reference_points
src=memory
src_spatial_shapes=spatial_shapes
src_level_start_index=level_start_index
src_valid_ratios=valid_ratios
src_padding_mask=mask_flatten

from adet.modeling.dptext_detr.utils import MLP, gen_point_pos_embed
output = tgt
reference_points_input = reference_points[:, :, :, None] * src_valid_ratios[:, None, None]
query_pos = gen_point_pos_embed(reference_points_input[:, :, :, 0, :])
# get the positional queries
query_pos = decoder.ref_point_head(query_pos) # projection

# output = layer(
#     output,
#     query_pos,
#     reference_points_input,
#     src,
#     src_spatial_shapes,
#     src_level_start_index,
#     src_padding_mask
# )
tgt=output
query_pos=query_pos
reference_points=reference_points_input
src=src
src_spatial_shapes=src_spatial_shapes
level_start_index=src_level_start_index
src_padding_mask=src_padding_mask

In [7]:
shortcut = tgt
q = k = layer.with_pos_embed(tgt, query_pos)
tgt = layer.attn_intra(
    q.flatten(0, 1).transpose(0, 1),
    k.flatten(0, 1).transpose(0, 1),
    tgt.flatten(0, 1).transpose(0, 1),
)[0].transpose(0, 1).reshape(q.shape)
tgt_circonv = layer.drop_path(layer.circonv(shortcut+query_pos))
tgt = shortcut + layer.norm_intra(layer.drop_path(tgt) + tgt_circonv)
tgt = tgt + layer.drop_path(layer.norm_fuse(layer.mlp_fuse(tgt)))

q_inter = k_inter = tgt_inter = torch.swapdims(tgt, 1, 2)  # (bs, n_pts, n_q, dim)
tgt2_inter = layer.attn_inter(
    q_inter.flatten(0, 1).transpose(0, 1),
    k_inter.flatten(0, 1).transpose(0, 1),
    tgt_inter.flatten(0, 1).transpose(0, 1)
)[0].transpose(0, 1).reshape(q_inter.shape)
tgt_inter = tgt_inter + layer.dropout_inter(tgt2_inter)
tgt_inter = torch.swapdims(layer.norm_inter(tgt_inter), 1, 2)

reference_points_loc = reference_points
tgt2 = layer.attn_cross(
    layer.with_pos_embed(tgt_inter, query_pos).flatten(1, 2),
    reference_points_loc.flatten(1, 2),
    src,
    src_spatial_shapes,
    level_start_index,
    src_padding_mask
).reshape(tgt_inter.shape)
tgt_inter = tgt_inter + layer.dropout_cross(tgt2)
tgt = layer.norm_cross(tgt_inter)

# ffn
tgt = layer.forward_ffn(tgt)

In [8]:
layer.attn_cross(
    layer.with_pos_embed(tgt_inter, query_pos).flatten(1, 2),
    reference_points_loc.flatten(1, 2),
    src,
    src_spatial_shapes,
    level_start_index,
    src_padding_mask
).shape

torch.Size([1, 1600, 256])

In [9]:
query=layer.with_pos_embed(tgt_inter, query_pos).flatten(1, 2)
reference_points=reference_points_loc.flatten(1, 2)
input_flatten=src
input_spatial_shapes=src_spatial_shapes
input_level_start_index=level_start_index
input_padding_mask=src_padding_mask

N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in

value = deform_attn.value_proj(input_flatten)
if input_padding_mask is not None:
    value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, deform_attn.n_heads, deform_attn.d_model // deform_attn.n_heads)
sampling_offsets = deform_attn.sampling_offsets(query).view(N, Len_q, deform_attn.n_heads, deform_attn.n_levels, deform_attn.n_points, 2)
attention_weights = deform_attn.attention_weights(query).view(N, Len_q, deform_attn.n_heads, deform_attn.n_levels * deform_attn.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, deform_attn.n_heads, deform_attn.n_levels, deform_attn.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
    offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
    sampling_locations = reference_points[:, :, None, :, None, :] \
                            + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
    sampling_locations = reference_points[:, :, None, :, None, :2] \
                            + sampling_offsets / deform_attn.n_points * reference_points[:, :, None, :, None, 2:] * 0.5

In [1]:
import math
import torch
def get_proposal_pos_embed( proposals):
    num_pos_feats = 64
    temperature = 10000
    scale = 2 * math.pi

    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
    dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / num_pos_feats)
    # N, L, 4
    proposals = proposals.sigmoid() * scale
    # N, L, 4, 64
    pos = proposals[:, :, :, None] / dim_t
    # N, L, 256
    pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
    return pos

In [4]:
get_proposal_pos_embed(torch.rand([8, 100, 4])).shape

torch.Size([8, 100, 256])

In [18]:
deform_attn.attention_weights

Linear(in_features=256, out_features=128, bias=True)

In [11]:
        # output = _MSDeformAttnFunction.apply(
        #     value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
        # output = self.output_proj(output)

In [12]:
value.shape

torch.Size([1, 14007, 8, 32])

In [13]:
input_spatial_shapes

tensor([[105, 100],
        [ 53,  50],
        [ 27,  25],
        [ 14,  13]], device='cuda:0')

In [14]:
sampling_locations.shape

torch.Size([1, 1600, 8, 4, 4, 2])

In [15]:
attention_weights.shape

torch.Size([1, 1600, 8, 4, 4])

In [16]:
deform_attn.im2col_step

64

In [None]:

        hs, inter_references = self.decoder(
            query_embed,# (bs, nq,1, 256)
            reference_points, # (bs, nq,1, 2)
            memory, #(bs, map_size, 256)
            spatial_shapes,
            level_start_index,
            valid_ratios,
            src_padding_mask=mask_flatten
        )
            output = layer(
                output,# bs, n_q, 1, 256
                query_pos, #  (bs, nq, 1, dim)
                reference_points_input, #(bs, nq, 1, 4, 2)
                src, #(bs, map_size, 256)
                src_spatial_shapes,
                src_level_start_index,
                src_padding_mask
            )

In [9]:
def gen_point_pos_embed(pts_tensor):
    # pts_tensor shape: (bs, nq, n_pts, 2)
    # return size:
    # - pos: (bs, nq, n_pts, 256)
    scale = 2 * math.pi
    dim_t = torch.arange(128, dtype=torch.float32, device=pts_tensor.device)
    dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / 128)
    x_embed = pts_tensor[:, :, :, 0] * scale
    y_embed = pts_tensor[:, :, :, 1] * scale
    pos_x = x_embed[:, :, :, None] / dim_t
    pos_y = y_embed[:, :, :, None] / dim_t
    pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
    pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
    pos = torch.cat((pos_x, pos_y), dim=-1)
    return pos

tmp=gen_point_pos_embed(torch.rand([8, 100, 1, 2]))

In [10]:
tmp.shape

torch.Size([8, 100, 1, 256])

In [8]:
tmp

tensor([[[[ 1.3424e-01,  9.9095e-01,  1.1633e-01,  ...,  1.0000e+00,
            2.7111e-04,  1.0000e+00],
          [ 2.2331e-02,  9.9975e-01,  1.9338e-02,  ...,  1.0000e+00,
            1.7975e-04,  1.0000e+00]],

         [[ 9.4326e-01,  3.3206e-01,  8.7582e-01,  ...,  1.0000e+00,
            2.1771e-04,  1.0000e+00],
          [ 1.9977e-01, -9.7984e-01,  5.6072e-01,  ...,  1.0000e+00,
            1.4105e-04,  1.0000e+00]],

         [[-2.6215e-01,  9.6503e-01, -8.7809e-01,  ...,  1.0000e+00,
            5.8233e-04,  1.0000e+00],
          [-4.0687e-01,  9.1349e-01, -9.3385e-01,  ...,  1.0000e+00,
            1.5532e-04,  1.0000e+00]],

         ...,

         [[-9.6692e-01,  2.5508e-01, -9.1781e-01,  ...,  1.0000e+00,
            1.3902e-04,  1.0000e+00],
          [ 2.0018e-01,  9.7976e-01,  1.7364e-01,  ...,  1.0000e+00,
            4.6905e-04,  1.0000e+00]],

         [[ 7.5101e-01, -6.6029e-01,  9.1552e-01,  ...,  1.0000e+00,
            4.0957e-04,  1.0000e+00],
          [-7.