<a href="https://colab.research.google.com/github/GordonGustafson/cuda-stuff/blob/main/leetgpu/convolution_2d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch.utils.cpp_extension import load_inline

!pip install Ninja

In [None]:
cuda_src = """
#include <cuda_runtime.h>
#include <stdio.h>

#include <torch/extension.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

#define BLOCK_SIZE 32
#define MAX_KERNEL_AREA (16 * 1024)
#define cdiv(dividend, divisor) ((dividend + divisor - 1) / dividend)

__constant__ float kernel_constant[MAX_KERNEL_AREA];

__global__ void convolution_2d(float const* const input,
                               float* const output,
                               int const input_rows,
                               int const input_cols,
                               int const kernel_rows,
                               int const kernel_cols) {
    int const row = blockIdx.y * blockDim.y + threadIdx.y;
    int const col = blockIdx.x * blockDim.x + threadIdx.x;

    int const output_rows = input_rows - kernel_rows + 1;
    int const output_cols = input_cols - kernel_cols + 1;

    __shared__ float input_tile_shared[BLOCK_SIZE][BLOCK_SIZE];

    if (row < input_rows && col < input_cols) {
        input_tile_shared[threadIdx.y][threadIdx.x] = input[row * input_cols + col];
    }
    __syncthreads();

    if (row < output_rows && col < output_cols) {
        float result = 0.0f;
        for (int kernel_row = 0; kernel_row < kernel_rows; kernel_row++) {
            for (int kernel_col = 0; kernel_col < kernel_cols; kernel_col++) {
                int const row_within_block = threadIdx.y + kernel_row;
                int const col_within_block = threadIdx.x + kernel_col;
                if (row_within_block < BLOCK_SIZE && col_within_block < BLOCK_SIZE) {
                    result += input_tile_shared[row_within_block][col_within_block] * kernel_constant[kernel_row * kernel_cols + kernel_col];
                } else {
                    // If we're lucky this will be in cache due to other blocks having read it recently.
                    result += input[(row + kernel_row) * input_cols + col + kernel_col] * kernel_constant[kernel_row * kernel_cols + kernel_col];
                }
            }
        }

        output[row * output_cols + col] = result;
    }
}

// input, kernel, output are device pointers
void conv2d(const float* input, const float* kernel, float* output,
           int input_rows, int input_cols, int kernel_rows, int kernel_cols) {
    int const kernel_area = kernel_rows * kernel_cols;
    if (kernel_area > MAX_KERNEL_AREA) {
        printf("Kernel is larger than MAX_KERNEL_AREA constant");
        return;
    }
    cudaMemcpyToSymbol(kernel_constant, kernel, kernel_area * sizeof(float));

    dim3 const threadsPerBlock = dim3(BLOCK_SIZE, BLOCK_SIZE);
    dim3 const blocksPerGrid = dim3(cdiv(input_rows, threadsPerBlock.y),
                                    cdiv(input_cols, threadsPerBlock.x));
    convolution_2d<<<blocksPerGrid, threadsPerBlock>>>(input, output, input_rows, input_cols, kernel_rows, kernel_cols);
}

torch::Tensor conv2d_torch_tensors(torch::Tensor input, torch::Tensor filter) {
    CHECK_INPUT(input);
    CHECK_INPUT(filter);

    int const input_width = input.size(0);
    int const input_height = input.size(1);
    int const filter_width = filter.size(0);
    int const filter_height = filter.size(1);
    torch::Tensor output = torch::empty({input_width - filter_width + 1, input_height - filter_height + 1}, input.options());
    conv2d(input.data_ptr<float>(), filter.data_ptr<float>(), output.data_ptr<float>(), input_height, input_width, filter_height, filter_width);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}

"""

cpp_src = "torch::Tensor conv2d_torch_tensors(torch::Tensor input, torch::Tensor filter);"
functions = ['conv2d_torch_tensors']

module = load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=functions, extra_cuda_cflags=["-O2"], verbose=True, name="inline_ext")

In [None]:
input = torch.ones((128, 128), dtype=torch.float32, device='cuda')
filter = torch.ones((5, 5), dtype=torch.float32, device='cuda')
result = module.conv2d_torch_tensors(input, filter)
# torch.set_printoptions(profile="full")
print(result)

In [None]:
print(result.shape)