In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import os
import io
import time
import traceback
import h5py
import tqdm

import pandas as pd
import numpy as np
from collections import defaultdict
from tifffile import imread, imwrite
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from math import isinf, isnan, sqrt
import hdf5plugin
import copy

import sdt.loc
import sdt
from utils import *
import random
import csv

In [4]:
def preprocess_img(im, mode='10bit'):
    if im.dtype == np.uint16:
        if mode == '8bit':
            im = np.power(im + 1, 0.5) - 1
            im = np.array(im, dtype=np.uint8)
        if mode == '10bit':
            im = np.power(im + 1, 0.71) - 1
            im = np.array(im, dtype=np.uint16)
        
    return im

In [5]:
def postprocess_img(im, mode='10bit'):
    if mode == '8bit':
        im = np.power(im + 1, 2) - 1
        im = np.array(im, dtype=np.uint16)
    if mode == '10bit':
        im = np.power(im + 1, 1/0.71) - 1
        im = np.array(im, dtype=np.uint16)

    return im

In [6]:
def test_compression(img, params={}, dataset=None, codec=None, cs=(256, 256, 256)):
    out = []
    try:
        if 'AV1' in codec:
            pre_img = preprocess_img(img)
        else:
            pre_img = copy.deepcopy(img)
        with io.BytesIO() as tmpfile:
            with h5py.File(tmpfile, 'w') as h5file:
                start = time.time()
                h5file.create_dataset('data', data=pre_img, **params, chunks=cs)
                out.append(tmpfile.getbuffer().nbytes)
                out.append(time.time() - start)
            with h5py.File(tmpfile, 'r') as h5file:
                start = time.time()
                tmp = np.array(h5file['data'])
                out.append(time.time() - start)

        if 'AV1' in codec:
            tmp = postprocess_img(tmp)
        # print(img.dtype, tmp.dtype)
        imwrite(f'STORM/tif/{codec}.tif', tmp)
    
        rmse = sqrt(mean_squared_error(img, tmp))
        psnr = peak_signal_noise_ratio(img, tmp)
        ssim = structural_similarity(img, tmp, data_range=np.amax(img) - np.amin(img))
        out.append(rmse)
        out.append(psnr)
        out.append(ssim)

    except Exception:
        print(traceback.format_exc())

    return out

In [7]:
data_path = '/data/duanb/'

In [8]:
datasets = {
    'STORM': imread(data_path + 'STORM/Aquired STORM.tif'),
}

for k, v in datasets.items():
    print(k, v.shape, v.dtype)

STORM (40000, 256, 256) uint16


In [9]:
csv_file = 'STORM_compression.csv'

In [10]:
crf_range = range(1, 12, 2)

In [11]:
# compression options
crf_settings = defaultdict(lambda: {})
for crf in crf_range:
    key = f'SVTAV1_Q{crf}'
    param = (6, 6, 256, 256, 256, 1, 409, 400, crf, 0, 0)
    toadd = {**crf_settings[key], **hdf5plugin.FFMPEG(*param)}
    crf_settings[key] = toadd

compression_methods = {
    'Uncompressed': {},
    'Blosc-Zstd': hdf5plugin.Blosc(cname='zstd', clevel=5, shuffle=hdf5plugin.Blosc.BITSHUFFLE),
    **crf_settings,
}

In [12]:
N = 1
chunk_size = 256
subsampled_images = defaultdict(lambda: [])
for dname, dset in tqdm.tqdm(datasets.items()):
    for i in range(N):
        selector = None
        repeat = True
        while repeat:
            z = np.random.randint(0, max(1, dset.shape[0]-chunk_size))
            x = np.random.randint(0, max(1, dset.shape[1]-chunk_size))
            y = np.random.randint(0, max(1, dset.shape[2]-chunk_size))
            selector = [
                slice(z, z+chunk_size),
                slice(x, x+chunk_size),
                slice(y, y+chunk_size)
            ]
            if dset.shape[0] < chunk_size:
                selector[0] = slice(None)

            tmp = np.array(dset[tuple(selector)])

            if tmp.max() < 100:
                repeat = True
            else:
                repeat = False
                subsampled_images[dname].append(tmp)

