In [1]:
import os
import sys

os.chdir(os.path.dirname(os.getcwd()))
sys.path.append(os.getcwd())

import torch
from lossless.component.coolchic import (CoolChicEncoder,
                                         CoolChicEncoderParameter)
from lossless.util.config import args, str_args
from lossless.util.image_loading import load_image_as_tensor
from lossless.util.logger import TrainingLogger
from lossless.util.parsecli import (change_n_out_synth,
                                    get_coolchic_param_from_args)

image_index = 0
print(args["input"])
use_color_regression = False
use_image_arm = True
encoder_gain = 16

im_path = args["input"][image_index]
im_tensor, c_bitdepths = load_image_as_tensor(im_path, device="cuda:0")
dataset = im_path.split("/")[-2]

logger = TrainingLogger(
    log_folder_path=args["LOG_PATH"],
    image_name=f"{dataset}_" + im_path.split("/")[-1].split(".")[0],
)
logger.log_result(f"{str_args(args)}")
logger.log_result(f"Processing image {im_path}")

encoder_param = CoolChicEncoderParameter(
    **get_coolchic_param_from_args(args, "lossless")
)
encoder_param.use_image_arm = use_image_arm
encoder_param.use_color_regression = use_color_regression
encoder_param.encoder_gain = encoder_gain
encoder_param.set_image_size((im_tensor.shape[2], im_tensor.shape[3]))
encoder_param.layers_synthesis = change_n_out_synth(
    encoder_param.layers_synthesis, 9 if use_color_regression else 6
)
coolchic = CoolChicEncoder(param=encoder_param)
coolchic.to_device("cuda:0")

['/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim01.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim02.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim03.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim04.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim05.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim06.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim07.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim08.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim09.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim10.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim11.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim12.png', '/home/jakub/ETH/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim13.png', '/home/jakub/ETH/thesis/

In [2]:
# print(coolchic.get_flops())
# print(coolchic.flops_str())
with torch.no_grad():
    print(coolchic.get_total_mac_per_pixel())
    print(coolchic.str_complexity())

1694.3125
| module                            | #parameters or shape   | #flops     |
|:----------------------------------|:-----------------------|:-----------|
| model                             | 0.526M                 | 0.666G     |
|  latent_grids                     |  0.524M                |            |
|   latent_grids.0                  |   (1, 1, 512, 768)     |            |
|   latent_grids.1                  |   (1, 1, 256, 384)     |            |
|   latent_grids.2                  |   (1, 1, 128, 192)     |            |
|   latent_grids.3                  |   (1, 1, 64, 96)       |            |
|   latent_grids.4                  |   (1, 1, 32, 48)       |            |
|   latent_grids.5                  |   (1, 1, 16, 24)       |            |
|   latent_grids.6                  |   (1, 1, 8, 12)        |            |
|  synthesis.layers                 |  0.57K                 |  0.208G    |
|   synthesis.layers.0              |   0.192K               |   66.06M   |
| 