-
Notifications
You must be signed in to change notification settings - Fork 266
/
view_transformer.py
271 lines (229 loc) · 10.8 KB
/
view_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from ..builder import NECKS
from mmdet3d.ops import bev_pool
def gen_dx_bx(xbound, ybound, zbound):
dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]])
bx = torch.Tensor([row[0] + row[2]/2.0 for row in [xbound, ybound, zbound]])
nx = torch.Tensor([(row[1] - row[0]) / row[2] for row in [xbound, ybound, zbound]])
return dx, bx, nx
def cumsum_trick(x, geom_feats, ranks):
x = x.cumsum(0)
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
kept[:-1] = (ranks[1:] != ranks[:-1])
x, geom_feats = x[kept], geom_feats[kept]
x = torch.cat((x[:1], x[1:] - x[:-1]))
return x, geom_feats
class QuickCumsum(torch.autograd.Function):
@staticmethod
def forward(ctx, x, geom_feats, ranks):
x = x.cumsum(0)
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
kept[:-1] = (ranks[1:] != ranks[:-1])
x, geom_feats = x[kept], geom_feats[kept]
x = torch.cat((x[:1], x[1:] - x[:-1]))
# save kept for backward
ctx.save_for_backward(kept)
# no gradient for geom_feats
ctx.mark_non_differentiable(geom_feats)
return x, geom_feats
@staticmethod
def backward(ctx, gradx, gradgeom):
kept, = ctx.saved_tensors
back = torch.cumsum(kept, 0)
back[kept] -= 1
val = gradx[back]
return val, None, None
@NECKS.register_module()
class ViewTransformerLiftSplatShoot(BaseModule):
def __init__(self, grid_config=None, data_config=None,
numC_input=512, numC_Trans=64, downsample=16,
accelerate=False, max_drop_point_rate=0.0, use_bev_pool=True,
**kwargs):
super(ViewTransformerLiftSplatShoot, self).__init__()
if grid_config is None:
grid_config = {
'xbound': [-51.2, 51.2, 0.8],
'ybound': [-51.2, 51.2, 0.8],
'zbound': [-10.0, 10.0, 20.0],
'dbound': [1.0, 60.0, 1.0],}
self.grid_config = grid_config
dx, bx, nx = gen_dx_bx(self.grid_config['xbound'],
self.grid_config['ybound'],
self.grid_config['zbound'],
)
self.dx = nn.Parameter(dx, requires_grad=False)
self.bx = nn.Parameter(bx, requires_grad=False)
self.nx = nn.Parameter(nx, requires_grad=False)
if data_config is None:
data_config = {'input_size': (256, 704)}
self.data_config = data_config
self.downsample = downsample
self.frustum = self.create_frustum()
self.D, _, _, _ = self.frustum.shape
self.numC_input = numC_input
self.numC_Trans = numC_Trans
self.depthnet = nn.Conv2d(self.numC_input, self.D + self.numC_Trans, kernel_size=1, padding=0)
self.geom_feats = None
self.accelerate = accelerate
self.max_drop_point_rate = max_drop_point_rate
self.use_bev_pool = use_bev_pool
def get_depth_dist(self, x):
return x.softmax(dim=1)
def create_frustum(self):
# make grid in image plane
ogfH, ogfW = self.data_config['input_size']
fH, fW = ogfH // self.downsample, ogfW // self.downsample
ds = torch.arange(*self.grid_config['dbound'], dtype=torch.float).view(-1, 1, 1).expand(-1, fH, fW)
D, _, _ = ds.shape
xs = torch.linspace(0, ogfW - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW)
ys = torch.linspace(0, ogfH - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW)
# D x H x W x 3
frustum = torch.stack((xs, ys, ds), -1)
return nn.Parameter(frustum, requires_grad=False)
def get_geometry(self, rots, trans, intrins, post_rots, post_trans, offset=None):
"""Determine the (x,y,z) locations (in the ego frame)
of the points in the point cloud.
Returns B x N x D x H/downsample x W/downsample x 3
"""
B, N, _ = trans.shape
# undo post-transformation
# B x N x D x H x W x 3
points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3)
if offset is not None:
_,D,H,W = offset.shape
points[:,:,:,:,:,2] = points[:,:,:,:,:,2]+offset.view(B,N,D,H,W)
points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1))
# cam_to_ego
points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],
points[:, :, :, :, :, 2:3]
), 5)
if intrins.shape[3]==4: # for KITTI
shift = intrins[:,:,:3,3]
points = points - shift.view(B,N,1,1,1,3,1)
intrins = intrins[:,:,:3,:3]
combine = rots.matmul(torch.inverse(intrins))
points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
points += trans.view(B, N, 1, 1, 1, 3)
# points_numpy = points.detach().cpu().numpy()
return points
def voxel_pooling(self, geom_feats, x):
B, N, D, H, W, C = x.shape
Nprime = B * N * D * H * W
nx = self.nx.to(torch.long)
# flatten x
x = x.reshape(Nprime, C)
# flatten indices
geom_feats = ((geom_feats - (self.bx - self.dx / 2.)) / self.dx).long()
geom_feats = geom_feats.view(Nprime, 3)
batch_ix = torch.cat([torch.full([Nprime // B, 1], ix,
device=x.device, dtype=torch.long) for ix in range(B)])
geom_feats = torch.cat((geom_feats, batch_ix), 1)
# filter out points that are outside box
kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < self.nx[0]) \
& (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1]) \
& (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2])
x = x[kept]
geom_feats = geom_feats[kept]
if self.max_drop_point_rate > 0.0 and self.training:
drop_point_rate = torch.rand(1)*self.max_drop_point_rate
kept = torch.rand(x.shape[0])>drop_point_rate
x, geom_feats = x[kept], geom_feats[kept]
if self.use_bev_pool:
final = bev_pool(x, geom_feats, B, self.nx[2], self.nx[0],
self.nx[1])
final = final.transpose(dim0=-2, dim1=-1)
else:
# get tensors from the same voxel next to each other
ranks = geom_feats[:, 0] * (self.nx[1] * self.nx[2] * B) \
+ geom_feats[:, 1] * (self.nx[2] * B) \
+ geom_feats[:, 2] * B \
+ geom_feats[:, 3]
sorts = ranks.argsort()
x, geom_feats, ranks = x[sorts], geom_feats[sorts], ranks[sorts]
# cumsum trick
x, geom_feats = QuickCumsum.apply(x, geom_feats, ranks)
# griddify (B x C x Z x X x Y)
final = torch.zeros((B, C, nx[2], nx[1], nx[0]), device=x.device)
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 1], geom_feats[:, 0]] = x
# collapse Z
final = torch.cat(final.unbind(dim=2), 1)
return final
def voxel_pooling_accelerated(self, rots, trans, intrins, post_rots, post_trans, x):
B, N, D, H, W, C = x.shape
Nprime = B * N * D * H * W
nx = self.nx.to(torch.long)
# flatten x
x = x.reshape(Nprime, C)
max = 300
# flatten indices
if self.geom_feats is None:
geom_feats = self.get_geometry(rots, trans, intrins, post_rots, post_trans)
geom_feats = ((geom_feats - (self.bx - self.dx / 2.)) / self.dx).long()
geom_feats = geom_feats.view(Nprime, 3)
batch_ix = torch.cat([torch.full([Nprime // B, 1], ix,
device=x.device, dtype=torch.long) for ix in range(B)])
geom_feats = torch.cat((geom_feats, batch_ix), 1)
# filter out points that are outside box
kept1 = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < self.nx[0]) \
& (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1]) \
& (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2])
idx = torch.range(0, x.shape[0] - 1, dtype=torch.long)
x = x[kept1]
idx = idx[kept1]
geom_feats = geom_feats[kept1]
# get tensors from the same voxel next to each other
ranks = geom_feats[:, 0] * (self.nx[1] * self.nx[2] * B) \
+ geom_feats[:, 1] * (self.nx[2] * B) \
+ geom_feats[:, 2] * B \
+ geom_feats[:, 3]
sorts = ranks.argsort()
x, geom_feats, ranks, idx = x[sorts], geom_feats[sorts], ranks[sorts], idx[sorts]
repeat_id = torch.ones(geom_feats.shape[0], device=geom_feats.device, dtype=geom_feats.dtype)
curr = 0
repeat_id[0] = 0
curr_rank = ranks[0]
for i in range(1, ranks.shape[0]):
if curr_rank == ranks[i]:
curr += 1
repeat_id[i] = curr
else:
curr_rank = ranks[i]
curr = 0
repeat_id[i] = curr
kept2 = repeat_id < max
repeat_id, geom_feats, x, idx = repeat_id[kept2], geom_feats[kept2], x[kept2], idx[kept2]
geom_feats = torch.cat([geom_feats, repeat_id.unsqueeze(-1)], dim=-1)
self.geom_feats = geom_feats
self.idx = idx
else:
geom_feats = self.geom_feats
idx = self.idx
x = x[idx]
# griddify (B x C x Z x X x Y)
final = torch.zeros((B, C, nx[2], nx[1], nx[0], max), device=x.device)
final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 1], geom_feats[:, 0], geom_feats[:, 4]] = x
final = final.sum(-1)
# collapse Z
final = torch.cat(final.unbind(dim=2), 1)
return final
def forward(self, input):
x, rots, trans, intrins, post_rots, post_trans = input
B, N, C, H, W = x.shape
x = x.view(B * N, C, H, W)
x = self.depthnet(x)
depth = self.get_depth_dist(x[:, :self.D])
img_feat = x[:, self.D:(self.D + self.numC_Trans)]
# Lift
volume = depth.unsqueeze(1) * img_feat.unsqueeze(2)
volume = volume.view(B, N, self.numC_Trans, self.D, H, W)
volume = volume.permute(0, 1, 3, 4, 5, 2)
# Splat
if self.accelerate:
bev_feat = self.voxel_pooling_accelerated(rots, trans, intrins, post_rots, post_trans, volume)
else:
geom = self.get_geometry(rots, trans, intrins, post_rots, post_trans)
bev_feat = self.voxel_pooling(geom, volume)
return bev_feat