To run this colab you can use your own colab setup or try
[Sandwich Image Compression Lowres Codec](https://colab.research.google.com/github/google/sandwiched_compression/blob/main/sandwich_image_compression_lowres_codec.ipynb).


In [1]:
!pip install -q mediapy tensorflow-datasets==4.9.4

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
!if [ ! -f compress_intra_model.py ]; then \
  git clone https://github.com/google/sandwiched_compression; \
  mv sandwiched_compression/* .; \
fi

Cloning into 'sandwiched_compression'...
remote: Enumerating objects: 213, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 213 (delta 58), reused 25 (delta 25), pack-reused 125 (from 2)[K
Receiving objects: 100% (213/213), 64.64 MiB | 4.66 MiB/s, done.
Resolving deltas: 100% (95/95), done.


In [3]:
import tensorflow as tf

import mediapy as media
import compress_intra_model
import datasets

In [4]:
# See https://www.tensorflow.org/datasets for datasets to try.
def dataset_fn(
    batch_size: int, training_mode: bool, take_count: int = 100
) -> tf.data.Dataset:
  return datasets.load_tfds_image_dataset(
      batch_size=batch_size,
      training_mode=training_mode,
      dataset_name='clic',  # insert preferred dataset name.
      target_size=256,
  ).take(
      take_count  # Will draw take_count batches randomly in each epoch.
  )  # Images are 256x256

In [5]:
# gamma is the Lagrangian multiplier for D + \gamma R loss.
def create_grayscale_codec_model(gamma: float) -> tf.keras.Model:
  return compress_intra_model.create_basic_model(
      model_keys=['image'],
      bottleneck_channels=1,  # grayscale
      output_channels=3,
      num_mlp_layers=2,
      use_jpeg_rate_model=True,
      downsample_factor=1,  # full-res
      num_truncate_bits=0,
      gamma=gamma,
      loop_filter_folder=None,  # Check code to see how to train one separately.
      use_unet_preprocessor=True,
      use_unet_postprocessor=True,
  )

# Only change two parameters for the low-res codec scenario.
def create_lowres_codec_model(gamma: float) -> tf.keras.Model:
  return compress_intra_model.create_basic_model(
      model_keys=['image'],
      bottleneck_channels=3,  # rgb
      output_channels=3,
      num_mlp_layers=2,
      use_jpeg_rate_model=True,
      downsample_factor=2,  # half-res
      num_truncate_bits=0,
      gamma=gamma,
      loop_filter_folder=None,
      use_unet_preprocessor=True,
      use_unet_postprocessor=True,
  )

In [6]:
train_batch_size = 4
train_dataset = dataset_fn(train_batch_size, True)  # Pull from train split.
eval_batch_size = 1
eval_dataset = dataset_fn(eval_batch_size, False)  # Pull from eval split.

Downloading and preparing dataset 7.48 GiB (download: 7.48 GiB, generated: 7.48 GiB, total: 14.96 GiB) to /root/tensorflow_datasets/clic/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/1633 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/clic/1.0.0.incompleteW9DN3Y/clic-train.tfrecord*...:   0%|          | 0/16…

Generating validation examples...:   0%|          | 0/102 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/clic/1.0.0.incompleteW9DN3Y/clic-validation.tfrecord*...:   0%|          |…

Generating test examples...:   0%|          | 0/428 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/clic/1.0.0.incompleteW9DN3Y/clic-test.tfrecord*...:   0%|          | 0/428…

Dataset clic downloaded and prepared to /root/tensorflow_datasets/clic/1.0.0. Subsequent calls will reuse this data.


In [None]:
# Simple trainer. It is recommended to use a custom trainer and train to
# convergence.

num_epochs = 800
gamma = 50  # Lagrange multiplier

base_model = create_lowres_codec_model(gamma)
learning_rate = 1e-3
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

epoch_stat = tf.keras.metrics.Mean()
loss_fn = compress_intra_model.create_basic_loss(gamma=gamma)

for i in range(num_epochs):
  for train_batch in train_dataset:
    with tf.GradientTape() as tape:
      out = base_model(train_batch)
      loss = loss_fn(train_batch, out)

      gradients = tape.gradient(loss, base_model.trainable_variables)
      optimizer.apply_gradients(zip(gradients, base_model.trainable_variables))
      epoch_stat(loss)

  # Note each epoch is over a varying set of take_count x batch_size images.
  # Calculate a median or change the dataset loader to always use the same set
  # if you prefer.
  print(f'Epoch {i:=4d}/{num_epochs:=4d} Loss: {epoch_stat.result():=4.4f}')
  epoch_stat.reset_state()

1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Using a symbolic `tf.Tensor` as a Python `bool` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#acce

Epoch    0/ 800 Loss: 2231.4250
Epoch    1/ 800 Loss: 456.7496
Epoch    2/ 800 Loss: 287.4542
Epoch    3/ 800 Loss: 254.8706
Epoch    4/ 800 Loss: 213.7722
Epoch    5/ 800 Loss: 197.6987
Epoch    6/ 800 Loss: 182.4476


In [None]:
# It is recommended to generate R-D curves by training multiple models for
# multiple gammas, then evaluate each model for multiple qsteps, and construct
# the Pareto frontier. Please see the paper for details:
# https://arxiv.org/abs/2402.05887

# Discussion on the results shown below:
#
# For the low-res codec scenario pay attention to areas where the simple linear
# has lost detail, e.g., text and textures, has aliasing, e.g., merging
# lines/edges, etc. Notice how much better the model predictions are and also
# notice what compressed-bottlenecks transport in these areas. Running
# post-processing-only models will typically generate wrong results in these
# areas. Please see the paper for examples.
#
# One can design models to hallucinate detail but it is important to understand
# that hallucination is not accurate transport. When one watches a movie one
# wants to see it as the director, cinematographer, etc., have intended it. One
# does not want to see some model's hallucinated reinterpretation of the
# art/reality.

def simple_linear_path(sample: tf.Tensor) -> tf.Tensor:
  factor = base_model.downsample_factor
  low_res = tf.image.resize(
      sample,
      size=[sample.shape[0] // factor, sample.shape[1] // factor],
      method=tf.image.ResizeMethod.BICUBIC,
  )
  return tf.image.resize(
      low_res,
      size=[sample.shape[0], sample.shape[1]],
      method=tf.image.ResizeMethod.LANCZOS3,
  )


# Pictures to show. Can also look at the proxy rate through 'rate', calculate
# distortion or whatever else you would like.
show_keys = ['prediction', 'compressed_bottleneck']
show_count = 10

# Upsample the bottlenecks using nearest neighor for clarity.
upsample_keys = ['compressed_bottleneck']

for idx, sample in enumerate(eval_dataset.as_numpy_iterator()):
  if idx >= show_count:
    break

  # Path 1: Simple demo:
  # Run the pre-processor, codec-proxy, and the post-procesor.
  output = base_model(sample)

  # Path 2: Actual performance with your codec:
  # Run the pre-processor, your codec, then post-processor.
  #
  # bottlenecks = base_model.run_preprocessor(sample, training=False)
  # compressed_bottlenecks = insert_your_image_codec_binary(bottlenecks)
  # output = {
  #     'bottleneck': bottlenecks,
  #     'compressed_bottleneck': compressed_bottlenecks,
  #     'prediction': base_model.run_postprocessor(
  #         compressed_bottlenecks, training=False
  #     ),
  # }

  images = {'original': sample['image'][0] / 255}

  # Emulate linear up-down without compression.
  images['simple_linear'] = simple_linear_path(images['original'])

  images.update(
      {key: value[0] / 255 for key, value in output.items() if key in show_keys}
  )
  for key in upsample_keys:
    images[key] = tf.image.resize(
        images[key],
        size=[sample['image'].shape[1], sample['image'].shape[2]],
        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
    ).numpy()
  media.show_images(images, height=512)