In [1]:
import numpy as np

import wgpu
from wgpu.utils.compute import compute_with_buffers
%load_ext line_profiler

In [2]:
data_file = np.load("dct_data_jul11.npz")
dct = data_file['dct']
# mask = dct != 0
# dct -= np.amin(dct)
# dct *= mask

In [3]:
dct[:8, :8]

array([[244,  16,   9,   7,  -1,   0,   0,   0],
       [-64, -16,  20,   2,   2,   0,   0,   0],
       [-39,   1,   5,   2,   0,   0,   0,   0],
       [ -7,   0,   2,   0,   0,   0,   0,   0],
       [ -3,   0,   1,   0,   0,   0,   0,   0],
       [ -1,   0,   0,   0,   0,   0,   0,   0],
       [ -1,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0]], dtype=int16)

In [4]:
def read_rle_shader_preview():
    return open("./rle_preview.wgsl").read()

def read_rle_shader():
    return open("./rle.wgsl").read()

def read_cumulative_sum_shader():
    return open("./cumulative_sum.wgsl").read()

rle_preview = read_rle_shader_preview()
rle_full = read_rle_shader()

cumsum_shader = read_cumulative_sum_shader()

In [5]:

# def generate_basis():
zigzag = np.array(
    [
        [0, 0],
        [0, 1],
        [1, 0],
        [2, 0],
        [1, 1],
        [0, 2],
        [0, 3],
        [1, 2],
        [2, 1],
        [3, 0],
        [4, 0],
        [3, 1],
        [2, 2],
        [1, 3],
        [0, 4],
        [0, 5],
        [1, 4],
        [2, 3],
        [3, 2],
        [4, 1],
        [5, 0],
        [6, 0],
        [5, 1],
        [4, 2],
        [3, 3],
        [2, 4],
        [1, 5],
        [0, 6],
        [0, 7],
        [1, 6],
        [2, 5],
        [3, 4],
        [4, 3],
        [5, 2],
        [6, 1],
        [7, 0],
        [7, 1],
        [6, 2],
        [5, 3],
        [4, 4],
        [3, 5],
        [2, 6],
        [1, 7],
        [2, 7],
        [3, 6],
        [4, 5],
        [5, 4],
        [6, 3],
        [7, 2],
        [7, 3],
        [6, 4],
        [5, 5],
        [4, 6],
        [3, 7],
        [4, 7],
        [5, 6],
        [6, 5],
        [7, 4],
        [7, 5],
        [6, 6],
        [5, 7],
        [6, 7],
        [7, 6],
        [7, 7],
    ]
)
zig_dim1 = zigzag[:, 0].astype(np.int32)
zig_dim2 = zigzag[:, 1].astype(np.int32)

In [6]:
zigzag.shape

(64, 2)

In [7]:
dtype_mapping = {
    np.int8: "b",
    np.int16: "h",
    np.int32: "i",
    np.float16: "e",
    np.float32: "f",
}

