Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inf in attached conv layers #8

Closed
chyohoo opened this issue Jan 26, 2022 · 9 comments
Closed

Inf in attached conv layers #8

chyohoo opened this issue Jan 26, 2022 · 9 comments

Comments

@chyohoo
Copy link

chyohoo commented Jan 26, 2022

Hi, recently I used SST to train on nuScenes dataset. Everything worked fine in the beginning. But after several epchos, I got nan in bbox_loss and dir_los. And I found that the loss is caused by inf output from the attached conv layer at the end of the sstv1, where the output from recover_bev is fp32 and the intermediate feature maps from conv2d in attached_conv output fp16 values, with the training going on the output values becoming inf.
I print the weights of conv which is normal.
I tried to clamp the inf value in the feature map, but inf value occurs in the following layers,
what could be wrong?

@happynear
Copy link
Collaborator

It seems to be a float precision problem. Maybe you can try to run the training in FP32 mode?

@chyohoo
Copy link
Author

chyohoo commented Jan 27, 2022

It seems to be a float precision problem. Maybe you can try to run the training in FP32 mode?

thanks for the advice, I am trying to disable fp16 training by commenting this line in the config file.
But I got

Traceback (most recent call last):
  File "./tools/train.py", line 235, in <module>
    main()
  File "./tools/train.py", line 232, in main
    meta=meta)
  File "/home/chen/WS_SST/SST/mmdet3d/apis/train.py", line 34, in train_model
    meta=meta)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmdet/apis/train.py", line 181, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True, **kwargs)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 30, in run_iter
    **kwargs)
  File "/home/che/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py", line 75, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmdet/models/detectors/base.py", line 237, in train_step
    losses = self(**data)
  File "/home/chene/anaconda3/envs/SST/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 98, in new_func
    return old_func(*args, **kwargs)
  File "/home/chen/WS_SST/SST/mmdet3d/models/detectors/base.py", line 58, in forward
    return self.forward_train(**kwargs)
  File "/home/chen/WS_SST/SST/mmdet3d/models/detectors/voxelnet.py", line 90, in forward_train
    x = self.extract_feat(points, img_metas)
  File "/home/chen/WS_SST/SST/mmdet3d/models/detectors/dynamic_voxelnet.py", line 50, in extract_feat
    x = self.middle_encoder(voxel_features, feature_coors, batch_size) # SSTInputLayer
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 98, in new_func
    return old_func(*args, **kwargs)
TypeError: forward() takes 3 positional arguments but 4 were given

And then I remove the auto_fp16 decorator of SSTInputLayer and Base3DDetector but still get another traceback.:

Traceback (most recent call last):
  File "./tools/train.py", line 235, in <module>
    main()
  File "./tools/train.py", line 232, in main
    meta=meta)
  File "/home/chen/WS_SST/SST/mmdet3d/apis/train.py", line 34, in train_model
    meta=meta)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmdet/apis/train.py", line 181, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True, **kwargs)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 30, in run_iter
    **kwargs)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py", line 75, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/mmdet/models/detectors/base.py", line 237, in train_step
    losses = self(**data)
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/chen/WS_SST/SST/mmdet3d/models/detectors/base.py", line 58, in forward
    return self.forward_train(**kwargs)
  File "/home/chen/WS_SST/SST/mmdet3d/models/detectors/voxelnet.py", line 90, in forward_train
    x = self.extract_feat(points, img_metas)
  File "/home/chen/WS_SST/SST/mmdet3d/models/detectors/dynamic_voxelnet.py", line 50, in extract_feat
    x = self.middle_encoder(voxel_features, feature_coors, batch_size) # SSTInputLayer
  File "/home/chen/anaconda3/envs/SST/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
TypeError: forward() takes 3 positional arguments but 4 were given

Did I miss something?

@Abyssaledge
Copy link
Collaborator

Could you show me your model config?

@chyohoo
Copy link
Author

chyohoo commented Jan 27, 2022

