In [1]:
import numpy as np

import wgpu
from wgpu.utils.compute import compute_with_buffers

In [2]:
data_file = np.load("dct_data_jul11.npz")
dct = data_file['dct']
mask = dct != 0
dct -= np.amin(dct)
dct *= mask

In [3]:
dct[:8, :8]

array([[451, 223, 216, 214, 206,   0,   0,   0],
       [143, 191, 227, 209, 209,   0,   0,   0],
       [168, 208, 212, 209,   0,   0,   0,   0],
       [200,   0, 209,   0,   0,   0,   0,   0],
       [204,   0, 208,   0,   0,   0,   0,   0],
       [206,   0,   0,   0,   0,   0,   0,   0],
       [206,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0]], dtype=int16)

In [4]:
def read_shader():
    return open("./rle.wgsl").read()

In [5]:

# def generate_basis():
zigzag = np.array(
    [
        [0, 0],
        [0, 1],
        [1, 0],
        [2, 0],
        [1, 1],
        [0, 2],
        [0, 3],
        [1, 2],
        [2, 1],
        [3, 0],
        [4, 0],
        [3, 1],
        [2, 2],
        [1, 3],
        [0, 4],
        [0, 5],
        [1, 4],
        [2, 3],
        [3, 2],
        [4, 1],
        [5, 0],
        [6, 0],
        [5, 1],
        [4, 2],
        [3, 3],
        [2, 4],
        [1, 5],
        [0, 6],
        [0, 7],
        [1, 6],
        [2, 5],
        [3, 4],
        [4, 3],
        [5, 2],
        [6, 1],
        [7, 0],
        [7, 1],
        [6, 2],
        [5, 3],
        [4, 4],
        [3, 5],
        [2, 6],
        [1, 7],
        [2, 7],
        [3, 6],
        [4, 5],
        [5, 4],
        [6, 3],
        [7, 2],
        [7, 3],
        [6, 4],
        [5, 5],
        [4, 6],
        [3, 7],
        [4, 7],
        [5, 6],
        [6, 5],
        [7, 4],
        [7, 5],
        [6, 6],
        [5, 7],
        [6, 7],
        [7, 6],
        [7, 7],
    ]
)
zig_dim1 = zigzag[:, 0].astype(np.int32)
zig_dim2 = zigzag[:, 1].astype(np.int32)

In [6]:
zigzag.shape

(64, 2)

In [7]:
dtype_mapping = {
    np.int8: "b",
    np.int16: "h",
    np.int32: "i",
    np.float16: "e",
    np.float32: "f",
}

In [8]:
dispatch_counts = np.array([dct.shape[0] // 8, dct.shape[1] // 8]).astype(np.int16)

In [9]:
dct_flat = dct.flatten().astype(np.int32)

In [10]:
## out_shape = a.shape
out_shape = (512*512,)
out = compute_with_buffers(
    input_arrays={0: dct_flat,
                 1: zig_dim1,
                 2: zig_dim2},
    output_arrays={3: (*out_shape, dtype_mapping[np.int32])},
    shader=read_shader(),
    n=(64,64, 1)
)

shader_out = np.frombuffer(out[3], dtype=np.int32)
shader_out

No windowing system present. Using surfaceless platform
No config found!
No config found!
Max vertex attribute stride unknown. Assuming it is 2048


array([  0, 451,   0, ...,   0,   0,   0], dtype=int32)

In [11]:
dct[:8, :8]

array([[451, 223, 216, 214, 206,   0,   0,   0],
       [143, 191, 227, 209, 209,   0,   0,   0],
       [168, 208, 212, 209,   0,   0,   0,   0],
       [200,   0, 209,   0,   0,   0,   0,   0],
       [204,   0, 208,   0,   0,   0,   0,   0],
       [206,   0,   0,   0,   0,   0,   0,   0],
       [206,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0]], dtype=int16)