diff --git a/README.md b/README.md index f2d5907..f32c9c0 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ JPG, 0.264 bpp / 90.1 kB ``` ![guess](assets/comparison/camp_jpg_compress.png) -The image shown is an out-of-sample instance from the CLIC-2020 dataset. The HIFIC image is obtained by reconstruction via the learned model. The JPG image is obtained by the command `mogrify -format jpg -quality 42 camp_original.png`. All images are losslessly compressed to PNG format for viewing. Images stored under `assets/comparison`. Note that the learned model was not adapted in any way for evaluation of this image. +The image shown is an out-of-sample instance from the CLIC-2020 dataset. The HIFIC image is obtained by reconstruction via the learned model. The JPG image is obtained by the command `mogrify -format jpg -quality 42 camp_original.png`. All images are losslessly compressed to PNG format for viewing. Images and other examples are stored under `assets/comparison`. Note that the learned model was not adapted in any way for evaluation of this image. ## Details This repository defines a model for learnable image compression capable of compressing images of arbitrary size and resolution. There are three main components to this model, as described in the original paper: @@ -61,7 +61,7 @@ python3 train.py --model_type compression --regime low --n_steps 1e6 ``` python3 train.py --model_type compression_gan --regime low --n_steps 1e6 --warmstart --ckpt path/to/base/checkpoint ``` -* Training after the warmstart for 2e5 steps using a batch size of 16 was sufficient to get reasonable results at sub-0.2 `bpp` on average. +* Training after the warmstart for 2e5 steps using a batch size of 16 was sufficient to get reasonable results at sub-0.2 `bpp` per image, on average using the default config. * If you get out-of-memory errors, try: * Reducing the number of residual blocks in the generator (default 7, the original paper used 9). * Decreasing the batch size (default 16). @@ -73,23 +73,24 @@ tensorboard --logdir experiments/my_experiment/tensorboard ``` ### Compression -* To obtain a _theoretical_ measure of the bitrate under some trained model, run `compress.py`. This will report the bits-per-pixel attainable by the compressed representation (`bpp`), some other fun metrics, and perform a forward pass through the model to obtain the reconstructed image. This model will work with images of arbitrary sizes and resolution. +* To obtain a _theoretical_ measure of the bitrate under some trained model, run `compress.py`. This will report the bits-per-pixel attainable by the compressed representation (`bpp`), some other fun metrics, and perform a forward pass through the model to obtain the reconstructed image (as a PNG). This model will work with images of arbitrary sizes and resolution (provided you don't run out of memory). This will work with JPG and PNG (without alpha channels). ``` -python3 compress.py --img path/to/image/dir --ckpt path/to/trained/model +python3 compress.py -i path/to/image/dir -ckpt path/to/trained/model ``` +* A pretrained model using the OpenImages dataset can be found here: [Drive link]. This model was trained for 2e5 warmup steps and 2e5 steps with the full generative loss. To use this, download the model and point the `-ckpt` argument in the command above to the corresponding path. + * The reported `bpp` is the theoretical bitrate required to losslessly store the quantized latent representation of an image as determined by the learned probability model provided by the hyperprior using some entropy coding algorithm. Comparing this (not the size of the reconstruction) against the original size of the image will give you an idea of the reduction in memory footprint. This repository does not currently support actual compression to a bitstring ([TensorFlow Compression](https://github.com/tensorflow/compression) does this well). We're working on an ANS entropy coder to support this in the future. ### Notes * The "size" of the compressed image as reported in `bpp` does not account for the size of the model required to decode the compressed format. -* The total size of the model (using the original architecture) is around 737 MB. Forward pass time should scale sublinearly provided everything fits in memory. -* You may get an OOM error when compressing images which are too large. We're working on a fix. +* The total size of the model (using the original architecture) is around 737 MB. Forward pass time should scale sublinearly provided everything fits in memory. A complete forward pass using a batch of 10 images takes around 45s on a 2.8 GHz Intel Core i7. +* You may get an OOM error when compressing images which are too large (`>~ 4000 x 4000`). It's possible to get around this by applying the network to evenly sized crops of the input image whose forward pass will fit in memory. We're working on a fix to automatically support this. ### Contributing All content in this repository is licensed under the Apache-2.0 license. Feel free to submit any corrections or suggestions as issues. ### Acknowledgements * The code under `hific/perceptual_similarity/` implementing the perceptual distortion loss is modified from the [Perceptual Similarity repository](https://github.com/richzhang/PerceptualSimilarity). - ### Authors * Grace Han diff --git a/compress.py b/compress.py index f472708..638ac90 100644 --- a/compress.py +++ b/compress.py @@ -67,7 +67,8 @@ def compress_batch(args): input_filenames_total.extend(filenames) for subidx in range(reconstruction.shape[0]): - fname = os.path.join(args.output_dir, "{}_RECON.png".format(filenames[subidx])) + bpp_per_im = float(bpp[subidx].cpu().numpy()) + fname = os.path.join(args.output_dir, "{}_RECON_{:.3f}bpp.png".format(filenames[subidx], bpp_per_im)) torchvision.utils.save_image(reconstruction[subidx], fname, normalize=True) output_filenames_total.append(fname) @@ -97,7 +98,7 @@ def compress_batch(args): def main(**kwargs): - description = "Compresses batch of images using specified learned model." + description = "Compresses batch of images using learned model specified via -ckpt argument." parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("-ckpt", "--ckpt_path", type=str, required=True, help="Path to model to be restored") @@ -106,7 +107,7 @@ def main(**kwargs): parser.add_argument("-o", "--output_dir", type=str, default='data/reconstructions', help="Path to directory to store output images") parser.add_argument('-bs', '--batch_size', type=int, default=1, - help="Dataloader batch size. Set to 1 for images of different sizes.") + help="Loader batch size. Set to 1 if images in directory are different sizes.") args = parser.parse_args() input_images = glob.glob(os.path.join(args.image_dir, '*.jpg')) diff --git a/default_config.py b/default_config.py index 309e100..dff96f0 100644 --- a/default_config.py +++ b/default_config.py @@ -13,7 +13,7 @@ class ModelTypes(object): class ModelModes(object): TRAINING = 'training' - VALIDATION = 'validation' # Monitoring + VALIDATION = 'validation' EVALUATION = 'evaluation' class Datasets(object): @@ -30,13 +30,13 @@ class directories(object): experiments = 'experiments' class checkpoints(object): - gan1 = 'experiments/gan_med_bitrate_openimages_compression_gan_2020_08_14_07_12/checkpoints/gan_med_bitrate_openimages_compression_gan_2020_08_14_07_12_epoch1_idx56776_2020_08_14_18:43.pt' + gan1 = 'experiments/lossless.pt' class args(object): """ Shared config """ - name = 'hific_v0' + name = 'hific_v0.1' silent = True n_epochs = 8 n_steps = 1e6 @@ -52,8 +52,8 @@ class args(object): model_mode = ModelModes.TRAINING # Architecture params - Table 3a) of [1] - latent_channels = 220 #220 - n_residual_blocks = 7 #7 # Authors use 9 blocks, performance saturates at 5 + latent_channels = 220 + n_residual_blocks = 7 # Authors use 9 blocks, performance saturates at 5 lambda_B = 2**(-4) # Loose rate k_M = 0.075 * 2**(-5) # Distortion k_P = 1. # Perceptual loss diff --git a/src/helpers/maths.py b/src/helpers/maths.py index 89c0c57..0507005 100644 --- a/src/helpers/maths.py +++ b/src/helpers/maths.py @@ -12,7 +12,7 @@ def backward(ctx, grad_output): return grad_output.clone(), None -class LowerBoundToward_0(torch.autograd.Function): +class LowerBoundToward(torch.autograd.Function): """ Assumes output shape is identical to input shape. """ @@ -24,26 +24,9 @@ def forward(ctx, tensor, lower_bound): @staticmethod def backward(ctx, grad_output): - # gate = torch.autograd.Variable(torch.logical_or(ctx.mask, grad_output.lt(0.)).type(grad_output.dtype)) - gate = torch.autograd.Variable(torch.logical_or(ctx.mask, grad_output.lt(0.)).type_as(grad_output.data)) + gate = torch.logical_or(ctx.mask, grad_output.lt(0.)).type(grad_output.dtype) return grad_output * gate, None -class LowerBoundToward(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, bound): - b = torch.ones_like(inputs) * bound - ctx.save_for_backward(inputs, b) - return torch.max(inputs, b) - - @staticmethod - def backward(ctx, grad_output): - inputs, b = ctx.saved_tensors - pass_through_1 = inputs >= b - pass_through_2 = grad_output < 0 - - pass_through = pass_through_1 | pass_through_2 - return pass_through.type(grad_output.dtype) * grad_output, None - def standardized_CDF_gaussian(value): # Gaussian # return 0.5 * (1. + torch.erf(value/ np.sqrt(2))) diff --git a/src/model.py b/src/model.py index c947cc6..1428524 100644 --- a/src/model.py +++ b/src/model.py @@ -22,7 +22,7 @@ Intermediates = namedtuple("Intermediates", ["input_image", # [0, 1] (after scaling from [0, 255]) "reconstruction", # [0, 1] - "latents_quantized", # Latents post-quantization. + "latents_quantized", # Latents post-quantization. "n_bpp", # Differential entropy estimate. "q_bpp"]) # Shannon entropy estimate. @@ -134,6 +134,7 @@ def compression_forward(self, x): total_nbpp = hyperinfo.total_nbpp total_qbpp = hyperinfo.total_qbpp + # Use quantized latents as input to G reconstruction = self.Generator(latents_quantized) if self.args.normalize_input_image is True: @@ -160,7 +161,6 @@ def discriminator_forward(self, intermediates, train_generator): D_in = torch.cat([x_real, x_gen], dim=0) latents = intermediates.latents_quantized.detach() - # latents = torch.cat([latents, latents], dim=0) latents = torch.repeat_interleave(latents, 2, dim=0) D_out, D_out_logits = self.Discriminator(D_in, latents) @@ -170,14 +170,11 @@ def discriminator_forward(self, intermediates, train_generator): D_real, D_gen = torch.chunk(D_out, 2, dim=0) D_real_logits, D_gen_logits = torch.chunk(D_out_logits, 2, dim=0) - # Tensorboard - # real_response, gen_response = D_real.mean(), D_fake.mean() - return Disc_out(D_real, D_gen, D_real_logits, D_gen_logits) def distortion_loss(self, x_gen, x_real): # loss in [0,255] space but normalized by 255 to not be too big - # - Delegate to weighting + # - Delegate scaling to weighting sq_err = self.squared_difference(x_gen*255., x_real*255.) # / 255. return torch.mean(sq_err) @@ -196,30 +193,18 @@ def compression_loss(self, intermediates, hyperinfo): x_real = (x_real + 1.) / 2. x_gen = (x_gen + 1.) / 2. - # print('X REAL MAX', x_real.max()) - # print('X REAL MIN', x_real.min()) - # print('X GEN MAX', x_gen.max()) - # print('X GEN MIN', x_gen.min()) - distortion_loss = self.distortion_loss(x_gen, x_real) perceptual_loss = self.perceptual_loss_wrapper(x_gen, x_real, normalize=True) weighted_distortion = self.args.k_M * distortion_loss weighted_perceptual = self.args.k_P * perceptual_loss - # print('Distortion loss size', weighted_distortion.size()) - # print('Perceptual loss size', weighted_perceptual.size()) - weighted_rate, rate_penalty = losses.weighted_rate_loss(self.args, total_nbpp=intermediates.n_bpp, total_qbpp=intermediates.q_bpp, step_counter=self.step_counter) - # print('Weighted rate loss size', weighted_rate.size()) weighted_R_D_loss = weighted_rate + weighted_distortion weighted_compression_loss = weighted_R_D_loss + weighted_perceptual - # print('Weighted R-D loss size', weighted_R_D_loss.size()) - # print('Weighted compression loss size', weighted_compression_loss.size()) - # Bookkeeping if (self.step_counter % self.log_interval == 1): self.store_loss('rate_penalty', rate_penalty) diff --git a/src/network/hyperprior.py b/src/network/hyperprior.py index 170de54..e04290f 100644 --- a/src/network/hyperprior.py +++ b/src/network/hyperprior.py @@ -45,17 +45,15 @@ def _quantize(self, x, mode='noise', means=None): if mode == 'noise': quantization_noise = torch.nn.init.uniform_(torch.zeros_like(x), -0.5, 0.5) - # quantization_noise = torch.rand(x.size()).to(x) - 0.5 x = x + quantization_noise - elif mode == 'quantize': + elif mode == 'quantize': if means is not None: x = x - means x = torch.floor(x + 0.5) x = x + means else: x = torch.floor(x + 0.5) - # x = torch.round(x) else: raise NotImplementedError @@ -71,16 +69,8 @@ def _estimate_entropy(self, likelihood, spatial_shape): n_pixels = np.prod(spatial_shape) log_likelihood = torch.log(likelihood + EPS) - # print('LOG LIKELIHOOD', log_likelihood.mean().item()) n_bits = torch.sum(log_likelihood) / (batch_size * quotient) bpp = n_bits / n_pixels - # print('N_PIXELS', n_pixels) - # print('BATCH SIZE', batch_size) - # print('LH', likelihood) - #print('LH MAX', likelihood.max()) - #print('LH MAX', likelihood.min()) - #print('NB', n_bits) - #print('BPP', bpp) return n_bits, bpp @@ -192,13 +182,13 @@ def likelihood(self, x): # Numerical stability using some sigmoid identities # to avoid subtraction of two numbers close to 1 - # sign = -torch.sign(cdf_upper + cdf_lower) - # sign = sign.detach() - # likelihood_ = torch.abs( - # torch.sigmoid(sign * cdf_upper) - torch.sigmoid(sign * cdf_lower)) + sign = -torch.sign(cdf_upper + cdf_lower) + sign = sign.detach() + likelihood_ = torch.abs( + torch.sigmoid(sign * cdf_upper) - torch.sigmoid(sign * cdf_lower)) # Naive - likelihood_ = torch.sigmoid(cdf_upper) - torch.sigmoid(cdf_lower) + # likelihood_ = torch.sigmoid(cdf_upper) - torch.sigmoid(cdf_lower) # Reshape to (N,C,H,W) likelihood_ = torch.reshape(likelihood_, shape) @@ -268,13 +258,13 @@ def latent_likelihood(self, x, mean, scale): # Assumes 1 - CDF(x) = CDF(-x) x = x - mean - # x = torch.abs(x) - # cdf_upper = self.standardized_CDF((0.5 - x) / scale) - # cdf_lower = self.standardized_CDF(-(0.5 + x) / scale) + x = torch.abs(x) + cdf_upper = self.standardized_CDF((0.5 - x) / scale) + cdf_lower = self.standardized_CDF(-(0.5 + x) / scale) # Naive - cdf_upper = self.standardized_CDF( (x + 0.5) / scale ) - cdf_lower = self.standardized_CDF( (x - 0.5) / scale ) + # cdf_upper = self.standardized_CDF( (x + 0.5) / scale ) + # cdf_lower = self.standardized_CDF( (x - 0.5) / scale ) likelihood_ = cdf_upper - cdf_lower likelihood_ = lower_bound_toward(likelihood_, self.min_likelihood) @@ -298,9 +288,6 @@ def forward(self, latents, spatial_shape, **kwargs): quantized_hyperlatent_bits, quantized_hyperlatent_bpp = self._estimate_entropy( quantized_hyperlatent_likelihood, spatial_shape) - #print('QUANT HL', quantized_hyperlatents) - #print('maxQUANT HL', quantized_hyperlatents.max()) - #print('minQUANT HL', quantized_hyperlatents.min()) if self.training is True: hyperlatents_decoded = noisy_hyperlatents else: @@ -343,11 +330,6 @@ def forward(self, latents, spatial_shape, **kwargs): side_bitstring=None, # TODO ) - # print(quantized_latents) - # print(quantized_hyperlatents) - # print(noisy_latents) - # print(noisy_hyperlatents) - return info class HyperpriorAnalysis(nn.Module): diff --git a/train.py b/train.py index e6f5f8f..ce1431b 100644 --- a/train.py +++ b/train.py @@ -27,7 +27,6 @@ from default_config import hific_args, mse_lpips_args, directories, ModelModes, ModelTypes # go fast boi!! -# Optimizes cuda kernels by benchmarking - no dynamic input sizes! torch.backends.cudnn.benchmark = True def create_model(args, device, logger, storage, storage_test): @@ -304,19 +303,9 @@ def train(args, model, train_loader, test_loader, device, logger, optimizers): else: model = create_model(args, device, logger, storage, storage_test) model = model.to(device) - # amortization_parameters = itertools.chain.from_iterable( - # [am.parameters() for am in model.amortization_models]) - - amort_names, amortization_parameters = list(), list() - for n, p in model.named_parameters(): - if ('Encoder' in n) or ('Generator' in n): - amort_names.append(n) - amortization_parameters.append(p) - logger.info(f'AM {n} - {p.shape}') - if ('analysis' in n) or ('synthesis' in n): - amort_names.append(n) - amortization_parameters.append(p) - logger.info(f'AM {n} - {p.shape}') + amortization_parameters = itertools.chain.from_iterable( + [am.parameters() for am in model.amortization_models]) + hyperlatent_likelihood_parameters = model.Hyperprior.hyperlatent_likelihood.parameters() amortization_opt = torch.optim.Adam(amortization_parameters, @@ -332,6 +321,8 @@ def train(args, model, train_loader, test_loader, device, logger, optimizers): n_gpus = torch.cuda.device_count() if n_gpus > 1 and args.multigpu is True: + # Not supported at this time + raise NotImplementedError('MultiGPU not supported yet.') logger.info('Using {} GPUs.'.format(n_gpus)) model = nn.DataParallel(model)