100%|█████████████████████████████████████████| 1/1 [00:00<00:00, 10.17it/s]


In [13]:
data = pd.DataFrame({
    'Dataset': [],
    'Compression Method': [],
    'Compressed Size': [],
    'Read Time': [],
    'Write Time': [],
    'RMSE': [],
    'PSNR': [],
    'SSIM': [],
    })

In [14]:
benchmark_start = time.time()
last_start = benchmark_start

pbar = tqdm.tqdm(total=N*len(subsampled_images) * len(compression_methods), ncols=80, ascii=True)
for dname, dset in subsampled_images.items():
    for mname, mparam in compression_methods.items():
        for img_data in dset:
            # print(dname, mname)
            if dname == 'EM' and '10bit' in mname: # no 10bit compression for already 8bit image
                pass
            else:
                img = copy.deepcopy(img_data)
                s, wt, rt, rmse, psnr, ssim = test_compression(img, params=mparam, dataset=dname, codec=mname)
                data = pd.concat([data, pd.DataFrame([{
                    'Dataset': dname,
                    'Compression Method': mname,
                    'Compressed Size': s,
                    'Read Time': rt,
                    'Write Time': wt,
                    'RMSE': rmse,
                    'PSNR': psnr,
                    'SSIM': ssim,
                }])], ignore_index=True)

            pbar.update(1)

print("Benchmark ended at", time.time() - benchmark_start)

  return 10 * np.log10((data_range ** 2) / err)
100%|#############################################| 8/8 [03:41<00:00, 29.83s/it]

Benchmark ended at 221.64569091796875


In [15]:
means = data.groupby(['Dataset', 'Compression Method']).mean()
data['Compression Ratio'] = data.apply(
    lambda x: means.loc[x['Dataset']].loc['Uncompressed']['Compressed Size'] / x['Compressed Size'],
    axis=1
)

data['Effective Compression Rate (MB/s)'] = data.apply(
    lambda x: means.loc[x['Dataset']].loc['Uncompressed']['Compressed Size'] / (10**6) / x['Write Time'],
    axis=1
)

data['Effective Decompression Rate (MB/s)'] = data.apply(
    lambda x: means.loc[x['Dataset']].loc['Uncompressed']['Compressed Size'] / (10**6) / x['Read Time'],
    axis=1
)

data.sort_values(by=['Compression Ratio'], ascending=False).to_csv(csv_file)

In [16]:
data.sort_values(by=['Compressed Size'], ascending=True).head(n=50)

Unnamed: 0,Dataset,Compression Method,Compressed Size,Read Time,Write Time,RMSE,PSNR,SSIM,Compression Ratio,Effective Compression Rate (MB/s),Effective Decompression Rate (MB/s)
7,STORM,SVTAV1_Q11,477820.0,0.267239,23.143988,68.766113,59.581977,0.981787,70.233494,1.450008,125.576567
6,STORM,SVTAV1_Q9,759127.0,0.319986,24.247411,64.184908,60.180808,0.983724,44.207317,1.384023,104.876251
5,STORM,SVTAV1_Q7,1284322.0,0.391296,22.796158,57.56281,61.126626,0.986728,26.129715,1.472133,85.763604
4,STORM,SVTAV1_Q5,2152993.0,0.466867,20.517525,47.439134,62.806731,0.990764,15.587124,1.635625,71.881258
3,STORM,SVTAV1_Q3,3418563.0,0.503409,20.210391,35.816711,65.247752,0.994587,9.816688,1.660481,66.663437
2,STORM,SVTAV1_Q1,5363894.0,0.579142,22.356339,21.472745,69.691715,0.998089,6.256456,1.501094,57.946024
1,STORM,Blosc-Zstd,20552893.0,0.091346,0.280428,0.0,inf,1.0,1.63281,119.670663,367.384821
0,STORM,Uncompressed,33558968.0,0.066021,0.027374,0.0,inf,1.0,1.0,1225.941852,508.307743
