<a href="https://colab.research.google.com/github/PsorTheDoctor/artificial-intelligence/blob/master/modern_approach/text_to_image/stable_diffusion_compression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Stable Diffusion image compression

##Initial setup

In [None]:
!pip install -qq diffusers["training"] transformers ftfy
!pip install -qq libimagequant
!pip install -qq mozjpeg-lossless-optimization
!pip install -qq scikit-image
!pip install -qq Pillow==9.0.0 -U

In [2]:
%mkdir input
%mkdir output

In [5]:
input_dir = '/content/input/'
output_dir = '/content/output/'
pretrained_model = 'CompVis/stable-diffusion-v1-4'

In [3]:
from huggingface_hub import notebook_login

notebook_login()

Login successful
Your token has been saved to /root/.huggingface/token


In [None]:
from diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel, StableDiffusionImg2ImgPipeline
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
import torch
from torch.cuda.amp import autocast

device = 'cuda'

vae = AutoencoderKL.from_pretrained(
    pretrained_model, subfolder="vae", use_auth_token=True
).to(device)

unet = UNet2DConditionModel.from_pretrained(
    pretrained_model, subfolder="unet", use_auth_token=True
).to(device)

scheduler = PNDMScheduler(
    beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
    num_train_timesteps=1000, skip_prk_steps=True
).set_format("pt")

text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model, subfolder="text_encoder", use_auth_token=True
)
tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model, subfolder="tokenizer", use_auth_token=True
)
uncond_input = tokenizer([""], padding="max_length", 
                         max_length=tokenizer.model_max_length, 
                         return_tensors="pt")
with torch.no_grad():
  uncond_embeddings = text_encoder(uncond_input.input_ids)[0].to(device)

##Helper functions

In [7]:
import PIL
from PIL import Image
import numpy as np
import inspect
import io
import libimagequant as liq
import zlib
import gc
import time
import mozjpeg_lossless_optimization
from skimage.metrics import structural_similarity as get_ssim
from skimage.metrics import peak_signal_noise_ratio as get_psnr

@torch.no_grad()
def to_latents(img: Image):
  np_img = (np.array(img).astype(np.float32) / 255.0) * 2.0 - 1.0
  np_img = np_img[None].transpose(0, 3, 1, 2)
  torch_img = torch.from_numpy(np_img)
  with autocast():
    generator = torch.Generator("cuda").manual_seed(0)
    latents = vae.encode(torch_img.to(vae.dtype).to(device)).latent_dist.sample(generator=generator)
  return latents

@torch.no_grad()
def to_img(latents):
  with autocast():
    torch_img = vae.decode(latents.to(vae.dtype).to(device)).sample
  torch_img = (torch_img / 2 + 0.5).clamp(0, 1)
  np_img = torch_img.cpu().permute(0, 2, 3, 1).detach().numpy()[0]
  np_img = (np_img * 255.0).astype(np.uint8)
  img = Image.fromarray(np_img)
  return img

