In [None]:
import torch
import numpy as np
import os
import glob
import collections
from PIL import Image
from compress import prepare_model, prepare_dataloader, compress_and_save, load_and_decompress
import shutil
File = collections.namedtuple('File', ['output_path', 'compressed_path',
                                       'num_bytes', 'bpp'])
original_sizes = dict()
weights_path='weights/hific_low.pt'
img_path='pepeimg'
stage_path='temp_dir'
out_path='outputimg'
model_choice='HIFIC-low'
model_choices = {'HIFIC-low': '1hfFTkZbs_VOBmXQ-M4bYEPejrD76lAY9',
                 'HIFIC-med': '1QNoX0AGKTBkthMJGPfQI0dT0_tnysYUb',
                 'HIFIC-high': '1BFYpvhVIA_Ek2QsHBbKnaBE8wn1GhFyA'}
model_ID = model_choices[model_choice]

In [None]:
if torch.cuda.is_available() is False:
  print('WARNING: No GPU found. Compression/decompression will be slow!')
else:
  print(f'Found GPU {torch.cuda.get_device_name(0)}')
model, args = prepare_model(weights_path, stage_path)

In [None]:
all_files = os.listdir(img_path)
print(f'Got following files ({len(all_files)}):')
scale_factor = 2 if len(all_files) == 1 else 4

for file_name in all_files:
    img = Image.open(os.path.join(img_path, file_name))
    w, h = img.size
    img = img.resize((w // scale_factor, h // scale_factor))
    print('-> ' + file_name + ':')
    display(img)

In [None]:

SUPPORTED_EXT = {'.png', '.jpg'}

all_files = os.listdir(img_path)
if not all_files:
    raise ValueError("Please upload/download images!")

def get_bpp(image_dimensions, num_bytes):
    w, h = image_dimensions
    return num_bytes * 8 / (w * h)

def has_alpha(img_p):
    im = Image.open(img_p)
    return im.mode == 'RGBA'

for filename in os.listdir(stage_path):
    file_path=os.path.join(stage_path, filename)
    if os.path.isfile(file_path):
        os.unlink(file_path)
for file_name in all_files:
    if os.path.isdir(file_name):
        continue
    if not any(file_name.endswith(ext) for ext in SUPPORTED_EXT):
        print('Skipping non-image', file_name, '...')
        continue
    full_path = os.path.join(img_path, file_name)
    if has_alpha(full_path) is True:
        print('Skipping because of alpha channel:', file_name)
        continue
    file_name, _ = os.path.splitext(file_name)
    original_sizes[file_name] = os.path.getsize(full_path)
    output_path = os.path.join(out_path, f'{file_name}.png')
    #!mv -v $full_path $stage_path
    shutil.copy(full_path, stage_path)

In [None]:
data_loader=prepare_dataloader(args, stage_path, out_path)

In [None]:
compress_and_save(model, args, data_loader, out_path)

In [None]:
all_outputs = []

for compressed_file in glob.glob(os.path.join(out_path, '*.hfc')):
    file_name, _ = os.path.splitext(compressed_file)
    output_path = os.path.join(out_path, f'{file_name}.png')

    # Model decode
    reconstruction = load_and_decompress(model, compressed_file, output_path)
    
    all_outputs.append(File(output_path=output_path,
                            compressed_path=compressed_file,
                            num_bytes=os.path.getsize(compressed_file),
                            bpp=get_bpp(Image.open(output_path).size, os.path.getsize(compressed_file))))
              