# GPU rans

In [1]:
import torch
from compressai.entropy_models import EntropyBottleneck

torch.manual_seed(0)
eb = EntropyBottleneck(channels=192).cuda()
eb.update(force=True)

x = torch.randn(32, 192, 8, 8, device="cuda")

# CPU roundtrip
eb.use_gpu_ans = False
s_cpu = eb.compress(x)
xhat_cpu = eb.decompress(s_cpu, (8,8))

# GPU roundtrip
eb.use_gpu_ans = True
s_gpu = eb.compress(x)
xhat_gpu = eb.decompress(s_gpu, (8,8))  # 注意 EntropyBottleneck.decompress签名是 (strings,size)

# 对比
print((xhat_cpu - xhat_gpu).abs().max())


tensor(0., device='cuda:0')


  merged_u8 = torch.frombuffer(memoryview(merged), dtype=torch.uint8).to(device=cdf.device, non_blocking=False)


In [1]:
import time
import torch
import numpy as np
from compressai.entropy_models import EntropyBottleneck
from compressai.zoo import bmshj2018_factorized

def time_cpu(fn, iters=10, warmup=2):
    for _ in range(warmup):
        fn()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    t1 = time.perf_counter()
    return (t1 - t0) / iters

def time_gpu(fn, iters=50, warmup=10):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(iters):
        fn()
    end.record()
    torch.cuda.synchronize()
    ms = start.elapsed_time(end) / iters
    return ms / 1000.0

