# 0 prepare data and libraries

In [None]:
!pip install sunpy zeep drms hvpy

In [None]:
!pip install pytorch_msssim compressai

In [None]:
import hvpy
import matplotlib.pyplot as plt
from sunpy.time import parse_time
from sunpy.util.config import get_and_create_download_dir
from matplotlib.image import imread
import math, io, os, torch
from torchvision import transforms
import numpy as np
from PIL import Image
from pytorch_msssim import ms_ssim
from compressai.zoo import bmshj2018_factorized
from ipywidgets import interact, widgets


In [None]:
# Download sample image from STEREO
cor2_file = hvpy.save_file(hvpy.getJP2Image(parse_time('2014/05/15 07:54').datetime,
                                            hvpy.DataSource.COR2_A.value),
                           get_and_create_download_dir() + "/COR2.jp2")
print(cor2_file)
!cp /root/sunpy/data/COR2.jp2 COR2.jp2

In [None]:
!ls -luah
# We downloaded example image "COR2.jp2" with 265K (it's a JPEG 2000 file - the real raw data might be different)

In [None]:
img = imread("COR2.jp2")
print(img.shape, img.dtype)
plt.imshow(img, cmap="gray")

# 1 on-board (compress)

In [None]:
# https://github.com/InterDigitalInc/CompressAI/blob/master/examples/CompressAI%20Inference%20Demo.ipynb
# model -> https://github.com/InterDigitalInc/CompressAI/blob/b10cc7c1c51a0af26ea5deae474acfd5afdc1454/compressai/models/google.py

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

net = bmshj2018_factorized(quality=2, pretrained=True).eval().to(device)
print(f'Parameters: {sum(p.numel() for p in net.parameters())}')

img = imread("COR2.jp2")
print("original data range (min,mean,max):", np.min(img), np.mean(img), np.max(img)) # 0-255

img = np.asarray([img,img,img]) # fake rgb
img = np.transpose(img, (1, 2, 0))

x = transforms.ToTensor()(img)
x = x.unsqueeze(0).to(device)
print("x data range (min,mean,max):", torch.min(x), torch.mean(x), torch.max(x)) # 0-1

In [None]:
%%time

with torch.no_grad():
    # Full pass: out_net = net.forward(x)
    # Compress:
    print("x", x.shape)
    y = net.g_a(x)
    print("y", y.shape)
    y_strings = net.entropy_bottleneck.compress(y)
    print("len(y_strings) = ",len(y_strings[0]))

    strings = [y_strings]
    shape = y.size()[-2:]

In [None]:
print(type(strings[0][0]))
# print(shape)
# name = "latent_" + str(shape[0])+"_"+str(shape[1])

In [None]:
# Save compressed forms:
with open(name+".bytes", 'wb') as f:
    f.write(strings[0][0])

In [None]:
!ls -luah

# 2 ground-based (decompress)

In [None]:
with open(name+".bytes", "rb") as f:
    strings_loaded = f.read()
strings_loaded = [[strings_loaded]]

a, b = int(name.split("_")[1]), int(name.split("_")[2])
shape_loaded = ([a,b])

In [None]:
%%time

with torch.no_grad():
    out_net = net.decompress(strings_loaded, shape_loaded)
    #(is already called inside) out_net['x_hat'].clamp_(0, 1)

x_hat = out_net['x_hat']
print("x_hat data range (min,mean,max):", torch.min(x_hat), torch.mean(x_hat), torch.max(x_hat)) # 0-1

print(out_net.keys())

In [None]:
rec_net = transforms.ToPILImage()(out_net['x_hat'].squeeze().cpu())
print("reconstruction data range (min,mean,max):", np.min(rec_net), np.mean(rec_net), np.max(rec_net)) # 0-255 again

diff = torch.mean((out_net['x_hat'] - x).abs(), axis=1).squeeze().cpu()

In [None]:
fix, axes = plt.subplots(1, 3, figsize=(16, 12))
for ax in axes:
    ax.axis('off')

axes[0].imshow(img)
axes[0].title.set_text('Original')

axes[1].imshow(rec_net)
axes[1].title.set_text('Reconstructed')

axes[2].imshow(diff, cmap='viridis')
axes[2].title.set_text('Difference')

plt.show()


# 3 metrics

In [None]:
def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()

def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
              for likelihoods in out_net['likelihoods'].values()).item()

def convert_size(size_bytes):
   if size_bytes == 0:
       return "0B"
   size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
   i = int(math.floor(math.log(size_bytes, 1024)))
   p = math.pow(1024, i)
   s = round(size_bytes / p, 2)
   return "%s %s" % (s, size_name[i])

def files_size(file_path):
    size_bytes = os.path.getsize(file_path)
    print("File", file_path, "has", convert_size(size_bytes))
    return size_bytes

print(f'PSNR: {compute_psnr(x, out_net["x_hat"]):.2f}dB')
print(f'MS-SSIM: {compute_msssim(x, out_net["x_hat"]):.4f}')
if 'likelihoods' in out_net.keys():
    print(f'Bit-rate: {compute_bpp(out_net):.3f} bpp')

original_size = files_size("COR2.jp2")
latent_size = files_size("latent_128_128.bytes")

reduction_factor = original_size / latent_size
print("Compressed with reduction factor by", round(reduction_factor,2), "times")

# (Notes)

- The data range of this loaded sample is well behaved (original data was between 0-255), our real world data will likely not be - normalisation between 0-1 before being passed to the network is needed

- Real world data might have worse compression than this ".jp2" sample

- The model is pre-trained with RGB images - we waste these channels by repeating our one channel three times

- This network is realtively small (11MB), but its speed needs to be tested on tiny devices. Evaluating it on something like the Myriad chip would also need rewriting a bit of the code to work in (almost) pure torch without many dependencies. For now, using just CPU on Colab VM, it takes about 12 sec to encode and 15 sec to decode - which is not very fast...
