In [1]:
from pathlib import Path
import os
from functools import partial
import numpy as np
import torch
import time
import spconv.pytorch as spconv
import torch.nn as nn

class input_backbone(nn.Module):
    def __init__(self, input_channels, grid_size):
        super().__init__()
        norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
        self.sparse_shape = grid_size[::-1] + [1, 0, 0]

        self.conv_input = spconv.SparseSequential(
            spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key="subm0").to(0),
            norm_fn(16).to(0),
            nn.ReLU()
        )

    def forward(self, voxel_features, voxel_coords):
        sp_tensor = spconv.SparseConvTensor(
            features = voxel_features,
            indices = voxel_coords,
            spatial_shape = self.sparse_shape,
            batch_size = 1
        )

        x = self.conv_input(sp_tensor)
        return x


In [2]:
label_dict = {'Car':0, 'Pedestrian':1, 'Rider':2, 'Truck':3, 'Van':4}
vsize_xyz = np.array([0.5, 0.5, 1])       
coors_range_xyz = np.array([-70.4, -40, -3, 70.4, 40, 1])  
num_point_features = 4                                 
max_num_points_per_voxel = 5               
max_num_voxels = 16000
input_channels = 4
grid_size = (coors_range_xyz[3:6] - coors_range_xyz[0:3]) / np.array(vsize_xyz)
grid_size = np.round(grid_size).astype(np.int64)

ckpt_path = './../../output/kitti_models/voxel_rcnn_car/A_S22/ckpt/checkpoint_epoch_160.pth'
trained_dict = torch.load(ckpt_path)

In [3]:
model = input_backbone(input_channels, grid_size)

In [4]:
model_dict = model.state_dict()
state_dict = {k:v for k,v in trained_dict.items() if k in model_dict.keys()}
model_dict.update(state_dict)
model.load_state_dict(model_dict)

<All keys matched successfully>