In [None]:
import sys
screen = torch.zeros((int(final_w), int(final_h), 3), device=device).float()
opacity_buffer = torch.ones((int(final_w), int(final_h)), device=device).float()

print(screen.shape)
print(screen.dtype)
print(f"{screen.element_size() * screen.nelement()} Bytes")
print(f"{(screen.element_size() * screen.nelement())/1024} KBytes")
print(f"{(screen.element_size() * screen.nelement())/1024**2} MBytes")
print(opacity_buffer.shape)

In [None]:
cuda_src = cuda_begin + r'''
__global__ void splat_kernel(
float* screen, 
float* opacity_buffer, 
float* sorted_depths, 
float* sorted_xy, 
float* sorted_final_covariance, 
float* sorted_conic, 
float* sorted_opacity, 
float* sorted_color, 
float* sorted_rect_min, 
float* sorted_rect_max, 
float* sorted_rect_area, 
int n
) {
    int i = blockIdx.x*blockDim.x + threadIdx.x;
    if (i<n){
        float sigma_x = sorted_conic[i][0][0];
        float sigma_y = sorted_conic[i][1][1];
        float sigma_x_y = sorted_conic[i][0][1]; 
        
        std::vector<int> x_grid(sorted_rect_max[i][0] - sorted_rect_min[i][0]);
        std::vector<int> y_grid(sorted_rect_max[i][1] - sorted_rect_min[i][1]);

        for (int x = sorted_rect_min[i][0]; x < sorted_rect_max[i][0]; ++x)
            x_grid.push_back(x);

        for (int y = sorted_rect_min[i][1]; y < sorted_rect_max[i][1]; ++y)
            y_grid.push_back(y);
            
        std::vector<std::vector<int>> mesh_x(x_grid.size(), std::vector<int>(y_grid.size()));
        std::vector<std::vector<int>> mesh_y(x_grid.size(), std::vector<int>(y_grid.size()));    
      
      
        for (int x = 0; x < x_grid.size(); ++x)
            for (int y = 0; y < y_grid.size(); ++y) {
                mesh_x[x][y] = x_grid[x];
                mesh_y[x][y] = y_grid[y];
            }

        std::vector<std::vector<std::vector<int>>> mesh(x_grid.size(), std::vector<std::vector<int>>(y_grid.size(), std::vector<int>(2)));

        for (int x = 0; x < x_grid.size(); ++x)
            for (int y = 0; y < y_grid.size(); ++y) {
                mesh[x][y][0] = mesh_x[x][y];
                mesh[x][y][1] = mesh_y[x][y];
            }
            
                std::vector<std::vector<float>> dist_to_mean(sorted_xy[i].size(), std::vector<float>(2));

        for (int d = 0; d < sorted_xy[i].size(); ++d)
            for (int c = 0; c < 2; ++c)
                dist_to_mean[d][c] = sorted_xy[i][d] - mesh[d][c];

        std::vector<float> gaussian_density(dist_to_mean.size());

        for (int d = 0; d < dist_to_mean.size(); ++d)
            gaussian_density[d] = (-0.5 * (sigma_x * pow(dist_to_mean[d][0], 2) + sigma_y * pow(dist_to_mean[d][1], 2)) - sigma_x_y * dist_to_mean[d][0] * dist_to_mean[d][1]);

        std::vector<float> alpha(gaussian_density.size());

        for (int d = 0; d < gaussian_density.size(); ++d)
            alpha[d] = std::min(sorted_opacity[i] * exp(gaussian_density[d]), static_cast<float>(0.99));

        std::vector<int> valid(alpha.size());

        for (int d = 0; d < alpha.size(); ++d)
            valid[d] = (alpha[d] > 0.003922) && (gaussian_density[d] <= 0);

        std::vector<std::vector<int>> valid_mesh;

        for (int d = 0; d < valid.size(); ++d)
            if (valid[d])
                valid_mesh.push_back(mesh[d][0]);

        for (int v = 0; v < valid_mesh.size(); ++v)
            for (int c = 0; c < 3; ++c)
                screen[valid_mesh[v][0]][valid_mesh[v][1]][c] +=
                    alpha[v] * sorted_color[i][c] * opacity_buffer[valid_mesh[v][0]][valid_mesh[v][1]][0];

        for (int v = 0; v < valid_mesh.size(); ++v)
            opacity_buffer[valid_mesh[v][0]][valid_mesh[v][1]] *= (1 - alpha[v]);
    }
}

torch::Tensor render(
    const torch::Tensor sorted_depths, 
    const torch::Tensor sorted_xy, 
    const torch::Tensor sorted_final_covariance, 
    const torch::Tensor sorted_conic, 
    const torch::Tensor sorted_opacity, 
    const torch::Tensor sorted_color, 
    const torch::Tensor sorted_rect_min, 
    const torch::Tensor sorted_rect_max, 
    const torch::Tensor sorted_rect_area
    const int w,
    const int h,
    const int n
) {
    CHECK_INPUT(sorted_depths);
    CHECK_INPUT(sorted_xy);
    CHECK_INPUT(sorted_final_covariance);
    CHECK_INPUT(sorted_conic);
    CHECK_INPUT(sorted_opacity);
    CHECK_INPUT(sorted_color);
    CHECK_INPUT(sorted_rect_min);
    CHECK_INPUT(sorted_rect_max);
    CHECK_INPUT(sorted_rect_area);
    CHECK_INPUT(w);
    CHECK_INPUT(h);
    CHECK_INPUT(n);
    
   // printf("h*w: %d*%d\n", h, w);
   
    auto screen = torch::zeros({w, h}, sorted_color.options());
    auto opacity_buffer = torch::ones({w, h}, sorted_color.options());
    
    int threads = 256;
    
    splat_kernel<<cdiv(n, threads), threads>>(
        screen.data_ptr<float>(), 
        opacity_buffer.data_ptr<float>(), 
        sorted_depths.data_ptr<float>(), 
        sorted_xy.data_ptr<float>(), 
        sorted_final_covariance.data_ptr<float>(), 
        sorted_conic, sorted_opacity.data_ptr<float>(), 
        sorted_color.data_ptr<float>(), 
        sorted_rect_min.data_ptr<float>(), 
        sorted_rect_max.data_ptr<float>(), 
        sorted_rect_area.data_ptr<float>(), 
        n
    );
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return screen;
}'''