Skip to content

Commit

Permalink
minor text fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Justin-Tan committed Aug 17, 2020
1 parent 9542165 commit 7077e9e
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 95 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ JPG, 0.264 bpp / 90.1 kB
```
![guess](assets/comparison/camp_jpg_compress.png)

The image shown is an out-of-sample instance from the CLIC-2020 dataset. The HIFIC image is obtained by reconstruction via the learned model. The JPG image is obtained by the command `mogrify -format jpg -quality 42 camp_original.png`. All images are losslessly compressed to PNG format for viewing. Images stored under `assets/comparison`. Note that the learned model was not adapted in any way for evaluation of this image.
The image shown is an out-of-sample instance from the CLIC-2020 dataset. The HIFIC image is obtained by reconstruction via the learned model. The JPG image is obtained by the command `mogrify -format jpg -quality 42 camp_original.png`. All images are losslessly compressed to PNG format for viewing. Images and other examples are stored under `assets/comparison`. Note that the learned model was not adapted in any way for evaluation of this image.

## Details
This repository defines a model for learnable image compression capable of compressing images of arbitrary size and resolution. There are three main components to this model, as described in the original paper:
Expand Down Expand Up @@ -61,7 +61,7 @@ python3 train.py --model_type compression --regime low --n_steps 1e6
```
python3 train.py --model_type compression_gan --regime low --n_steps 1e6 --warmstart --ckpt path/to/base/checkpoint
```
* Training after the warmstart for 2e5 steps using a batch size of 16 was sufficient to get reasonable results at sub-0.2 `bpp` on average.
* Training after the warmstart for 2e5 steps using a batch size of 16 was sufficient to get reasonable results at sub-0.2 `bpp` per image, on average using the default config.
* If you get out-of-memory errors, try:
* Reducing the number of residual blocks in the generator (default 7, the original paper used 9).
* Decreasing the batch size (default 16).
Expand All @@ -73,23 +73,24 @@ tensorboard --logdir experiments/my_experiment/tensorboard
```