def run():
    project_dir = "/hwj"
    quality = 1
    device = "cuda:0"

    net = bmshj2018_factorized(quality=quality, pretrained=False)
    old_state_dict = torch.load(
        f"{project_dir}/data/model/bmshj2018-factorized-prior-{quality}.pth",
        map_location=device,
    )
    net.load_state_dict(old_state_dict)
    net.eval().to(device)

    path = "/hwj/project/aiz-accelerate/data/nyx-dark_matter_density.npy"

    # 1) load npy -> tensor
    arr = np.load(path)
    x0 = torch.from_numpy(arr).float().to(device)

    # 2) analysis transform -> latents
    with torch.no_grad():
        y = net.g_a(x0)

    eb = net.entropy_bottleneck
    # eb.update(force=True)  # 通常 pretrained 模型已更新好，如你不确定可以打开

    torch.manual_seed(0)

    # ------------------------------------------------------------
    # A) compress correctness: CPU bytes == GPU bytes
    # ------------------------------------------------------------
    eb.use_gpu_ans = False
    s_cpu = eb.compress(y)

    eb.use_gpu_ans = True
    s_gpu = eb.compress(y)

    assert len(s_cpu) == len(s_gpu)
    for i in range(len(s_cpu)):
        if s_cpu[i] != s_gpu[i]:
            print("Mismatch at", i, len(s_cpu[i]), len(s_gpu[i]))
            print(s_cpu[i][:64])
            print(s_gpu[i][:64])
            raise SystemExit(1)
    print("Pass verify: CPU compressed == GPU compressed (bytes match)")

    # ------------------------------------------------------------
    # B) decompress correctness (4-way cross check)
    # ------------------------------------------------------------
    size = y.size()[-2:]  # EntropyBottleneck.decompress(strings, size)

    # 1) CPU decode CPU bitstream
    eb.use_gpu_ans = False
    yhat_cpu_from_cpu = eb.decompress(s_cpu, size)

    # 2) GPU decode GPU bitstream
    eb.use_gpu_ans = True
    yhat_gpu_from_gpu = eb.decompress(s_gpu, size)

    # 3) CPU decode GPU bitstream
    eb.use_gpu_ans = False
    yhat_cpu_from_gpu = eb.decompress(s_gpu, size)

    # 4) GPU decode CPU bitstream
    eb.use_gpu_ans = True
    yhat_gpu_from_cpu = eb.decompress(s_cpu, size)

    # 注意：EntropyBottleneck.decompress 会 dequantize 出 float
    def max_abs(a, b):
        return (a - b).abs().max().item()

    m1 = max_abs(yhat_cpu_from_cpu, yhat_gpu_from_gpu)
    m2 = max_abs(yhat_cpu_from_cpu, yhat_cpu_from_gpu)
    m3 = max_abs(yhat_cpu_from_cpu, yhat_gpu_from_cpu)

    print("\n=== Decompress verify ===")
    print(f"max|CPU(cpu)->yhat  - GPU(gpu)->yhat| = {m1}")
    print(f"max|CPU(cpu)->yhat  - CPU(gpu)->yhat| = {m2}")
    print(f"max|CPU(cpu)->yhat  - GPU(cpu)->yhat| = {m3}")

    # 这三个必须为 0（或极小接近 0，如果你 dtype/设备转换引入了差异）
    if m1 != 0.0 or m2 != 0.0 or m3 != 0.0:
        print("Decompress mismatch! (expected exact match)")
        raise SystemExit(2)

    print("Pass verify: CPU/GPU decompress are mutually compatible")

    # ------------------------------------------------------------
    # C) speed tests: compress + decompress
    # ------------------------------------------------------------
    def cpu_compress():
        eb.use_gpu_ans = False
        return eb.compress(y)

    def gpu_compress():
        eb.use_gpu_ans = True
        return eb.compress(y)

    def cpu_decompress_cpu_stream():
        eb.use_gpu_ans = False
        return eb.decompress(s_cpu, size)

    def gpu_decompress_gpu_stream():
        eb.use_gpu_ans = True
        return eb.decompress(s_gpu, size)

    # compress: CPU path iters 小一些（tolist + loop 很慢）
    cpu_comp_sec = time_cpu(cpu_compress, iters=3, warmup=1)
    gpu_comp_sec = time_gpu(gpu_compress, iters=30, warmup=5)

    # decompress: 同理，CPU path iters 小一点
    cpu_decomp_sec = time_cpu(cpu_decompress_cpu_stream, iters=3, warmup=1)
    gpu_decomp_sec = time_gpu(gpu_decompress_gpu_stream, iters=30, warmup=5)

    total_bytes = sum(len(s) for s in s_cpu)
    raw_bytes = y.numel() * 4
    gb_raw = raw_bytes / (1024**3)

    print("\n=== Size ===")
    print("ga output:", y.shape, f"{raw_bytes} bytes", y.dtype)
    print(f"Total compressed bytes: {total_bytes}")
    print(f"CR: {raw_bytes/total_bytes:.2f}")

    print("\n=== Speed: compress ===")
    print(f"CPU compress avg: {cpu_comp_sec*1000:.3f} ms / call")
    print(f"GPU compress avg: {gpu_comp_sec*1000:.3f} ms / call")
    print(f"Speedup: {cpu_comp_sec/gpu_comp_sec:.2f}x")
    print(f"Throughput CPU: {gb_raw/cpu_comp_sec:.6f} GiB/s")
    print(f"Throughput GPU: {gb_raw/gpu_comp_sec:.6f} GiB/s")

    print("\n=== Speed: decompress ===")
    print(f"CPU decompress avg: {cpu_decomp_sec*1000:.3f} ms / call")
    print(f"GPU decompress avg: {gpu_decomp_sec*1000:.3f} ms / call")
    print(f"Speedup: {cpu_decomp_sec/gpu_decomp_sec:.2f}x")
    print(f"Throughput CPU: {gb_raw/cpu_decomp_sec:.6f} GiB/s")
    print(f"Throughput GPU: {gb_raw/gpu_decomp_sec:.6f} GiB/s")

if __name__ == "__main__":
    run()


Pass verify: CPU compressed == GPU compressed (bytes match)

=== Decompress verify ===
max|CPU(cpu)->yhat  - GPU(gpu)->yhat| = 0.0
max|CPU(cpu)->yhat  - CPU(gpu)->yhat| = 0.0
max|CPU(cpu)->yhat  - GPU(cpu)->yhat| = 0.0
Pass verify: CPU/GPU decompress are mutually compatible

=== Size ===
ga output: torch.Size([512, 192, 8, 8]) 25165824 bytes torch.float32
Total compressed bytes: 261808
CR: 96.12

=== Speed: compress ===
CPU compress avg: 602.875 ms / call
GPU compress avg: 7.667 ms / call
Speedup: 78.63x
Throughput CPU: 0.038876 GiB/s
Throughput GPU: 3.056780 GiB/s

=== Speed: decompress ===
CPU decompress avg: 776.710 ms / call
GPU decompress avg: 7.397 ms / call
Speedup: 105.00x
Throughput CPU: 0.030175 GiB/s
Throughput GPU: 3.168412 GiB/s