In [8]:
dispatch_counts = np.array([dct.shape[0] // 8, dct.shape[1] // 8]).astype(np.int16)

In [9]:
dct_flat = dct.flatten().astype(np.int32)

# Define the bindings to pipeline everything

In [33]:
## Make zig_dim1 and zig_dim1 buffers

# Define a device
device = wgpu.utils.get_default_device()

#Define buffers
zig_dim1_buffer = device.create_buffer_with_data(data=zig_dim1, usage=wgpu.BufferUsage.STORAGE)
zig_dim2_buffer = device.create_buffer_with_data(data=zig_dim2, usage=wgpu.BufferUsage.STORAGE)
dct_buffer = device.create_buffer_with_data(data=dct_flat, usage=wgpu.BufferUsage.STORAGE)

out_buffer = device.create_buffer(
    size=dct_flat.nbytes, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_SRC
)


## Define binding layouts for each step:
# Setup layout and bindings
binding_layouts = [
    {
        "binding": 0,
        "visibility": wgpu.ShaderStage.COMPUTE,
        "buffer": {
            "type": wgpu.BufferBindingType.read_only_storage,
        },
    },
    {
        "binding": 1,
        "visibility": wgpu.ShaderStage.COMPUTE,
        "buffer": {
            "type": wgpu.BufferBindingType.read_only_storage,
        },
    },
    {
        "binding": 2,
        "visibility": wgpu.ShaderStage.COMPUTE,
        "buffer": {
            "type": wgpu.BufferBindingType.read_only_storage,
        },
    },
    {
        "binding": 3,
        "visibility": wgpu.ShaderStage.COMPUTE,
        "buffer": {
            "type": wgpu.BufferBindingType.storage,
        },
    },    
]
bindings = [
    {
        "binding": 0,
        "resource": {"buffer": dct_buffer, "offset": 0, "size": dct_buffer.size},
    },
    {
        "binding": 1,
        "resource": {"buffer": zig_dim1_buffer, "offset": 0, "size": zig_dim1_buffer.size},
    },
    {
        "binding": 2,
        "resource": {"buffer": zig_dim2_buffer, "offset": 0, "size": zig_dim2_buffer.size},
    },
    {
        "binding": 3,
        "resource": {"buffer": out_buffer, "offset": 0, "size": out_buffer.size},
    },
]

# Put everything together
bind_group_layout = device.create_bind_group_layout(entries=binding_layouts)
pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)

cshader = device.create_shader_module(code=open("rle_preview.wgsl").read())
# Create and run the pipeline
compute_pipeline = device.create_compute_pipeline(
    layout=pipeline_layout,
    compute={"module": cshader, "entry_point": "main"},
)


command_encoder = device.create_command_encoder()
compute_pass = command_encoder.begin_compute_pass()
compute_pass.set_pipeline(compute_pipeline)
compute_pass.set_bind_group(0, bind_group)
compute_pass.dispatch_workgroups(64, 64, 1)  # x y z
compute_pass.end()
device.queue.submit([command_encoder.finish()])

out = device.queue.read_buffer(out_buffer).cast("i")
out = np.array(out)


0

In [11]:
import time
start_time = time.time()
def my_func():
    out_shape = ((512*512) / (64*64),)
    out = compute_with_buffers(
        input_arrays={0: dct_flat,
                     1: zig_dim1,
                     2: zig_dim2},
        output_arrays={3: (*out_shape, dtype_mapping[np.int32])},
        shader=rle_preview,
        n=(64,64, 1)
    )
    
    # print("done!")
    rle_lengths = np.frombuffer(out[3], dtype=np.int32)
    out_shape = (rle_lengths.shape[0],)
    out = compute_with_buffers(
        input_arrays={0: rle_lengths},
        output_arrays={1: (*out_shape, dtype_mapping[np.int32]),
                       2: (1, dtype_mapping[np.int32])},
        shader=cumsum_shader,
        n=(1,1,1)
    )

    start_pts = np.frombuffer(out[1], dtype=np.int32)
    total_length = np.frombuffer(out[2], dtype=np.int32)[0]
    
    out = compute_with_buffers(
        input_arrays={0: dct_flat,
                      1: zig_dim1,
                      2:zig_dim2,
                      3: start_pts,
                      4: rle_lengths},
        
        output_arrays={5: (total_length, dtype_mapping[np.int32])},
        shader=rle_full,
        n=(64,64,1)
    )

    return np.frombuffer(out[5], dtype=np.int32)
x = my_func()

ValueError: Invalid shape for output array 5: (0,)

In [25]:
def my_func():
    out_shape = ((512*512) / (8*8),)
    out = compute_with_buffers(
        input_arrays={0: dct_flat,
                     1: zig_dim1,
                     2: zig_dim2},
        output_arrays={3: (*out_shape, dtype_mapping[np.int32])},
        shader=rle_preview,
        n=(64,64, 1)
    )

    rle_lengths = np.frombuffer(out[3], dtype=np.int32)
    out_shape = (rle_lengths.shape[0],)
    out = compute_with_buffers(
        input_arrays={0: rle_lengths},
        output_arrays={1: (*out_shape, dtype_mapping[np.int32]),
                       2: (1, dtype_mapping[np.int32])},
        shader=cumsum_shader,
        n=(1,1,1)
    )
    
    start_pts = np.frombuffer(out[1], dtype=np.int32)
    total_length = np.frombuffer(out[2], dtype=np.int32)
    
    
    out = compute_with_buffers(
        input_arrays={0: dct_flat,
                      1: zig_dim1,
                      2:zig_dim2,
                      3: start_pts,
                      4: rle_lengths},
        
        output_arrays={5: (total_length, dtype_mapping[np.int32])},
        shader=rle_full,
        n=(64,64,1)
    )
    
    return np.frombuffer(out[5], dtype=np.int32)

In [28]:
%lprun -f my_func my_func()

Timer unit: 1e-09 s

Total time: 0.0343335 s
File: /tmp/ipykernel_3287654/1886378806.py
Function: my_func at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def my_func():
     2         1       1807.0   1807.0      0.0      out_shape = ((512*512) / (8*8),)
     3         2   12942860.0    6e+06     37.7      out = compute_with_buffers(
     4         2       2845.0   1422.5      0.0          input_arrays={0: dct_flat,
     5         1        523.0    523.0      0.0                       1: zig_dim1,
     6         1        434.0    434.0      0.0                       2: zig_dim2},
     7         1       4583.0   4583.0      0.0          output_arrays={3: (*out_shape, dtype_mapping[np.int32])},
     8         1        677.0    677.0      0.0          shader=rle_preview,
     9         1        460.0    460.0      0.0          n=(64,64, 1)
    10                                               )
    11               

In [22]:
total_length

array([119246], dtype=int32)

In [None]:
zig_dim1.shape

# Now let's try the wgpu data

In [None]:
device = wgpu.utils.get_default_device()
# Create buffer objects, input buffer is mapped.

buffer_input = device.create_buffer_with_data(data=dct_flat, usage=wgpu.BufferUsage.STORAGE)
buffer_zig1 = device.create_buffer_with_data(data=zig_dim1, usage=wgpu.BufferUsage.STORAGE)
buffer_zig2 = device.create_buffer_with_data(data=zig_dim2, usage=wgpu.BufferUsage.STORAGE)