model = dict(
    type='DynamicVoxelNet',
    neck=dict(
        type='SECONDFPN',
        norm_cfg=dict(type='naiveSyncBN2d', eps=0.001, momentum=0.01),
        in_channels=[128],
        upsample_strides=[1],
        out_channels=[384]),
    bbox_head=dict(
        type='Anchor3DHead',
        num_classes=3,
        in_channels=384,
        feat_channels=384,
        use_direction_classifier=True,
        anchor_generator=dict(
            type='AlignedAnchor3DRangeGenerator',
            ranges=[[-49.6, 0, -1.80032795, 49.6, 49.6, -1.80032795],
                    [-49.6, 0, -1.74440365, 49.6, 49.6, -1.74440365],
                    [-49.6, 0, -1.68526504, 49.6, 49.6, -1.68526504]],
            sizes=[[1.95017717, 4.60718145, 1.72270761],
                   [2.4560939, 6.73778078, 2.73004906],
                   [2.87427237, 12.01320693, 3.81509561]],
            rotations=[0, 1.57],
            reshape_out=False),
        diff_rad_by_sin=True,
        dir_offset=0.7854,
        dir_limit_offset=0,
        bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder', code_size=7),
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=0.5),
        loss_dir=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)),
    train_cfg=dict(
        assigner=[
            dict(
                type='MaxIoUAssigner',
                iou_calculator=dict(type='BboxOverlapsNearest3D'),
                pos_iou_thr=0.5,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                ignore_iof_thr=-1),
            dict(
                type='MaxIoUAssigner',
                iou_calculator=dict(type='BboxOverlapsNearest3D'),
                pos_iou_thr=0.5,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                ignore_iof_thr=-1),
            dict(
                type='MaxIoUAssigner',
                iou_calculator=dict(type='BboxOverlapsNearest3D'),
                pos_iou_thr=0.5,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                ignore_iof_thr=-1)
        ],
        allowed_border=0,
        code_weight=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
        pos_weight=-1,
        debug=False),
    test_cfg=dict(
        use_rotate_nms=True,
        nms_across_levels=False,
        nms_pre=4096,
        nms_thr=0.25,
        score_thr=0.1,
        min_bbox_size=0,
        max_num=500),
    voxel_layer=dict(
        voxel_size=(0.2, 0.2, 8),
        max_num_points=-1,
        point_cloud_range=[-50, 0, -5, 50, 50, 3],
        max_voxels=(-1, -1)),
    voxel_encoder=dict(
        type='DynamicVFE',
        in_channels=4,
        feat_channels=[64, 128],
        with_distance=False,
        voxel_size=(0.2, 0.2, 8),
        with_cluster_center=True,
        with_voxel_center=True,
        point_cloud_range=[-50, 0, -5, 50, 50, 3],
        norm_cfg=dict(type='naiveSyncBN1d', eps=0.001, momentum=0.01)),
    middle_encoder=dict(
        type='SSTInputLayer',
        window_shape=(10, 10),
        shifts_list=[(0, 0), (5, 5)],
        point_cloud_range=[-50, 0, -5, 50, 50, 3],
        voxel_size=(0.2, 0.2, 8),
        shuffle_voxels=True,
        debug=True,
        drop_info=({
            0: {
                'max_tokens': 30,
                'drop_range': (0, 30)
            },
            1: {
                'max_tokens': 60,
                'drop_range': (30, 60)
            },
            2: {
                'max_tokens': 100,
                'drop_range': (60, 100000)
            }
        }, {
            0: {
                'max_tokens': 30,
                'drop_range': (0, 30)
            },
            1: {
                'max_tokens': 60,
                'drop_range': (30, 60)
            },
            2: {
                'max_tokens': 100,
                'drop_range': (60, 100000)
            }
        })),
    backbone=dict(
        type='SSTv1',
        d_model=[128, 128, 128, 128, 128, 128],
        nhead=[8, 8, 8, 8, 8, 8],
        num_blocks=6,
        dim_feedforward=[256, 256, 256, 256, 256, 256],
        output_shape=[250, 500],
        num_attached_conv=3,
        conv_kwargs=[
            dict(kernel_size=3, dilation=1, padding=1, stride=1),
            dict(kernel_size=3, dilation=1, padding=1, stride=1),
            dict(kernel_size=3, dilation=1, padding=1, stride=1)
        ],
        conv_in_channel=128,
        conv_out_channel=128,
        debug=True,
        drop_info=({
            0: {
                'max_tokens': 30,
                'drop_range': (0, 30)
            },
            1: {
                'max_tokens': 60,
                'drop_range': (30, 60)
            },
            2: {
                'max_tokens': 100,
                'drop_range': (60, 100000)
            }
        }, {
            0: {
                'max_tokens': 30,
                'drop_range': (0, 30)
            },
            1: {
                'max_tokens': 60,
                'drop_range': (30, 60)
            },
            2: {
                'max_tokens': 100,
                'drop_range': (60, 100000)
            }
        }),
        pos_temperature=10000,
        normalize_pos=False,
        window_shape=(10, 10),
        checkpoint_blocks=[0, 1, 2]))