def resize_to_512(input_file, output_file):
  img = Image.open(input_file).convert('RGB')
  # Center cropped image
  maxdim = max(img.width, img.height)
  mindim = min(img.width, img.height)
  left = max(0, (img.width - img.height) // 2 - 1)
  top = max(0, (img.height - img.width) // 2 - 1)
  img = img.crop((left, top, left + mindim - 1, top + mindim - 1))
  # Resize
  img = img.resize((512, 512), Image.LANCZOS)
  img.save(output_file, lossless = True, quality = 100)

def print_metrics(gt, img):
  gt = np.array(gt)
  img = np.array(img)
  print('PSNR: {:.4f}'.format(get_psnr(gt, img)))
  print('SSIM: {:.4f}'.format(get_ssim(gt, img, multichannel=True, 
                                       data_range=img.max() - img.min())))

##Compression methods

In [8]:
coeff = 0.18215

def quantize(latents):
  quantized_latents = (latents / (255 * coeff) + 0.5).clamp(0, 1)
  quantized = quantized_latents.cpu().permute(0, 2, 3, 1).detach().numpy()[0]
  quantized = (quantized * 255 + 0.5).astype(np.uint8)
  return quantized

def unquantize(quantized):
  unquantized = quantized.astype(np.float32) / 255.0
  unquantized = unquantized[None].transpose(0, 3, 1, 2)
  unquantized_latents = (unquantized - 0.5) * (255 * coeff)
  unquantized_latents = torch.from_numpy(unquantized_latents)
  return unquantized_latents.to(device)

@torch.no_grad()
def denoise(latents):
  latents = latents * coeff
  step_size = 15 
  n_inference_steps = scheduler.config.get('num_train_timesteps', 1000) // step_size
  strength = 0.04
  scheduler.set_timesteps(n_inference_steps)
  offset = scheduler.config.get('steps_offset', 0)
  init_timestep = int(n_inference_steps * strength) + offset
  init_timestep = min(init_timestep, n_inference_steps)
  timesteps = scheduler.timesteps[-init_timestep]
  timesteps = torch.tensor([timesteps], dtype=torch.long, device=device)
  extra_step_kwargs = {}
  if 'eta' in set(inspect.signature(scheduler.step).parameters.keys()):
    extra_step_kwargs['eta'] = 0.9
  
  latents = latents.to(unet.dtype).to(device)
  t_start = max(n_inference_steps - init_timestep + offset, 0)
  with autocast():
    for i, t in enumerate(scheduler.timesteps[t_start:]):
      noise_pred = unet(latents, t, encoder_hidden_states=uncond_embeddings).sample
      latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

  scheduler.set_timesteps(1)
  return latents / coeff

def compress(input_file, output_path):
  img = Image.open(input_file)
  display(img)
  print('Original')
  print('Size: {:.2f} kB'.format(os.stat(input_file).st_size / 1024))

  latents = to_latents(img)
  img_from_latents = to_img(latents)
  display(img_from_latents)
  print('VAE roundtrip')
  print_metrics(img, img_from_latents)

  quantized = quantize(latents)
  del latents
  quantized_img = Image.fromarray(quantized)
  quantized_img.save(output_path + '.webp', lossless=True, quality=100)

  unquantized_latents = unquantize(quantized)
  unquantized_img = to_img(unquantized_latents)
  display(unquantized_img)
  del unquantized_latents
  print('VAE decoded from 8-bit quantized latents')
  print_metrics(img, unquantized_img)

  attr = liq.Attr()
  attr.speed = 1
  attr.max_colors = 256
  input_img = attr.create_rgba(quantized.flatten('C').tobytes(),
                               quantized_img.width, quantized_img.height, 0
  )
  quantization_result = input_img.quantize(attr)
  quantization_result.dithering_level = 1.0
  out_pixels = quantization_result.remap_image(input_img)
  out_palette = quantization_result.get_palette()
  indices = np.frombuffer(out_pixels, np.uint8)
  palette = np.array([c for color in out_palette for c in color], dtype=np.uint8)

  palettized_bytes = io.BytesIO()
  np.savez_compressed(palettized_bytes, w=64, h=64, i=indices.flatten(), p=palette)
  with open(output_path + '.npz', 'wb') as f:
    f.write(palettized_bytes.getbuffer())

  compressed_bytes = zlib.compress(
      np.concatenate((palette, indices), dtype=np.uint8).tobytes(), level=9
  )
  with open(output_path + '.bin', 'wb') as f:
    f.write(compressed_bytes)

  sd_bytes = len(compressed_bytes)

  indices = indices.reshape((64, 64))
  palettized_latent_img = Image.fromarray(indices, mode='P')
  palettized_latent_img.putpalette(palette, rawmode='RGBA')
  latents = np.array(palettized_latent_img.convert('RGBA'))
  latents = unquantize(latents)
  palettized_img = to_img(latents)
  display(palettized_img)
  print('VAE decoding of palettized and dithered 8-bit latents')
  print_metrics(img, palettized_img)

  latents = denoise(latents)
  denoised_img = to_img(latents)
  display(denoised_img)
  del latents
  print('VAE decoding of de-noised dithered 8-bit latents')
  print('Size: {:.2f} kB'.format(os.stat('/content/output/lena.bin').st_size / 1024))
  print_metrics(img, denoised_img)

  jpg_bytes = io.BytesIO()
  q = 0 
  while jpg_bytes.getbuffer().nbytes < sd_bytes:
    jpg_bytes = io.BytesIO()
    img.save(jpg_bytes, format='JPEG', quality=q, optimize=True, subsampling=1)
    jpg_bytes.flush()
    jpg_bytes.seek(0)
    jpg_bytes = io.BytesIO(mozjpeg_lossless_optimization.optimize(jpg_bytes.read()))
    jpg_bytes.flush()
    q += 1

  with open(output_path + '.jpg', 'wb') as f:
    f.write(jpg_bytes.getbuffer())
  
  jpg = Image.open(jpg_bytes)
  try:
    display(jpg)
    print('JPG comprassed with quality setting: {}'.format(q))
    print('size: {:.2f} kB'.format(jpg_bytes.getbuffer().nbytes / 1024))
    print_metrics(img, jpg)
  except:
    print('Sth went wrong compressing {}.jpg'.format(output_path))

  webp_bytes = io.BytesIO()
  q = 0
  while webp_bytes.getbuffer().nbytes < sd_bytes:
    webp_bytes = io.BytesIO()
    img.save(webp_bytes, format='WEBP', quality=q, method=6)    
    webp_bytes.flush()
    q += 1

  with open(output_path + '.webp', 'wb') as f:
    f.write(webp_bytes.getbuffer())
  try:
    webp = Image.open(webp_bytes)
    display(webp)
    print('WebP compressed with quality setting: {}'.format(q))
    print('size: {:.2f} kB'.format(webp_bytes.getbuffer().nbytes / 1024))
    print_metrics(img, webp)
  except:
    print('Sth went wrong compressing {}.webp'.format(output_path))

In [None]:
import os 
import shutil
import time
from tqdm import tqdm

rescaled_dir = input_dir + '/rescaled/'

if not os.path.isdir(rescaled_dir):
  os.mkdir(rescaled_dir)

print('Rescaling images to 512x512')
for i, filename in tqdm(enumerate(os.listdir(input_dir))):
  f_in = os.path.join(input_dir, filename)
  f_out = os.path.join(rescaled_dir, os.path.splitext(filename)[0] + '.png')
  if os.path.isfile(f_in) and not os.path.isfile(f_out):
    try:
      resize_to_512(f_in, f_out)
    except:
      print('Skipping {} beacuse the file could not be opened.'.format(filename))

if os.path.isdir(output_dir):
  shutil.rmtree(output_dir)
os.mkdir(output_dir)
for filename in os.listdir(rescaled_dir):
  f = os.path.join(rescaled_dir, filename)
  if os.path.isfile(f):
    compress(f, os.path.splitext(os.path.join(output_dir, filename))[0])
    time.sleep(0.1)

In [None]:
import matplotlib.pyplot as plt
import cv2

filename = 'lena'
fig = plt.figure(figsize=(20, 5))

fig.add_subplot(2, 2, 1)
plt.imshow(plt.imread(input_dir + filename + '.png'))
plt.axis('off')
plt.title('Original')

fig.add_subplot(1, 4, 2)
plt.imshow(plt.imread(output_dir + filename + '.jpg'))
plt.axis('off')
plt.title('JPG (5.41 kB)')

fig.add_subplot(1, 4, 3)
plt.imshow(plt.imread(output_dir + filename + '.webp'))
plt.axis('off')
plt.title('WebP (5.16 kB)')

fig.add_subplot(1, 4, 4)
plt.imshow(plt.imread(output_dir + filename + '.webp'))
plt.axis('off')
plt.title('Stable Diffusion (4.97 kB)')