# Grayscale Image

In [43]:
from pathlib import Path

import cv2
import torch
from torch.utils.cpp_extension import load_inline

In [6]:
# constans
SIZE = 1000
KERNEL_DIR = Path("../kernels")
DATA_DIR = Path("../data")
IMAGE_PATH = DATA_DIR / "puppy.png"

In [10]:
# utils

def compile_ext(cuda_source: str, cpp_headers: str, ext_name: str, func: list):
    cuda_source = Path(cuda_source).read_text()

    ext = load_inline(
        name=ext_name,
        cpp_sources=cpp_headers,
        cuda_sources=cuda_source,
        functions=func,
        with_cuda=True,
        extra_cuda_cflags=["-O2"],
    )
    return ext


def tensor_details(tensor: torch.Tensor, name: str, head: int = 10):
    print("*" * 50)
    print(f"Tensor {name}")
    print(f"\t Shape: {tensor.shape}")
    print(f"\t Dtype: {tensor.dtype}")
    print(f"\t Device: {tensor.device}")
    print(f"Sample:\n {tensor[:head]}\n")

In [11]:
# Read images
from torchvision import io

In [12]:
image = io.read_image(str(IMAGE_PATH))

In [14]:
tensor_details(image, "image", 1)

**************************************************
Tensor image
	 Shape: torch.Size([3, 1536, 2048])
	 Dtype: torch.uint8
	 Device: cpu
Sample:
 tensor([[[91, 91, 91,  ..., 92, 94, 95],
         [91, 91, 90,  ..., 92, 94, 95],
         [90, 90, 90,  ..., 93, 95, 95],
         ...,
         [82, 73, 46,  ..., 26, 27, 26],
         [88, 83, 55,  ..., 23, 23, 20],
         [65, 68, 48,  ..., 18, 20, 21]]], dtype=torch.uint8)



In [30]:
cuda_source = KERNEL_DIR / "rgb_to_grayscale.cu"
cpp_source = "torch::Tensor rgb_to_grayscale(torch::Tensor input);"

In [31]:
# Compile extension
ext = compile_ext(cuda_source, cpp_source, "rgb_to_grayscale", ["rgb_to_grayscale"])

In [32]:
image = image.contiguous().cuda()

In [33]:
output = ext.rgb_to_grayscale(image)

In [44]:
tensor_details(output, "Output", 1)

**************************************************
Tensor Output
	 Shape: torch.Size([1536, 2048])
	 Dtype: torch.uint8
	 Device: cuda:0
Sample:
 tensor([[86, 86, 86,  ..., 94, 96, 97]], device='cuda:0', dtype=torch.uint8)



In [42]:
cv2.imwrite(str(DATA_DIR / "uppy_gray.png"), output.cpu().numpy())

True