point_cloud_range = [-50, 0, -5, 50, 50, 3]
class_names = ['car', 'truck', 'trailer']
dataset_type = 'NuScenesDataset'
data_root = '/home/algo_zf/Data/nuscence/'
input_modality = dict(
    use_lidar=True,
    use_camera=False,
    use_radar=False,
    use_map=False,
    use_external=False)
file_client_args = dict(backend='disk')
train_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=5,
        use_dim=5,
        file_client_args=dict(backend='disk')),
    dict(
        type='LoadPointsFromMultiSweeps',
        sweeps_num=10,
        load_dim=5,
        use_dim=[0, 1, 2, 3],
        file_client_args=dict(backend='disk')),
    dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
    dict(
        type='GlobalRotScaleTrans',
        rot_range=[-0.3925, 0.3925],
        scale_ratio_range=[0.95, 1.05],
        translation_std=[0, 0, 0]),
    dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
    dict(type='PointsRangeFilter', point_cloud_range=[-50, 0, -5, 50, 50, 3]),
    dict(type='ObjectRangeFilter', point_cloud_range=[-50, 0, -5, 50, 50, 3]),
    dict(type='ObjectNameFilter', classes=['car', 'truck', 'trailer']),
    dict(type='PointShuffle'),
    dict(
        type='DefaultFormatBundle3D', class_names=['car', 'truck', 'trailer']),
    dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=5,
        use_dim=5,
        file_client_args=dict(backend='disk')),
    dict(
        type='LoadPointsFromMultiSweeps',
        sweeps_num=10,
        load_dim=5,
        use_dim=[0, 1, 2, 3],
        file_client_args=dict(backend='disk')),
    dict(
        type='MultiScaleFlipAug3D',
        img_scale=(1333, 800),
        pts_scale_ratio=1,
        flip=False,
        transforms=[
            dict(
                type='GlobalRotScaleTrans',
                rot_range=[0, 0],
                scale_ratio_range=[1.0, 1.0],
                translation_std=[0, 0, 0]),
            dict(type='RandomFlip3D'),
            dict(
                type='PointsRangeFilter',
                point_cloud_range=[-50, 0, -5, 50, 50, 3]),
            dict(
                type='DefaultFormatBundle3D',
                class_names=['car', 'truck', 'trailer'],
                with_label=False),
            dict(type='Collect3D', keys=['points'])
        ])
]
eval_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=5,
        use_dim=5,
        file_client_args=dict(backend='disk')),
    dict(
        type='LoadPointsFromMultiSweeps',
        load_dim=5,
        use_dim=[0, 1, 2, 3],
        sweeps_num=10,
        file_client_args=dict(backend='disk')),
    dict(
        type='DefaultFormatBundle3D',
        class_names=['car', 'truck', 'trailer'],
        with_label=False),
    dict(type='Collect3D', keys=['points'])
]
data = dict(
    samples_per_gpu=8,
    workers_per_gpu=8,
    train=dict(
        type='RepeatDataset',
        times=1,
        dataset=dict(
            type='NuScenesDataset',
            data_root='/home/algo_zf/Data/nuscence/',
            ann_file=
            '/home/algo_zf/Data/nuscence/nuscenes-lidar_infos_train.pkl',
            pipeline=[
                dict(
                    type='LoadPointsFromFile',
                    coord_type='LIDAR',
                    load_dim=5,
                    use_dim=5,
                    file_client_args=dict(backend='disk')),
                dict(
                    type='LoadPointsFromMultiSweeps',
                    sweeps_num=10,
                    load_dim=5,
                    use_dim=[0, 1, 2, 3],
                    file_client_args=dict(backend='disk')),
                dict(
                    type='LoadAnnotations3D',
                    with_bbox_3d=True,
                    with_label_3d=True),
                dict(
                    type='GlobalRotScaleTrans',
                    rot_range=[-0.3925, 0.3925],
                    scale_ratio_range=[0.95, 1.05],
                    translation_std=[0, 0, 0]),
                dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
                dict(
                    type='PointsRangeFilter',
                    point_cloud_range=[-50, 0, -5, 50, 50, 3]),
                dict(
                    type='ObjectRangeFilter',
                    point_cloud_range=[-50, 0, -5, 50, 50, 3]),
                dict(
                    type='ObjectNameFilter',
                    classes=['car', 'truck', 'trailer']),
                dict(type='PointShuffle'),
                dict(
                    type='DefaultFormatBundle3D',
                    class_names=['car', 'truck', 'trailer']),
                dict(
                    type='Collect3D',
                    keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
            ],
            with_velocity=False,
            modality=dict(
                use_lidar=True,
                use_camera=False,
                use_radar=False,
                use_map=False,
                use_external=False),
            classes=['car', 'truck', 'trailer'],
            test_mode=False,
            box_type_3d='LiDAR',
            load_interval=1)),
    val=dict(
        type='NuScenesDataset',
        data_root='/home/algo_zf/Data/nuscence/',
        ann_file='/home/algo_zf/Data/nuscence/nuscenes-lidar_infos_val.pkl',
        pipeline=[
            dict(
                type='LoadPointsFromFile',
                coord_type='LIDAR',
                load_dim=5,
                use_dim=5,
                file_client_args=dict(backend='disk')),
            dict(
                type='LoadPointsFromMultiSweeps',
                sweeps_num=10,
                load_dim=5,
                use_dim=[0, 1, 2, 3],
                file_client_args=dict(backend='disk')),
            dict(
                type='MultiScaleFlipAug3D',
                img_scale=(1333, 800),
                pts_scale_ratio=1,
                flip=False,
                transforms=[
                    dict(
                        type='GlobalRotScaleTrans',
                        rot_range=[0, 0],
                        scale_ratio_range=[1.0, 1.0],
                        translation_std=[0, 0, 0]),
                    dict(type='RandomFlip3D'),
                    dict(
                        type='PointsRangeFilter',
                        point_cloud_range=[-50, 0, -5, 50, 50, 3]),
                    dict(
                        type='DefaultFormatBundle3D',
                        class_names=['car', 'truck', 'trailer'],
                        with_label=False),
                    dict(type='Collect3D', keys=['points'])
                ])
        ],
        modality=dict(
            use_lidar=True,
            use_camera=False,
            use_radar=False,
            use_map=False,
            use_external=False),
        with_velocity=False,
        classes=['car', 'truck', 'trailer'],
        test_mode=True,
        box_type_3d='LiDAR'),
    test=dict(
        type='NuScenesDataset',
        data_root='/home/algo_zf/Data/nuscence/',
        ann_file='/home/algo_zf/Data/nuscence/nuscenes-lidar_infos_val.pkl',
        pipeline=[
            dict(
                type='LoadPointsFromFile',
                coord_type='LIDAR',
                load_dim=5,
                use_dim=5,
                file_client_args=dict(backend='disk')),
            dict(
                type='LoadPointsFromMultiSweeps',
                sweeps_num=10,
                load_dim=5,
                use_dim=[0, 1, 2, 3],
                file_client_args=dict(backend='disk')),
            dict(
                type='MultiScaleFlipAug3D',
                img_scale=(1333, 800),
                pts_scale_ratio=1,
                flip=False,
                transforms=[
                    dict(
                        type='GlobalRotScaleTrans',
                        rot_range=[0, 0],
                        scale_ratio_range=[1.0, 1.0],
                        translation_std=[0, 0, 0]),
                    dict(type='RandomFlip3D'),
                    dict(
                        type='PointsRangeFilter',
                        point_cloud_range=[-50, 0, -5, 50, 50, 3]),
                    dict(
                        type='DefaultFormatBundle3D',
                        class_names=['car', 'truck', 'trailer'],
                        with_label=False),
                    dict(type='Collect3D', keys=['points'])
                ])
        ],
        modality=dict(
            use_lidar=True,
            use_camera=False,
            use_radar=False,
            use_map=False,
            use_external=False),
        with_velocity=False,
        classes=['car', 'truck', 'trailer'],
        test_mode=True,
        box_type_3d='LiDAR'))
evaluation = dict(
    interval=12,
    pipeline=[
        dict(
            type='LoadPointsFromFile',
            coord_type='LIDAR',
            load_dim=5,
            use_dim=5,
            file_client_args=dict(backend='disk')),
        dict(
            type='LoadPointsFromMultiSweeps',
            load_dim=5,
            use_dim=[0, 1, 2, 3],
            sweeps_num=10,
            file_client_args=dict(backend='disk')),
        dict(
            type='DefaultFormatBundle3D',
            class_names=['car', 'truck', 'trailer'],
            with_label=False),
        dict(type='Collect3D', keys=['points'])
    ])
lr = 0.0001
optimizer = dict(
    type='AdamW',
    lr=0.0001,
    betas=(0.9, 0.999),
    weight_decay=0.05,
    paramwise_cfg=dict(custom_keys=dict(norm=dict(decay_mult=0.0))))
optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2))
lr_config = dict(
    policy='cyclic',
    target_ratio=(10, 0.001),
    cyclic_times=1,
    step_ratio_up=0.4)