### Compression
* To obtain a _theoretical_ measure of the bitrate under some trained model, run `compress.py`. This will report the bits-per-pixel attainable by the compressed representation (`bpp`), some other fun metrics, and perform a forward pass through the model to obtain the reconstructed image. This model will work with images of arbitrary sizes and resolution.
* To obtain a _theoretical_ measure of the bitrate under some trained model, run `compress.py`. This will report the bits-per-pixel attainable by the compressed representation (`bpp`), some other fun metrics, and perform a forward pass through the model to obtain the reconstructed image (as a PNG). This model will work with images of arbitrary sizes and resolution (provided you don't run out of memory). This will work with JPG and PNG (without alpha channels).
```
python3 compress.py --img path/to/image/dir --ckpt path/to/trained/model
python3 compress.py -i path/to/image/dir -ckpt path/to/trained/model
```
* A pretrained model using the OpenImages dataset can be found here: [Drive link]. This model was trained for 2e5 warmup steps and 2e5 steps with the full generative loss. To use this, download the model and point the `-ckpt` argument in the command above to the corresponding path.

* The reported `bpp` is the theoretical bitrate required to losslessly store the quantized latent representation of an image as determined by the learned probability model provided by the hyperprior using some entropy coding algorithm. Comparing this (not the size of the reconstruction) against the original size of the image will give you an idea of the reduction in memory footprint. This repository does not currently support actual compression to a bitstring ([TensorFlow Compression](https://github.com/tensorflow/compression) does this well). We're working on an ANS entropy coder to support this in the future.

### Notes
* The "size" of the compressed image as reported in `bpp` does not account for the size of the model required to decode the compressed format.
* The total size of the model (using the original architecture) is around 737 MB. Forward pass time should scale sublinearly provided everything fits in memory.
* You may get an OOM error when compressing images which are too large. We're working on a fix.
* The total size of the model (using the original architecture) is around 737 MB. Forward pass time should scale sublinearly provided everything fits in memory. A complete forward pass using a batch of 10 images takes around 45s on a 2.8 GHz Intel Core i7.
* You may get an OOM error when compressing images which are too large (`>~ 4000 x 4000`). It's possible to get around this by applying the network to evenly sized crops of the input image whose forward pass will fit in memory. We're working on a fix to automatically support this.

### Contributing
All content in this repository is licensed under the Apache-2.0 license. Feel free to submit any corrections or suggestions as issues.

### Acknowledgements
* The code under `hific/perceptual_similarity/` implementing the perceptual distortion loss is modified from the [Perceptual Similarity repository](https://github.com/richzhang/PerceptualSimilarity).
<!-- * The cat in the main image is my neighbour's. -->

### Authors
* Grace Han
Expand Down
7 changes: 4 additions & 3 deletions compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def compress_batch(args):
input_filenames_total.extend(filenames)

for subidx in range(reconstruction.shape[0]):
fname = os.path.join(args.output_dir, "{}_RECON.png".format(filenames[subidx]))
bpp_per_im = float(bpp[subidx].cpu().numpy())
fname = os.path.join(args.output_dir, "{}_RECON_{:.3f}bpp.png".format(filenames[subidx], bpp_per_im))
torchvision.utils.save_image(reconstruction[subidx], fname, normalize=True)
output_filenames_total.append(fname)

Expand Down Expand Up @@ -97,7 +98,7 @@ def compress_batch(args):

def main(**kwargs):

description = "Compresses batch of images using specified learned model."
description = "Compresses batch of images using learned model specified via -ckpt argument."
parser = argparse.ArgumentParser(description=description,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-ckpt", "--ckpt_path", type=str, required=True, help="Path to model to be restored")
Expand All @@ -106,7 +107,7 @@ def main(**kwargs):
parser.add_argument("-o", "--output_dir", type=str, default='data/reconstructions',
help="Path to directory to store output images")
parser.add_argument('-bs', '--batch_size', type=int, default=1,
help="Dataloader batch size. Set to 1 for images of different sizes.")
help="Loader batch size. Set to 1 if images in directory are different sizes.")
args = parser.parse_args()

input_images = glob.glob(os.path.join(args.image_dir, '*.jpg'))
Expand Down
10 changes: 5 additions & 5 deletions default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ModelTypes(object):

class ModelModes(object):
TRAINING = 'training'
VALIDATION = 'validation' # Monitoring
VALIDATION = 'validation'
EVALUATION = 'evaluation'

class Datasets(object):
Expand All @@ -30,13 +30,13 @@ class directories(object):
experiments = 'experiments'

class checkpoints(object):
gan1 = 'experiments/gan_med_bitrate_openimages_compression_gan_2020_08_14_07_12/checkpoints/gan_med_bitrate_openimages_compression_gan_2020_08_14_07_12_epoch1_idx56776_2020_08_14_18:43.pt'
gan1 = 'experiments/lossless.pt'

class args(object):
"""
Shared config
"""
name = 'hific_v0'
name = 'hific_v0.1'
silent = True
n_epochs = 8
n_steps = 1e6
Expand All @@ -52,8 +52,8 @@ class args(object):
model_mode = ModelModes.TRAINING

# Architecture params - Table 3a) of [1]
latent_channels = 220 #220
n_residual_blocks = 7 #7 # Authors use 9 blocks, performance saturates at 5
latent_channels = 220
n_residual_blocks = 7 # Authors use 9 blocks, performance saturates at 5
lambda_B = 2**(-4) # Loose rate
k_M = 0.075 * 2**(-5) # Distortion
k_P = 1. # Perceptual loss
Expand Down
21 changes: 2 additions & 19 deletions src/helpers/maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def backward(ctx, grad_output):
return grad_output.clone(), None


class LowerBoundToward_0(torch.autograd.Function):
class LowerBoundToward(torch.autograd.Function):
"""
Assumes output shape is identical to input shape.
"""
Expand All @@ -24,26 +24,9 @@ def forward(ctx, tensor, lower_bound):

@staticmethod
def backward(ctx, grad_output):
# gate = torch.autograd.Variable(torch.logical_or(ctx.mask, grad_output.lt(0.)).type(grad_output.dtype))
gate = torch.autograd.Variable(torch.logical_or(ctx.mask, grad_output.lt(0.)).type_as(grad_output.data))
gate = torch.logical_or(ctx.mask, grad_output.lt(0.)).type(grad_output.dtype)
return grad_output * gate, None

class LowerBoundToward(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, bound):
b = torch.ones_like(inputs) * bound
ctx.save_for_backward(inputs, b)
return torch.max(inputs, b)

@staticmethod
def backward(ctx, grad_output):
inputs, b = ctx.saved_tensors
pass_through_1 = inputs >= b
pass_through_2 = grad_output < 0

pass_through = pass_through_1 | pass_through_2
return pass_through.type(grad_output.dtype) * grad_output, None

def standardized_CDF_gaussian(value):
# Gaussian
# return 0.5 * (1. + torch.erf(value/ np.sqrt(2)))
Expand Down
21 changes: 3 additions & 18 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Intermediates = namedtuple("Intermediates",
["input_image", # [0, 1] (after scaling from [0, 255])
"reconstruction", # [0, 1]
"latents_quantized", # Latents post-quantization.
"latents_quantized", # Latents post-quantization.
"n_bpp", # Differential entropy estimate.
"q_bpp"]) # Shannon entropy estimate.

Expand Down Expand Up @@ -134,6 +134,7 @@ def compression_forward(self, x):
total_nbpp = hyperinfo.total_nbpp
total_qbpp = hyperinfo.total_qbpp

# Use quantized latents as input to G
reconstruction = self.Generator(latents_quantized)

if self.args.normalize_input_image is True:
Expand All @@ -160,7 +161,6 @@ def discriminator_forward(self, intermediates, train_generator):
D_in = torch.cat([x_real, x_gen], dim=0)

latents = intermediates.latents_quantized.detach()
# latents = torch.cat([latents, latents], dim=0)
latents = torch.repeat_interleave(latents, 2, dim=0)

D_out, D_out_logits = self.Discriminator(D_in, latents)
Expand All @@ -170,14 +170,11 @@ def discriminator_forward(self, intermediates, train_generator):
D_real, D_gen = torch.chunk(D_out, 2, dim=0)
D_real_logits, D_gen_logits = torch.chunk(D_out_logits, 2, dim=0)

# Tensorboard
# real_response, gen_response = D_real.mean(), D_fake.mean()

return Disc_out(D_real, D_gen, D_real_logits, D_gen_logits)

def distortion_loss(self, x_gen, x_real):
# loss in [0,255] space but normalized by 255 to not be too big
# - Delegate to weighting
# - Delegate scaling to weighting
sq_err = self.squared_difference(x_gen*255., x_real*255.) # / 255.
return torch.mean(sq_err)

Expand All @@ -196,30 +193,18 @@ def compression_loss(self, intermediates, hyperinfo):
x_real = (x_real + 1.) / 2.
x_gen = (x_gen + 1.) / 2.

# print('X REAL MAX', x_real.max())
# print('X REAL MIN', x_real.min())
# print('X GEN MAX', x_gen.max())
# print('X GEN MIN', x_gen.min())

distortion_loss = self.distortion_loss(x_gen, x_real)
perceptual_loss = self.perceptual_loss_wrapper(x_gen, x_real, normalize=True)

weighted_distortion = self.args.k_M * distortion_loss
weighted_perceptual = self.args.k_P * perceptual_loss

# print('Distortion loss size', weighted_distortion.size())
# print('Perceptual loss size', weighted_perceptual.size())

weighted_rate, rate_penalty = losses.weighted_rate_loss(self.args, total_nbpp=intermediates.n_bpp,
total_qbpp=intermediates.q_bpp, step_counter=self.step_counter)

# print('Weighted rate loss size', weighted_rate.size())
weighted_R_D_loss = weighted_rate + weighted_distortion
weighted_compression_loss = weighted_R_D_loss + weighted_perceptual

# print('Weighted R-D loss size', weighted_R_D_loss.size())
# print('Weighted compression loss size', weighted_compression_loss.size())

# Bookkeeping
if (self.step_counter % self.log_interval == 1):
self.store_loss('rate_penalty', rate_penalty)
Expand Down
40 changes: 11 additions & 29 deletions src/network/hyperprior.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,15 @@ def _quantize(self, x, mode='noise', means=None):

if mode == 'noise':
quantization_noise = torch.nn.init.uniform_(torch.zeros_like(x), -0.5, 0.5)
# quantization_noise = torch.rand(x.size()).to(x) - 0.5
x = x + quantization_noise
elif mode == 'quantize':

elif mode == 'quantize':
if means is not None:
x = x - means
x = torch.floor(x + 0.5)
x = x + means
else:
x = torch.floor(x + 0.5)
# x = torch.round(x)
else:
raise NotImplementedError

Expand All @@ -71,16 +69,8 @@ def _estimate_entropy(self, likelihood, spatial_shape):
n_pixels = np.prod(spatial_shape)

log_likelihood = torch.log(likelihood + EPS)
# print('LOG LIKELIHOOD', log_likelihood.mean().item())
n_bits = torch.sum(log_likelihood) / (batch_size * quotient)
bpp = n_bits / n_pixels
# print('N_PIXELS', n_pixels)
# print('BATCH SIZE', batch_size)
# print('LH', likelihood)
#print('LH MAX', likelihood.max())
#print('LH MAX', likelihood.min())
#print('NB', n_bits)
#print('BPP', bpp)

return n_bits, bpp

Expand Down Expand Up @@ -192,13 +182,13 @@ def likelihood(self, x):

# Numerical stability using some sigmoid identities
# to avoid subtraction of two numbers close to 1
# sign = -torch.sign(cdf_upper + cdf_lower)
# sign = sign.detach()
# likelihood_ = torch.abs(
# torch.sigmoid(sign * cdf_upper) - torch.sigmoid(sign * cdf_lower))
sign = -torch.sign(cdf_upper + cdf_lower)
sign = sign.detach()
likelihood_ = torch.abs(
torch.sigmoid(sign * cdf_upper) - torch.sigmoid(sign * cdf_lower))

# Naive
likelihood_ = torch.sigmoid(cdf_upper) - torch.sigmoid(cdf_lower)
# likelihood_ = torch.sigmoid(cdf_upper) - torch.sigmoid(cdf_lower)

# Reshape to (N,C,H,W)
likelihood_ = torch.reshape(likelihood_, shape)
Expand Down Expand Up @@ -268,13 +258,13 @@ def latent_likelihood(self, x, mean, scale):

# Assumes 1 - CDF(x) = CDF(-x)
x = x - mean
# x = torch.abs(x)
# cdf_upper = self.standardized_CDF((0.5 - x) / scale)
# cdf_lower = self.standardized_CDF(-(0.5 + x) / scale)
x = torch.abs(x)
cdf_upper = self.standardized_CDF((0.5 - x) / scale)
cdf_lower = self.standardized_CDF(-(0.5 + x) / scale)

# Naive
cdf_upper = self.standardized_CDF( (x + 0.5) / scale )
cdf_lower = self.standardized_CDF( (x - 0.5) / scale )
# cdf_upper = self.standardized_CDF( (x + 0.5) / scale )
# cdf_lower = self.standardized_CDF( (x - 0.5) / scale )

likelihood_ = cdf_upper - cdf_lower
likelihood_ = lower_bound_toward(likelihood_, self.min_likelihood)
Expand All @@ -298,9 +288,6 @@ def forward(self, latents, spatial_shape, **kwargs):
quantized_hyperlatent_bits, quantized_hyperlatent_bpp = self._estimate_entropy(
quantized_hyperlatent_likelihood, spatial_shape)

#print('QUANT HL', quantized_hyperlatents)
#print('maxQUANT HL', quantized_hyperlatents.max())
#print('minQUANT HL', quantized_hyperlatents.min())
if self.training is True:
hyperlatents_decoded = noisy_hyperlatents
else:
Expand Down Expand Up @@ -343,11 +330,6 @@ def forward(self, latents, spatial_shape, **kwargs):
side_bitstring=None, # TODO
)

# print(quantized_latents)
# print(quantized_hyperlatents)
# print(noisy_latents)
# print(noisy_hyperlatents)

return info

class HyperpriorAnalysis(nn.Module):
Expand Down
19 changes: 5 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from default_config import hific_args, mse_lpips_args, directories, ModelModes, ModelTypes

# go fast boi!!
# Optimizes cuda kernels by benchmarking - no dynamic input sizes!
torch.backends.cudnn.benchmark = True

def create_model(args, device, logger, storage, storage_test):
Expand Down Expand Up @@ -304,19 +303,9 @@ def train(args, model, train_loader, test_loader, device, logger, optimizers):
else:
model = create_model(args, device, logger, storage, storage_test)
model = model.to(device)
# amortization_parameters = itertools.chain.from_iterable(
# [am.parameters() for am in model.amortization_models])

amort_names, amortization_parameters = list(), list()
for n, p in model.named_parameters():
if ('Encoder' in n) or ('Generator' in n):
amort_names.append(n)
amortization_parameters.append(p)
logger.info(f'AM {n} - {p.shape}')
if ('analysis' in n) or ('synthesis' in n):
amort_names.append(n)
amortization_parameters.append(p)
logger.info(f'AM {n} - {p.shape}')
amortization_parameters = itertools.chain.from_iterable(
[am.parameters() for am in model.amortization_models])

hyperlatent_likelihood_parameters = model.Hyperprior.hyperlatent_likelihood.parameters()

amortization_opt = torch.optim.Adam(amortization_parameters,
Expand All @@ -332,6 +321,8 @@ def train(args, model, train_loader, test_loader, device, logger, optimizers):

n_gpus = torch.cuda.device_count()
if n_gpus > 1 and args.multigpu is True:
# Not supported at this time
raise NotImplementedError('MultiGPU not supported yet.')
logger.info('Using {} GPUs.'.format(n_gpus))
model = nn.DataParallel(model)

Expand Down

0 comments on commit 7077e9e

Please sign in to comment.