/
soft_rasterize_cuda_kernel.cu
781 lines (679 loc) · 28.2 KB
/
soft_rasterize_cuda_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
// for the older gpus atomicAdd with double arguments does not exist
#if __CUDA_ARCH__ < 600 and defined(__CUDA_ARCH__)
static __inline__ __device__ double atomicAdd(double* address, double val) {
unsigned long long int* address_as_ull = (unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old);
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
namespace{
template <typename scalar_t>
__device__ __forceinline__ void barycentric_coordinate(scalar_t *w, const scalar_t x, const scalar_t y, const scalar_t *face_info) {
w[0] = face_info[3 * 0 + 0] * x + face_info[3 * 0 + 1] * y + face_info[3 * 0 + 2];
w[1] = face_info[3 * 1 + 0] * x + face_info[3 * 1 + 1] * y + face_info[3 * 1 + 2];
w[2] = face_info[3 * 2 + 0] * x + face_info[3 * 2 + 1] * y + face_info[3 * 2 + 2];
}
template <typename scalar_t>
__device__ __forceinline__ bool check_border(const scalar_t x, const scalar_t y, const scalar_t *face, const scalar_t threshold) {
return (x > max(max(face[0], face[3]), face[6]) + threshold ||
x < min(min(face[0], face[3]), face[6]) - threshold ||
y > max(max(face[1], face[4]), face[7]) + threshold ||
y < min(min(face[1], face[4]), face[7]) - threshold);
}
template <typename scalar_t>
__device__ __forceinline__ bool check_face_frontside(const scalar_t *face) {
return (face[7] - face[1]) * (face[3] - face[0]) < (face[4] - face[1]) * (face[6] - face[0]);
}
template <typename scalar_t>
__device__ __forceinline__ bool check_pixel_inside(const scalar_t *w) {
return w[0] <= 1 && w[0] >= 0 && w[1] <= 1 && w[1] >= 0 && w[2] <= 1 && w[2] >= 0;
}
template <typename scalar_t>
__device__ __forceinline__ void barycentric_clip(scalar_t *w) {
for (int k = 0; k < 3; k++) w[k] = max(min(w[k], 1.), 0.);
const scalar_t w_sum = max(w[0] + w[1] + w[2], 1e-5);
for (int k = 0; k < 3; k++) w[k] /= w_sum;
}
template <typename scalar_t>
__device__ __forceinline__ void euclidean_p2f_distance(scalar_t &sign, scalar_t &dis_x, scalar_t &dis_y,
scalar_t *w, scalar_t *t,
const scalar_t* face, const scalar_t *face_info,
const scalar_t xp, const scalar_t yp) {
const scalar_t *face_sym = face_info + 9;
const scalar_t *face_obt = face_info + 18;
if (w[0] > 0 && w[1] > 0 && w[2] > 0 &&
w[0] < 1 && w[1] < 1 && w[2] < 1) {
// inside the triangle, w[0] + w[1] + w[2] = 0
scalar_t dis_min = 100000000;
scalar_t dis_x_min = 0;
scalar_t dis_y_min = 0;
scalar_t a0[3];
scalar_t t0[3];
for (int k = 0; k < 3; k++) {
int v0 = k;
int v1 = (k + 1) % 3;
int v2 = (k + 2) % 3;
a0[0] = face_sym[3 * v0 + 0] - face_sym[3 * v1 + 0];
a0[1] = face_sym[3 * v0 + 1] - face_sym[3 * v1 + 1];
a0[2] = face_sym[3 * v0 + 2] - face_sym[3 * v1 + 2];
t0[v0] = (w[0] * a0[0] + w[1] * a0[1] + w[2] * a0[2] - a0[v1]) / (a0[v0] - a0[v1]);
t0[v1] = 1 - t0[v0];
t0[v2] = 0;
t0[0] -= w[0];
t0[1] -= w[1];
t0[2] -= w[2];
// calculate distance
dis_x = t0[0] * face[0] + t0[1] * face[3] + t0[2] * face[6];
dis_y = t0[0] * face[1] + t0[1] * face[4] + t0[2] * face[7];
scalar_t dis = dis_x * dis_x + dis_y * dis_y;
if (dis < dis_min) {
dis_min = dis;
dis_x_min = dis_x;
dis_y_min = dis_y;
t[0] = t0[0];
t[1] = t0[1];
t[2] = t0[2];
}
}
dis_x = dis_x_min;
dis_y = dis_y_min;
sign = 1;
} else {
int v0 = -1;
if (w[1] <= 0 && w[2] <= 0) {
v0 = 0;
if (face_obt[0] == 1 && (xp - face[0]) * (face[6] - face[0]) + (yp - face[1]) * (face[7] - face[1]) > 0) v0 = 2;
} else if (w[2] <= 0 && w[0] <= 0) {
v0 = 1;
if (face_obt[1] == 1 && (xp - face[3]) * (face[0] - face[3]) + (yp - face[4]) * (face[1] - face[4]) > 0) v0 = 0;
} else if (w[0] <= 0 && w[1] <= 0) {
v0 = 2;
if (face_obt[2] == 1 && (xp - face[6]) * (face[3] - face[6]) + (yp - face[7]) * (face[4] - face[7]) > 0) v0 = 1;
} else
if (w[0] <= 0) v0 = 1;
else if (w[1] <= 0) v0 = 2;
else if (w[2] <= 0) v0 = 0;
const int v1 = (v0 + 1) % 3;
const int v2 = (v0 + 2) % 3;
scalar_t a0[3];
a0[0] = face_sym[3 * v0 + 0] - face_sym[3 * v1 + 0];
a0[1] = face_sym[3 * v0 + 1] - face_sym[3 * v1 + 1];
a0[2] = face_sym[3 * v0 + 2] - face_sym[3 * v1 + 2];
t[v0] = (w[0] * a0[0] + w[1] * a0[1] + w[2] * a0[2] - a0[v1]) / (a0[v0] - a0[v1]);
t[v1] = 1 - t[v0];
t[v2] = 0;
// clamp to [0, 1]
for (int k = 0; k < 3; k++) {
t[k] = min(max(t[k], 0.), 1.);
t[k] -= w[k];
}
// calculate distance
dis_x = t[0] * face[0] + t[1] * face[3] + t[2] * face[6];
dis_y = t[0] * face[1] + t[1] * face[4] + t[2] * face[7];
sign = -1;
}
}
template <typename scalar_t>
__device__ __forceinline__ void forward_barycentric_p2f_distance(scalar_t &dis, const scalar_t *w) {
dis = w[0] > w[1] ? (w[1] > w[2] ? w[2] : w[1]) : (w[0] > w[2] ? w[2] : w[0]);
dis = dis > 0 ? pow(dis, 2) : -pow(dis, 2);
}
template <typename scalar_t>
__device__ __forceinline__ void backward_barycentric_p2f_distance(scalar_t grad_v[3][3], const scalar_t *w, const scalar_t *face_info, const scalar_t xp, const scalar_t yp, const scalar_t dis, const scalar_t C) {
const int p = w[0] > w[1] ? (w[1] > w[2] ? 2 : 1) : (w[0] > w[2] ? 2 : 0);
const scalar_t *face_inv = face_info;
for (int l = 0; l < 2; l++) {
for (int k = 0; k < 3; k++) {
scalar_t grad_kl = 0;
for (int q = 0; q < 3; q++) {
grad_kl += -face_inv[3*p+l] * face_inv[3*k+q] * (q == 0 ? xp : (q == 1 ? yp : 1));
}
grad_v[k][l] = grad_kl * C;
grad_v[k][l] *= dis > 0 ? (2. * sqrt(dis)) : (2. * sqrt(-dis));
}
}
}
template <typename scalar_t>
__device__ __forceinline__ scalar_t forward_sample_texture(const scalar_t *texture, const scalar_t *w, const int R, const int k, const int texture_sample_type) {
scalar_t texture_k = 0;
if (texture_sample_type == 0) { // sample surface color with resolution as R
const int w_x = w[0] * R;
const int w_y = w[1] * R;
if ((w[0] + w[1]) * R - w_x - w_y <= 1) {
texture_k = texture[(w_y * R + w_x) * 3 + k];
} else {
texture_k = texture[((R - 1 - w_y) * R + (R - 1 - w_x)) * 3 + k];
}
} else
if (texture_sample_type == 1) { // sample vertex color
texture_k = w[0] * texture[k] + w[1] * texture[3+k] + w[2] * texture[6+k];
}
return texture_k;
}
template <typename scalar_t>
__device__ __forceinline__ scalar_t backward_sample_texture(const scalar_t grad_color, const scalar_t *w, const int R, const int k, const int texture_sample_type) {
scalar_t grad_texture_k = 0;
if (texture_sample_type == 0) { // sample surface color with resolution as R
const int w_x = w[0] * R;
const int w_y = w[1] * R;
if ((w[0] + w[1]) * R - w_x - w_y <= 1) {
if (k == w_y * R + w_x) {
grad_texture_k = grad_color;
}
} else {
if (k == (R - 1 - w_y) * R + (R - 1 - w_x)) {
grad_texture_k = grad_color;
}
}
} else
if (texture_sample_type == 1) {
grad_texture_k = w[k] * grad_color;
}
return grad_texture_k;
}
// triangle preprocessing
template <typename scalar_t>
__global__ void forward_soft_rasterize_inv_cuda_kernel(
const scalar_t* __restrict__ faces,
scalar_t* faces_info,
int batch_size,
int num_faces,
int image_size) {
/* batch number, face, number, image size, face[v012][RGB] */
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= batch_size * num_faces) {
return;
}
// const int is = image_size;
const scalar_t* face = &faces[i * 9];
scalar_t* face_inv = &faces_info[i * 27];
scalar_t* face_sym = &faces_info[i * 27+9];
scalar_t* face_obt = &faces_info[i * 27+18];
/* return if backside */
// if ((face[7] - face[1]) * (face[3] - face[0]) < (face[4] - face[1]) * (face[6] - face[0]))
// return;
/* p[num][xy]: x, y is (-1, 1). */
scalar_t p[3][2];
for (int num = 0; num < 3; num++) {
for (int dim = 0; dim < 2; dim++) {
p[num][dim] = face[3 * num + dim]; // no normalize
}
}
/* compute face_inv */
scalar_t face_inv_star[9] = {
p[1][1] - p[2][1], p[2][0] - p[1][0], p[1][0] * p[2][1] - p[2][0] * p[1][1],
p[2][1] - p[0][1], p[0][0] - p[2][0], p[2][0] * p[0][1] - p[0][0] * p[2][1],
p[0][1] - p[1][1], p[1][0] - p[0][0], p[0][0] * p[1][1] - p[1][0] * p[0][1]};
scalar_t face_inv_determinant = (
p[2][0] * (p[0][1] - p[1][1]) +
p[0][0] * (p[1][1] - p[2][1]) +
p[1][0] * (p[2][1] - p[0][1]));
face_inv_determinant = face_inv_determinant > 0 ? max(face_inv_determinant, 1e-10) : min(face_inv_determinant, -1e-10);
/* set to global memory */
for (int k = 0; k < 9; k++) {
face_inv[k] = face_inv_star[k] / face_inv_determinant;
}
/* F * F.T */
for (int j = 0; j < 3; j++) {
for (int k = 0; k < 3; k++) {
face_sym[j * 3 + k] = face[j * 3 + 0] * face[k * 3 + 0] +
face[j * 3 + 1] * face[k * 3 + 1] +
1;
}
}
/* check if one arc is obt arc */
for (int k = 0; k < 3; k++) {
const int k0 = k;
const int k1 = (k + 1) % 3;
const int k2 = (k + 2) % 3;
if ((p[k1][0] - p[k0][0]) * (p[k2][0] - p[k0][0]) + (p[k1][1] - p[k0][1]) * (p[k2][1] - p[k0][1]) < 0) {
face_obt[k0] = 1;
break;
}
}
}
template <typename scalar_t>
__global__ void forward_soft_rasterize_cuda_kernel(
const scalar_t* __restrict__ faces,
const scalar_t* __restrict__ textures,
const scalar_t* __restrict__ faces_info,
scalar_t* aggrs_info,
scalar_t* soft_colors,
int batch_size,
int num_faces,
int image_size,
int texture_size,
int texture_res,
float near,
float far,
float eps,
float sigma_val,
int func_id_dist,
float dist_eps,
float gamma_val,
int func_id_rgb,
int func_id_alpha,
int texture_sample_type,
bool double_side) {
////////////////////////
////////////////////////
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= batch_size * image_size * image_size) {
return;
}
const int is = image_size;
const int nf = num_faces;
const int bn = i / (is * is);
const int pn = i % (is * is);
const int yi = is - 1 - (pn / is);
const int xi = pn % is;
const scalar_t yp = (2. * yi + 1. - is) / is;
const scalar_t xp = (2. * xi + 1. - is) / is;
const scalar_t *face = &faces[bn * nf * 9] - 9;
const scalar_t *texture = &textures[bn * nf * texture_size * 3] - texture_size * 3;
const scalar_t *face_info = &faces_info[bn * nf * 27] - 27;
const scalar_t threshold = dist_eps * sigma_val;
// Initialize pixel color
scalar_t soft_color[4] = {1., 1., 1., 0.};
if (func_id_alpha == 2) soft_color[3] = 1.;
scalar_t softmax_sum = exp(eps / gamma_val);
scalar_t softmax_max = eps;
for (int k = 0; k < 3; k++) {
if (func_id_rgb == 0) { // hard assign, set to background
soft_color[k] = soft_colors[(bn * 4 + k) * (is * is) + pn];
} else
if (func_id_rgb == 1) {
soft_color[k] = soft_colors[(bn * 4 + k) * (is * is) + pn] * softmax_sum; // initialize background color
}
}
scalar_t depth_min = 10000000;
int face_index_min = -1;
for (int fn = 0; fn < nf; fn++) {
face += 9;
texture += texture_size * 3;
face_info += 27;
if (check_border(xp, yp, face, sqrt(threshold))) continue; // triangle too far away from pixel
scalar_t dis;
scalar_t dis_x;
scalar_t dis_y;
scalar_t t[3];
scalar_t w[3];
scalar_t w_clip[3];
scalar_t sign;
scalar_t soft_fragment;
// compute barycentric coordinate w
barycentric_coordinate(w, xp, yp, face_info);
// compute probability map based on distance functions
if (func_id_dist == 0) { // hard assign
soft_fragment = check_pixel_inside(w) ? 1. : 0.;
if (soft_fragment == 0.) continue; // ignore triangle outside of the pixel
} else
if (func_id_dist == 1) { // barycentric distance
forward_barycentric_p2f_distance(dis, w);
if (-dis >= threshold) continue; // ignore triangle far away from the pixel
soft_fragment = 1. / (1. + exp(-dis / sigma_val));
} else
if (func_id_dist == 2) { // euclidean distance
euclidean_p2f_distance(sign, dis_x, dis_y, w, t, face, face_info, xp, yp);
dis = dis_x * dis_x + dis_y * dis_y;
if (sign < 0 && dis >= threshold) continue; // ignore triangle far away from the pixel
soft_fragment = 1. / (1. + exp(-sign * dis / sigma_val));
}
/////////////////////////////////////////////////////
// aggragate for alpha channel
if (func_id_alpha == 0) { // hard assign
if (soft_fragment > 0.5) soft_color[3] = 1.;
} else
if (func_id_alpha == 1) { // Sum
soft_color[3] += soft_fragment;
} else
if (func_id_alpha == 2) { // Logical-Or
soft_color[3] *= 1. - soft_fragment;
}
/////////////////////////////////////////////////////
for (int k = 0; k < 3; k++) w_clip[k] = w[k];
barycentric_clip(w_clip);
const scalar_t zp = 1. / (w_clip[0] / face[2] + w_clip[1] / face[5] + w_clip[2] / face[8]);
if (zp < near || zp > far) continue; // triangle out of screen, pass
/////////////////////////////////////////////////////
// aggregate for rgb channels
if (func_id_rgb == 0) { // Hard assign
if (zp < depth_min && check_pixel_inside(w) && (double_side || check_face_frontside(face))) {
depth_min = zp;
face_index_min = fn;
for (int k = 0; k < 3; k++) {
soft_color[k] = forward_sample_texture(texture, w_clip, texture_res, k, texture_sample_type);
}
}
} else
if (func_id_rgb == 1) { // D * Softmax (Z)
if (check_face_frontside(face) || double_side) {
const scalar_t zp_norm = (far - zp) / (far - near);
scalar_t exp_delta_zp = 1.;
if (zp_norm > softmax_max) {
exp_delta_zp = exp((softmax_max - zp_norm) / gamma_val);
softmax_max = zp_norm;
}
const scalar_t exp_z = exp((zp_norm - softmax_max) / gamma_val);
softmax_sum = exp_delta_zp * softmax_sum + exp_z * soft_fragment;
for (int k = 0; k < 3; k++) {
const scalar_t color_k = forward_sample_texture(texture, w_clip, texture_res, k, texture_sample_type);
soft_color[k] = exp_delta_zp * soft_color[k] + exp_z * soft_fragment * color_k;// * soft_fragment;
}
}
}
}
//////////////////////////////////////////////
// finalize aggregation
if (func_id_alpha == 0) {
soft_colors[(bn * 4 + 3) * (is * is) + pn] = soft_color[3];
} else
if (func_id_alpha == 1) {
soft_colors[(bn * 4 + 3) * (is * is) + pn] = soft_color[3] / nf;
} else
if (func_id_alpha == 2) {
soft_colors[(bn * 4 + 3) * (is * is) + pn] = 1. - soft_color[3];
}
if (func_id_rgb == 0) {
if (face_index_min != -1)
for (int k = 0; k < 3; k++) {
soft_colors[(bn * 4 + k) * (is * is) + pn] = soft_color[k];
}
aggrs_info[(bn * 2 + 0) * (is * is) + pn] = depth_min;
aggrs_info[(bn * 2 + 1) * (is * is) + pn] = face_index_min;
} else
if (func_id_rgb == 1) {
for (int k = 0; k < 3; k++) {
soft_colors[(bn * 4 + k) * (is * is) + pn] = soft_color[k] / softmax_sum;
}
aggrs_info[(bn * 2 + 0) * (is * is) + pn] = softmax_sum;
aggrs_info[(bn * 2 + 1) * (is * is) + pn] = softmax_max;
}
}
template <typename scalar_t>
__global__ void backward_soft_rasterize_cuda_kernel(
const scalar_t* __restrict__ faces,
const scalar_t* __restrict__ textures,
const scalar_t* __restrict__ soft_colors,
const scalar_t* __restrict__ faces_info,
const scalar_t* __restrict__ aggrs_info, // 0: sum, 1: max z*D
scalar_t* grad_faces,
scalar_t* grad_textures,
scalar_t* grad_soft_colors,
int batch_size,
int num_faces,
int image_size,
int texture_size,
int texture_res,
float near,
float far,
float eps,
float sigma_val,
int func_id_dist,
float dist_eps,
float gamma_val,
int func_id_rgb,
int func_id_alpha,
int texture_sample_type,
bool double_side) {
////////////////////////
////////////////////////
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= batch_size * image_size * image_size) {
return;
}
const int is = image_size;
const int nf = num_faces;
const int bn = i / (is * is);
const int pn = i % (is * is);
const int yi = is - 1 - (pn / is);
const int xi = pn % is;
const scalar_t yp = (2. * yi + 1 - is) / is;
const scalar_t xp = (2. * xi + 1 - is) / is;
const scalar_t* face = &faces[bn * nf * 9] - 9;
const scalar_t* texture = &textures[bn * nf * texture_size * 3] - texture_size * 3;
const scalar_t* face_info = &faces_info[bn * nf * 27] - 27;
const scalar_t threshold = dist_eps * sigma_val;
const scalar_t softmax_sum = aggrs_info[(bn * 2 + 0) * (is * is) + pn];
const scalar_t softmax_max = aggrs_info[(bn * 2 + 1) * (is * is) + pn];
for (int fn = 0; fn < nf; fn++) {
face += 9;
texture += texture_size * 3;
face_info += 27;
if (check_border(xp, yp, face, sqrt(threshold))) continue;
scalar_t dis;
scalar_t dis_x;
scalar_t dis_y;
scalar_t t[3];
scalar_t w[3];
scalar_t w0[3];
scalar_t sign;
scalar_t soft_fragment;
barycentric_coordinate(w, xp, yp, face_info);
// compute probability map based on distance functions
if (func_id_dist == 0) { // hard assign
soft_fragment = check_pixel_inside(w) ? 1. : 0.;
if (soft_fragment == 0.) continue; // ???
} else
if (func_id_dist == 1) { // barycentric distance
forward_barycentric_p2f_distance(dis, w);
for (int k = 0; k < 3; k++) t[k] = w[k];
if (-dis >= threshold) continue; // ignore triangle far away from the pixel
soft_fragment = 1. / (1. + exp(-dis / sigma_val));
} else
if (func_id_dist == 2) { // euclidean distance
euclidean_p2f_distance(sign, dis_x, dis_y, w, t, face, face_info, xp, yp);
dis = dis_x * dis_x + dis_y * dis_y;
if (sign < 0 && dis >= threshold) continue; // ignore triangle too far away from the pixel
soft_fragment = 1. / (1. + exp(-sign * dis / sigma_val));
}
scalar_t* grad_face = &grad_faces[(bn * nf + fn) * 9];
scalar_t* grad_texture = &grad_textures[(bn * nf + fn) * texture_size * 3];
scalar_t grad_v[3][3] = {0};
scalar_t C_grad_xy = 0;
/////////////////////////////////////////////////////
// aggragate for alpha channel
scalar_t C_grad_xy_alpha = grad_soft_colors[(bn * 4 + 3) * (is * is) + pn];
if (func_id_alpha == 0) { // hard assign
// hard assign alpha channels does not have gradient
} else
if (func_id_alpha == 1) { // Sum
C_grad_xy_alpha /= nf;
} else
if (func_id_alpha == 2) { // Logical-Or
C_grad_xy_alpha *= (1 - soft_colors[(bn * 4 + 3) * (is * is) + pn]) / max(1 - soft_fragment, 1e-6);
}
C_grad_xy += C_grad_xy_alpha;
/////////////////////////////////////////////////////
for (int k = 0; k < 3; k++) w0[k] = w[k];
barycentric_clip(w);
const scalar_t zp = 1. / (w[0] / face[2] + w[1] / face[5] + w[2] / face[8]);
if (zp < near || zp > far) continue; // triangle out of screen, pass
// aggregate for rgb channels
if (func_id_rgb == 0) { // Hard assign, no gradient to xyz
if (fn == softmax_max) {
for (int k = 0; k < 3; k++) {
for (int j = 0; j < texture_size; j++) {
atomicAdd(&grad_texture[3 * j + k], backward_sample_texture(grad_soft_colors[(bn * 4 + k) * (is * is) + pn], w, texture_res, j, texture_sample_type));
}
}
}
} else
if (func_id_rgb == 1 && (check_face_frontside(face) || double_side)) { // Softmax (Z * D)
scalar_t C_grad_xyz_rgb = 0.;
const scalar_t zp_norm = (far - zp) / (far - near);
const scalar_t zp_softmax = soft_fragment * exp((zp_norm - softmax_max) / gamma_val) / softmax_sum;
for (int k = 0; k < 3; k++) {
const scalar_t grad_soft_color_k = grad_soft_colors[(bn * 4 + k) * (is * is) + pn];
for (int j = 0; j < texture_size; j++) {
const scalar_t grad_t = backward_sample_texture(grad_soft_color_k, w, texture_res, j, texture_sample_type);
atomicAdd(&grad_texture[3 * j + k], zp_softmax * grad_t);
}
const scalar_t color_k = forward_sample_texture(texture, w, texture_res, k, texture_sample_type);
C_grad_xyz_rgb += grad_soft_color_k * (color_k - soft_colors[(bn * 4 + k) * (is * is) + pn]);
}
C_grad_xyz_rgb *= zp_softmax;
C_grad_xy += C_grad_xyz_rgb / soft_fragment;
const scalar_t C_grad_z_rgb = C_grad_xyz_rgb / gamma_val / (near - far) * zp * zp;
grad_v[0][2] = C_grad_z_rgb * w[0] / face[2] / face[2];
grad_v[1][2] = C_grad_z_rgb * w[1] / face[5] / face[5];
grad_v[2][2] = C_grad_z_rgb * w[2] / face[8] / face[8];
}
/////////////////////////////////////////////////////
C_grad_xy *= soft_fragment * (1 - soft_fragment) / sigma_val; // sigmoid gradient
// compute probability map gradient based on distance functions
if (func_id_dist == 1) { // barycentric distance
backward_barycentric_p2f_distance(grad_v, t, face_info, xp, yp, dis, C_grad_xy);
} else
if (func_id_dist == 2) { // euclidean distance
for (int k = 0; k < 3; k++) {
for (int l = 0; l < 2; l++) {
grad_v[k][l] = 2 * sign * C_grad_xy * (t[k] + w0[k]) * (l == 0 ? dis_x : dis_y);
}
}
}
atomicAdd(&grad_face[0], grad_v[0][0]);
atomicAdd(&grad_face[1], grad_v[0][1]);
atomicAdd(&grad_face[3], grad_v[1][0]);
atomicAdd(&grad_face[4], grad_v[1][1]);
atomicAdd(&grad_face[6], grad_v[2][0]);
atomicAdd(&grad_face[7], grad_v[2][1]);
atomicAdd(&grad_face[2], grad_v[0][2]);
atomicAdd(&grad_face[5], grad_v[1][2]);
atomicAdd(&grad_face[8], grad_v[2][2]);
}
}
}
std::vector<at::Tensor> forward_soft_rasterize_cuda(
at::Tensor faces,
at::Tensor textures,
at::Tensor faces_info,
at::Tensor aggrs_info,
at::Tensor soft_colors,
int image_size,
float near,
float far,
float eps,
float sigma_val,
int func_id_dist,
float dist_eps,
float gamma_val,
int func_id_rgb,
int func_id_alpha,
int texture_sample_type,
bool double_side) {
const auto batch_size = faces.size(0);
const auto num_faces = faces.size(1);
const auto texture_size = textures.size(2);
const auto texture_res = int(sqrt(texture_size));
const int threads = 512;
const dim3 blocks_1 ((batch_size * num_faces - 1) / threads +1);
AT_DISPATCH_FLOATING_TYPES(faces.type(), "forward_soft_rasterize_inv_cuda", ([&] {
forward_soft_rasterize_inv_cuda_kernel<scalar_t><<<blocks_1, threads>>>(
faces.data<scalar_t>(),
faces_info.data<scalar_t>(),
batch_size,
num_faces,
image_size);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
printf("Error in forward_transform_inv_triangle: %s\n", cudaGetErrorString(err));
const dim3 blocks_2 ((batch_size * image_size * image_size - 1) / threads +1);
AT_DISPATCH_FLOATING_TYPES(faces.type(), "forward_eff_soft_rasterize_cuda", ([&] {
forward_soft_rasterize_cuda_kernel<scalar_t><<<blocks_2, threads>>>(
faces.data<scalar_t>(),
textures.data<scalar_t>(),
faces_info.data<scalar_t>(),
aggrs_info.data<scalar_t>(),
soft_colors.data<scalar_t>(),
batch_size,
num_faces,
image_size,
texture_size,
texture_res,
near,
far,
eps,
sigma_val,
func_id_dist,
dist_eps,
gamma_val,
func_id_rgb,
func_id_alpha,
texture_sample_type,
double_side);
}));
err = cudaGetLastError();
if (err != cudaSuccess)
printf("Error in forward_soft_rasterize: %s\n", cudaGetErrorString(err));
return {faces_info, aggrs_info, soft_colors};
}
std::vector<at::Tensor> backward_soft_rasterize_cuda(
at::Tensor faces,
at::Tensor textures,
at::Tensor soft_colors,
at::Tensor faces_info,
at::Tensor aggrs_info,
at::Tensor grad_faces,
at::Tensor grad_textures,
at::Tensor grad_soft_colors,
int image_size,
float near,
float far,
float eps,
float sigma_val,
int func_id_dist,
float dist_eps,
float gamma_val,
int func_id_rgb,
int func_id_alpha,
int texture_sample_type,
bool double_side) {
const auto batch_size = faces.size(0);
const auto num_faces = faces.size(1);
const auto texture_size = textures.size(2);
const auto texture_res = int(sqrt(texture_size));
const int threads = 512;
const dim3 blocks ((batch_size * image_size * image_size - 1) / threads + 1);
AT_DISPATCH_FLOATING_TYPES(faces.type(), "backward_soft_rasterize_cuda", ([&] {
backward_soft_rasterize_cuda_kernel<scalar_t><<<blocks, threads>>>(
faces.data<scalar_t>(),
textures.data<scalar_t>(),
soft_colors.data<scalar_t>(),
faces_info.data<scalar_t>(),
aggrs_info.data<scalar_t>(),
grad_faces.data<scalar_t>(),
grad_textures.data<scalar_t>(),
grad_soft_colors.data<scalar_t>(),
batch_size,
num_faces,
image_size,
texture_size,
texture_res,
near,
far,
eps,
sigma_val,
func_id_dist,
dist_eps,
gamma_val,
func_id_rgb,
func_id_alpha,
texture_sample_type,
double_side);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
printf("Error in backward_soft_rasterize: %s\n", cudaGetErrorString(err));
return {grad_faces, grad_textures};
}