momentum_config = dict(
    policy='cyclic',
    target_ratio=(0.8947368421052632, 1),
    cyclic_times=1,
    step_ratio_up=0.4)
runner = dict(type='EpochBasedRunner', max_epochs=48)
checkpoint_config = dict(interval=1)
log_config = dict(
    interval=1,
    hooks=[dict(type='TextLoggerHook'),
           dict(type='TensorboardLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/sst_nuscenes_1x1_test_lidar_debug'
load_from = None
resume_from = None
workflow = [('train', 1)]
voxel_size = (0.2, 0.2, 8)
window_shape = (10, 10)
drop_info_training = dict({
    0: dict(max_tokens=30, drop_range=(0, 30)),
    1: dict(max_tokens=60, drop_range=(30, 60)),
    2: dict(max_tokens=100, drop_range=(60, 100000))
})
drop_info_test = dict({
    0: dict(max_tokens=30, drop_range=(0, 30)),
    1: dict(max_tokens=60, drop_range=(30, 60)),
    2: dict(max_tokens=100, drop_range=(60, 100000))
})
drop_info = ({
    0: {
        'max_tokens': 30,
        'drop_range': (0, 30)
    },
    1: {
        'max_tokens': 60,
        'drop_range': (30, 60)
    },
    2: {
        'max_tokens': 100,
        'drop_range': (60, 100000)
    }
}, {
    0: {
        'max_tokens': 30,
        'drop_range': (0, 30)
    },
    1: {
        'max_tokens': 60,
        'drop_range': (30, 60)
    },
    2: {
        'max_tokens': 100,
        'drop_range': (60, 100000)
    }
})
shifts_list = [(0, 0), (5, 5)]
gpu_ids = range(0, 1)

@chyohoo
Copy link
Author

chyohoo commented Jan 27, 2022

I'm guessing the reason why the Inf occurs during the training. I check the log, when the Inf happened, the lr was high (almost 1e-3) at that moment. Probably the large LR caused the weights to become larger during bp, and eventually, the feats value goes overflown in FP16.

@Abyssaledge
Copy link
Collaborator

Abyssaledge commented Jan 27, 2022

There is a small difference between:
https://github.com/TuSimple/SST/blob/main/mmdet3d/models/middle_encoders/sst_input_layer.py#L52
and
https://github.com/TuSimple/SST/blob/main/mmdet3d/models/middle_encoders/sst_input_layer_v2.py#L55
Your error should be caused by this difference.
So I suggest you use config in sst_refactor, which is more convenient for you to modify. In sst_refactor, we replace SSTInputLayer and SST with SSTInputLayerV2 and SSTv2.
You can also use 3D window by modify voxel_size, sparse_shape and window_shape in the refactored config, but remember write your own recover_bev function when use 3D window.

@chyohoo
Copy link
Author

chyohoo commented Jan 27, 2022

thanks, I will try the v2. BTW, have you tried not to use fp16 for training?

@Abyssaledge
Copy link
Collaborator

I do not try fp32 in these configs, but I try fp32 in our customized model and I do not see any problems.

@chyohoo chyohoo closed this as completed Feb 7, 2022
@Zoeeeing
Copy link

@chyohoo Hi, i also try to use SST to train on nuScenes dataset, could you please share the results with nuScenes dataset? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants