<a href="https://colab.research.google.com/github/JTStephens18/CUDA_Playground/blob/main/CUDA_Playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import random
import numpy as np
import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from scipy.interpolate import interpn

In [2]:
# Slows things down, but good for development since it stops when there is an error
os.environ['CUDA_LAUNCH_BLOCKING']='1'

In [3]:
# Wurlitzer allows things to be printed from C++/CUDA code in a notebook
# Ninja is a build tool required by pytorch to compile C++/CUDA code
%pip install -q wurlitzer ninja

In [4]:
%load_ext wurlitzer

In [5]:
"""
load_inline is a great function that takes in
  a list of any of the cuda code strings you want to compile (cuda_sources)
  any plain cpp strings you want to compile (cpp_sources)
  any functions in the cpp strings you want to make available to pytorch (functions)
that compiles it all and turns it into a python module
"""
from torch.utils.cpp_extension import load_inline

In [6]:
def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False):
  return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                     extra_cuda_cflags=["-O2"] if opt else [], verbose=verbose, name='inline_ext')

In [7]:
cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
// Checks input is contiguous in memory
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

// Ceiling division - which we can use to figure out how many blocks we need
inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
'''

In [None]:
"""
  Timestep: 33:30 in Getting started with CUDA for Python Programmers
  Use __global__ anytime we want to call something from the CPU to run on the GPU
  Ex: __global__ void func(int x) {}

  To call a CUDA kernel:
  func<<<numBlocks, numThreads>>> (
    arguments
  );

  To check for an error call:
  C10_CUDA_KERNEL_LAUNCH_CHECK();
  Always call it after calling a kernel to make sure no errors

  Must be careful when running a function from a CUDA kernel that it has finished
  Can check this by printing a value or .cpu() will wait for the kernel to finish and put it onto cpu
"""

## Neural Kernel Code

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def createGramMatrix(points_i, points_j):
  """
    Input: A set of points_i and points_j [batch_size, num_points, num_values] Ex: [16, 2048, 3] and [16, 5000, 3]
    Output: A matrix of size [batch_size, num_points_i, num_points_j, 2, 3]
            Which represents a set of values at each ij index
  """
  jdx = torch.arange(points_j.shape[1]).unsqueeze(0).repeat(points_i.shape[1], 1)
  idx = torch.arange(points_i.shape[1]).unsqueeze(1).repeat(1, points_j.shape[1])
  # pairs = torch.cat((points_i[:, idx, :], points_j[:, jdx, :]), dim=3).reshape(points_i.shape[0], points_i.shape[1], points_j.shape[1], 2, -1)
  print(idx.shape)
  pairs = torch.cat((points_i[:, idx, :], points_j[:, jdx, :]), dim=3)
  print(pairs.shape)
  return pairs

In [None]:
input1 = torch.randn(2, 2000, 3).to(device)
input2 = torch.randn(2, 2000, 3).to(device)
# pairsTensor = torch.zeros(4, 2000, 2000, 32).to(device)
matrix = createGramMatrix(input1, input2)
torch.cuda.empty_cache()
# print(matrix.shape)

torch.Size([2000, 2000])
torch.Size([2, 2000, 2000, 64])


In [None]:
"""
  To modify createGramMatrix to use a CUDA kernel:
    First it should not return anything, only modify values in a premade tensor
      Could hopefully only create this tensor once for the entire training loop
      and just continuously update the values in the tensor.
      Ideally would help GPU memory since it is only allocated once
    For loop/logic to that concatenates values and adds them into the premade tensor
"""

In [26]:
batch_size = 4
pairs = torch.randn(batch_size, 200, 200, 6).to(device)
input1 = torch.randn(batch_size, 200, 3).to(device)
input2 = torch.randn(batch_size, 200, 3).to(device)

In [7]:
print(pairs[0,0,0])

tensor([ 0.0756, -1.7596, -0.7277,  0.9198, -0.5163,  0.3295], device='cuda:0')


In [16]:
print(pairs[1,12,17, 2])
print(input1[1,12,2])
print(input2[1,17,2])
flat = pairs.flatten()
print(flat.shape)

flat1 = input1.flatten()
flat2 = input2.flatten()

# for i in range(pairs.shape[0]):
#   for j in range(pairs.shape[1]):
#     for k in range(pairs.shape[2]):
#       for l in range(pairs.shape[3]):
#         idx = i * (pairs.shape[1] * pairs.shape[2] * pairs.shape[3]) +
#         j * (pairs.shape[2] * pairs.shape[3]) + k * (pairs.shape[3]) + l


i = 1
j = 12
k = 17
l = 2
# Formula to calculate index of a flattened tensor given the dimensions of original multidimensional tensor and indexes
idx = i * (pairs.shape[1] * pairs.shape[2] * pairs.shape[3]) + j * (pairs.shape[2] * pairs.shape[3]) + k * (pairs.shape[3]) + l
input_idx = i * (input1.shape[1] * input1.shape[2]) + j * (input1.shape[2]) + l
input_idx2 = i * (input2.shape[1] * input2.shape[2]) + k * (input2.shape[2]) + l
print(idx)
print(flat[idx])
print(input_idx)
print(flat1[input_idx])
print(input_idx2)
print(flat2[input_idx2])

tensor(-1.0341, device='cuda:0')
tensor(1.3791, device='cuda:0')
tensor(0.9200, device='cuda:0')
torch.Size([9600])
3944
tensor(-1.0341, device='cuda:0')
98
tensor(1.3791, device='cuda:0')
113
tensor(0.9200, device='cuda:0')


In [27]:
%%time
pairs[:, :200, :, :3] = input1.unsqueeze(2)
pairs[:, :, :200, 3:] = input2.unsqueeze(1)

CPU times: user 8.69 ms, sys: 2.99 ms, total: 11.7 ms
Wall time: 15.6 ms


In [None]:
print(pairs.shape)
print(pairs[0,0,0])

torch.Size([4, 20, 20, 6])
tensor([-1.0355, -1.0061,  0.2309,  0.7025,  0.1076,  1.2185], device='cuda:0')


In [22]:
cuda_src = cuda_begin + r'''

#define BATCH 4
#define DIM1 200
#define DIM2 200
#define DIM3 3

__global__ void concat_kernel(float* input1, float* input2, float* out, int h, int w, int batch)  {
    int r = blockIdx.y * blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;
    if(r>=h || c>= w) return;
    for (int i = 0; i < batch; i++) {
      for (int j = 0; j < DIM1; j++) {
        for (int k = 0; k < DIM2; k++) {
          for (int l = 0; l < DIM3; l++) {
            int idx = i * (DIM1 * DIM2 * DIM3) + j * (DIM2 * DIM3) + k * (2*DIM3) + l;
            int idx2 = i * (DIM1 * DIM2 * DIM3) + j * (DIM2 * DIM3) + k * (2*DIM3) + l+3;
            int input_idx_1 = i * (DIM1 * DIM3) + j * (DIM3) + l;
            int input_idx_2 = i * (DIM2 * DIM3) + k * (DIM3) + l;
            out[idx] = input1[input_idx_1];
            out[idx2] = input2[input_idx_2];
          }
        }
      }
    }
}

torch::Tensor concat(torch::Tensor input1, torch::Tensor input2, torch::Tensor output) {
  CHECK_INPUT(input1);
  CHECK_INPUT(input2);
  CHECK_INPUT(output);
  int batch = BATCH;
  int h = DIM1;
  int w = DIM2;
  dim3 tpb(16, 16);
  dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
  concat_kernel<<<blocks, tpb>>>(
    input1.data_ptr<float>(), input2.data_ptr<float>(), output.data_ptr<float>(), h,w, batch);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
  return output;
}
'''

In [23]:
cpp_src = "torch::Tensor concat(torch::Tensor input1, torch::Tensor input2, torch::Tensor output);"

module = load_cuda(cuda_src, cpp_src, ['concat'])

In [24]:
out = torch.zeros(4, 200, 200, 6).contiguous().to(device)
input1 = torch.randn(4, 200, 3).contiguous().to(device)
input2 = torch.randn(4, 200, 3).contiguous().to(device)

In [29]:
%%time
res = module.concat(input1.flatten(), input2.flatten(), out.flatten()).cpu()
print(res.shape)

torch.Size([960000])
CPU times: user 215 ms, sys: 2.31 ms, total: 217 ms
Wall time: 217 ms


In [None]:
def trilinear_interpolation(query_points, grid):
    # Extract the coordinates of the eight surrounding vertices
    grid = grid.permute(0,2,3,4,1)
    query_points_floor = query_points.floor().long() - 2
    # print(query_points_floor.min())
    x0, y0, z0 = query_points_floor[:,:,0], query_points_floor[:,:,1], query_points_floor[:,:,2]
    x1, y1, z1 = x0 + 1, y0 + 1, z0 + 1

    batch_enum = torch.arange(query_points.shape[0]).unsqueeze(1)

    # Extract the values at the eight surrounding vertices
    # c000 = grid[:, :, x0, y0, z0]
    c000 = grid[batch_enum, x0, y0, z0]
    c001 = grid[batch_enum, x0, y0, z1]
    c010 = grid[batch_enum, x0, y1, z0]
    c011 = grid[batch_enum, x0, y1, z1]
    c100 = grid[batch_enum, x1, y0, z0]
    c101 = grid[batch_enum, x1, y0, z1]
    c110 = grid[batch_enum, x1, y1, z0]
    c111 = grid[batch_enum, x1, y1, z1]

    # print(c000.shape)
    # print(c001.shape)
    # print(c010.shape)
    # print(c011.shape)
    # print(c100.shape)
    # print(c101.shape)
    # print(c110.shape)
    # print(c111.shape)

    # Compute the interpolation weights and add 1s to match the last dimension of c000 ... c111
    u = (query_points[:,:,0] - x0.float()).unsqueeze(-1).expand(query_points.shape[0],query_points.shape[1], grid.shape[4])
    v = (query_points[:,:,1] - y0.float()).unsqueeze(-1).expand(query_points.shape[0],query_points.shape[1], grid.shape[4])
    w = (query_points[:,:,2] - z0.float()).unsqueeze(-1).expand(query_points.shape[0],query_points.shape[1], grid.shape[4])

    # print(u.shape)
    # print(v.shape)
    # print(w.shape)

    # Perform trilinear interpolation
    interpolated_value = (1 - u) * (1 - v) * (1 - w) * c000 + \
                         (1 - u) * (1 - v) * w * c001 + \
                         (1 - u) * v * (1 - w) * c010 + \
                         (1 - u) * v * w * c011 + \
                         u * (1 - v) * (1 - w) * c100 + \
                         u * (1 - v) * w * c101 + \
                         u * v * (1 - w) * c110 + \
                         u * v * w * c111
    return interpolated_value

In [None]:
def calculateTheta(x_tilde, x_tilde_prime):
  norm = torch.linalg.norm(x_tilde, dim=-1).unsqueeze(3)
  print(norm.shape)
  norm_prime = torch.linalg.norm(x_tilde_prime, dim=-1).unsqueeze(3)
  print(norm_prime.shape)
  numerator = torch.linalg.norm(norm_prime * x_tilde - norm * x_tilde_prime, dim=-1)
  denominator = torch.linalg.norm(norm_prime * x_tilde + norm * x_tilde_prime, dim=-1)
  theta = torch.atan2(numerator, denominator)
  return theta

In [None]:
def calculateNeuralSpline(x):
    # Convert x to float type if it's not already
    # x = x.permute(0,1,3,2)
    # Reshape x to have the desired shape [points, points, 2, 3]
    #x = x.unsqueeze(3)  # Add a singleton dimension to enable broadcasting
    # x_tilde = torch.cat((x[..., 0], torch.ones_like(x[..., :0])), dim=-1)
    # x_tilde_prime = torch.cat((x[..., 1], torch.ones_like(x[..., :1])), dim=-1)

    x_tilde = x[..., 0]  # Extract x_tilde
    x_tilde_prime = x[..., 1]  # Extract x_tilde_prime

    theta = calculateTheta(x_tilde, x_tilde_prime)
    firstTerm = (torch.linalg.norm(x_tilde, dim=-1) * torch.linalg.norm(x_tilde_prime, dim=-1) / np.pi)
    secondTerm = (torch.sin(theta) + 2 * (np.pi - theta) * torch.cos(theta))
    kernelVal = firstTerm * secondTerm
    return kernelVal.squeeze()  # Remove singleton dimensions

In [None]:
def calculateKernel(points_i, points_j, grid):
  # Get features for each point by trilinearly interpolating from output grid
  features_i = trilinear_interpolation(points_i, grid)
  features_j = trilinear_interpolation(points_j, grid)
  # Concat features with points
  concat_points_i = torch.cat((points_i, features_i), dim=2)
  concat_points_j = torch.cat((points_j, features_j), dim=2)
  # Calculate gram matrix
  matrix = createGramMatrix(concat_points_i, concat_points_j)
  # Pass matrix into calculateNeuralSpline
  Kns = calculateNeuralSpline(matrix)
  # Return values
  return Kns

In [None]:
def f_x(alpha, new_points, original_points, grid):
  return alpha * calculateKernel(new_points, original_points, grid)

In [None]:
'''
/* __global__ void concat_kernel(float* input1, float* input2, float* out, int h, int w)  {
#     int r = blockIdx.y * blockDim.y + threadIdx.y;
#     int c = blockIdx.x*blockDim.x + threadIdx.x;
#     if(r>=h || c>= w) return;
#     out[:, :h, :, :3] = input1;
#     out[:, :, :w, 3:] = input2;
*/ }

 for (int i = 0; i < input1.size(0); i++){
      for (int j = 0; j < h; j++) {
        for (int k = 0; k < w; k++) {
          for (int z = 0; z < 6; z++) {
            out[i, j, k, z] = input1[i, j, z];
            out[i, j, k, z+3] = input2[i, k, z];
          }
        }
      }
    }
'''