In [45]:
import os
import sys



os.chdir(os.path.dirname(os.getcwd()))
sys.path.append(os.getcwd())

import torch
from lossless.component.core.arm_image import ImageARMParameter
from lossless.component.coolchic import (CoolChicEncoder,
                                         CoolChicEncoderParameter)
from lossless.configs.config import args, str_args
from lossless.util.image_loading import load_image_as_tensor
from lossless.util.logger import TrainingLogger
from lossless.util.parsecli import (change_n_out_synth,
                                    get_coolchic_param_from_args)
from lossless.util.command_line_args_loading import load_args

image_index = 0
use_image_arm = True
encoder_gain = 64

command_line_args = load_args(notebook_overrides = {
    "image_index": image_index,
    "encoder_gain": encoder_gain,
    "color_space": "YCoCg",
    "use_image_arm": use_image_arm,
    "multiarm_setup": "1x1",

})
im_path = args["input"][command_line_args.image_index]
im_tensor, colorspace_bitdepths = load_image_as_tensor(
    im_path, device="cuda:0", color_space=command_line_args.color_space
)

# dataset = im_path.split("/")[-2]
# logger = TrainingLogger(
#     log_folder_path=args["LOG_PATH"],
#     image_name=f"{dataset}_" + im_path.split("/")[-1].split(".")[0],
# )
# # logger.log_result(f"{str_args(args)}")
# print(logger.results_logs_path)
# logger.log_result(f"Processing image {im_path}")

# ORIGINAL:
args["arm_image_params"] = ImageARMParameter(context_size=8, n_hidden_layers=2, hidden_layer_dim=6)
args["layers_synthesis_lossless"] = "24-1-1-linear-relu,X-1-1-linear-none,X-3-3-residual-relu,X-3-3-residual-none"

# CHANGED:
# args["arm_image_params"] = ImageARMParameter(context_size=16, n_hidden_layers=2, hidden_layer_dim=48)
args["layers_synthesis_lossless"] = "26-3-1-linear-relu,26-3-1-residual-none,26-3-1-residual-relu,X-3-1-linear-none"

encoder_param = get_coolchic_param_from_args(
    args,
    "lossless",
    image_size=(im_tensor.shape[2], im_tensor.shape[3]),
    use_image_arm=command_line_args.use_image_arm,
    encoder_gain=command_line_args.encoder_gain,
    multi_region_image_arm_setup=command_line_args.multiarm_setup,
)
encoder_param.use_image_arm = use_image_arm
encoder_param.encoder_gain = encoder_gain
encoder_param.set_image_size((im_tensor.shape[2], im_tensor.shape[3]))
encoder_param.layers_synthesis = change_n_out_synth(
    encoder_param.layers_synthesis, 9 if args['use_color_regression'] else 6
)
coolchic = CoolChicEncoder(param=encoder_param)
coolchic.to_device("cuda:0")

In [46]:
# print(coolchic.get_flops())
# print(coolchic.flops_str())
with torch.no_grad():
    print(coolchic.get_total_mac_per_pixel())
    print(coolchic.str_complexity())

16376.3125
| module                                                     | #parameters or shape   | #flops     |
|:-----------------------------------------------------------|:-----------------------|:-----------|
| model                                                      | 0.541M                 | 6.439G     |
|  latent_grids                                              |  0.524M                |            |
|   latent_grids.0                                           |   (1, 1, 512, 768)     |            |
|   latent_grids.1                                           |   (1, 1, 256, 384)     |            |
|   latent_grids.2                                           |   (1, 1, 128, 192)     |            |
|   latent_grids.3                                           |   (1, 1, 64, 96)       |            |
|   latent_grids.4                                           |   (1, 1, 32, 48)       |            |
|   latent_grids.5                                           |   (1, 1, 16, 24) 