In [12]:
import os
import sys


if os.path.basename(os.getcwd()) == "analysis":
    os.chdir(os.path.dirname(os.getcwd()))
    sys.path.append(os.getcwd())

import numpy as np
import torch
from lossless.component.coolchic import CoolChicEncoder, CoolChicEncoderParameter
from lossless.nnquant.quantizemodel import quantize_model
from lossless.training.loss import loss_function
from lossless.training.manager import ImageEncoderManager
from lossless.util.config import args
from lossless.util.image_loading import load_image_as_tensor
from lossless.util.parsecli import (
    change_n_out_synth,
    get_coolchic_param_from_args,
    get_manager_from_args,
)

torch.autograd.set_detect_anomaly(True)
torch.set_float32_matmul_precision("high")

model_location_dir = "../logs_cluster/logs/full_runs/21_11_2025_YCoCg_with_fixed_colorregression/trained_models"
model_paths = os.listdir(model_location_dir)
model_paths = sorted(model_paths, key=lambda x: int(x.split("_kodim")[1].split("_")[0]))

table = []
for img_index in range(len(model_paths)):
    color_space = "YCoCg"
    use_image_arm = True
    print(f"Using model: {model_paths[img_index]}")

    im_path = args["input"][img_index]
    im_tensor, c_bitdepths = load_image_as_tensor(im_path, device="cuda:0", color_space=color_space)

    # ==========================================================================================
    # LOAD PRESETS, COOLCHIC PARAMETERS
    # ==========================================================================================
    image_encoder_manager = ImageEncoderManager(**get_manager_from_args(args))
    encoder_param = CoolChicEncoderParameter(**get_coolchic_param_from_args(args, "lossless"))
    encoder_param.set_image_size((im_tensor.shape[2], im_tensor.shape[3]))
    encoder_param.layers_synthesis = change_n_out_synth(
        encoder_param.layers_synthesis, args["output_dim_size"]
    )
    encoder_param.use_image_arm = use_image_arm
    coolchic = CoolChicEncoder(param=encoder_param)
    coolchic.to_device("cuda:0")
    coolchic.load_state_dict(torch.load(os.path.join(model_location_dir, model_paths[img_index])))

    # ==========================================================================================
    # QUANTIZE AND EVALUATE
    # ==========================================================================================
    # technically we don't need quantization when working with uncompressed model
    quantized_coolchic = CoolChicEncoder(param=encoder_param)
    quantized_coolchic.to_device("cuda:0")
    quantized_coolchic.set_param(coolchic.get_param())
    # quantized_coolchic = quantize_model(
    #     quantized_coolchic,
    #     im_tensor,
    #     image_encoder_manager,
    #     None, # type:ignore
    #     color_bitdepths=c_bitdepths,
    # )
    # rate_per_module, total_network_rate = quantized_coolchic.get_network_rate()
    with torch.no_grad():
        arm_params = list(quantized_coolchic.image_arm.parameters())
        arm_params_bits = sum(p.numel() for p in arm_params) * 32  # assuming float32
    print(arm_params_bits / im_tensor.numel())
    raise Exception("Stop for testing.")
    total_network_rate += arm_params_bits
    total_network_rate /= im_tensor.numel()
    total_network_rate = float(total_network_rate)

    with torch.no_grad():
        # Forward pass with no quantization noise
        predicted_prior = quantized_coolchic.forward(
            image=im_tensor,
            quantizer_noise_type="none",
            quantizer_type="hardround",
            AC_MAX_VAL=-1,
            flag_additional_outputs=False,
        )
        predicted_priors_rates = loss_function(
            predicted_prior,
            im_tensor,
            rate_mlp_bpd=total_network_rate,
            latent_multiplier=1.0,
            channel_ranges=c_bitdepths,
        )
    print(
        f"Rate per module: {rate_per_module},\n",
        f"Final results after quantization: {predicted_priors_rates}"
    )
    table.append(
        {
            "Index": img_index,
            "Loss": predicted_priors_rates.loss.cpu().item(),
            "Rate NN": predicted_priors_rates.rate_nn_bpd,
            "Rate Latent": predicted_priors_rates.rate_latent_bpd,
            "Rate Img": predicted_priors_rates.rate_img_bpd,
        }
    )
    raise Exception("Stop after one iteration for testing.")

Using model: 2025_11_22__16_44_29__trained_coolchic_kodak_kodim01_img_rate_3.2631609439849854.pth
Converting image to YCoCg color space


RuntimeError: Error(s) in loading state_dict for CoolChicEncoder:
	size mismatch for synthesis.layers.2.weight: copying a param with shape torch.Size([9, 24, 1, 1]) from checkpoint, the shape in current model is torch.Size([6, 24, 1, 1]).
	size mismatch for synthesis.layers.2.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([6]).
	size mismatch for synthesis.layers.4.weight: copying a param with shape torch.Size([9, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 2, 3, 3]).
	size mismatch for synthesis.layers.4.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([6]).
	size mismatch for synthesis.layers.6.weight: copying a param with shape torch.Size([9, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 2, 3, 3]).
	size mismatch for synthesis.layers.6.bias: copying a param with shape torch.Size([9]) from checkpoint, the shape in current model is torch.Size([6]).
	size mismatch for image_arm.models.0.0.weight: copying a param with shape torch.Size([6, 33]) from checkpoint, the shape in current model is torch.Size([6, 30]).
	size mismatch for image_arm.models.1.0.weight: copying a param with shape torch.Size([6, 34]) from checkpoint, the shape in current model is torch.Size([6, 31]).
	size mismatch for image_arm.models.1.4.weight: copying a param with shape torch.Size([6, 6]) from checkpoint, the shape in current model is torch.Size([4, 6]).
	size mismatch for image_arm.models.1.4.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([4]).
	size mismatch for image_arm.models.2.0.weight: copying a param with shape torch.Size([6, 35]) from checkpoint, the shape in current model is torch.Size([6, 32]).
	size mismatch for image_arm.models.2.4.weight: copying a param with shape torch.Size([8, 6]) from checkpoint, the shape in current model is torch.Size([4, 6]).
	size mismatch for image_arm.models.2.4.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([4]).

In [4]:
def dict_list_to_latex_table(rows, floatfmt="{:.3f}"):
    """
    Convert a list of dicts into a LaTeX table string.
    
    Args:
        rows (list of dict): All dicts must have the same keys.
        floatfmt (str): Format string for floats, e.g. "{:.4f}".
    
    Returns:
        str: LaTeX table.
    """
    if not rows:
        return ""

    # Column names taken from dict keys
    cols = list(rows[0].keys())

    # Escape LaTeX special characters in column names
    def escape(s):
        repl = {
            '%': r'\%',
            '&': r'\&',
            '_': r'\_',
        }
        for k, v in repl.items():
            s = s.replace(k, v)
        return s

    header = " & ".join(escape(c) for c in cols) + r" \\"

    # Build table rows
    body_lines = []
    for row in rows:
        cells = []
        for c in cols:
            v = row[c]
            if isinstance(v, float):
                v = floatfmt.format(v)
            cells.append(str(v))
        body_lines.append(" & ".join(cells) + r" \\")
    
    body = "\n".join(body_lines)

    # Combine into a LaTeX table
    latex = (
        "\\begin{tabular}{%s}\n" % ("l" * len(cols)) +
        header + "\n\\hline\n" +
        body + "\n\\end{tabular}"
    )
    return latex

latex_table = dict_list_to_latex_table(table, floatfmt="{:.3f}")
print("\nLaTeX Table:\n")
print(latex_table)


LaTeX Table:

\begin{tabular}{lllll}
Index & Loss & Rate NN & Rate Latent & Rate Img \\
\hline
0 & 4.178 & 0.029 & 0.371 & 3.778 \\
\end{tabular}


In [5]:
average_loss = np.mean([row["Loss"] for row in table])
average_rate_nn = np.mean([row["Rate NN"] for row in table])
average_rate_latent = np.mean([row["Rate Latent"] for row in table])
average_rate_img = np.mean([row["Rate Img"] for row in table])
print("\nAverages:")
print(f"Average Loss: {average_loss:.6f}")
print(f"Average Rate NN: {average_rate_nn:.6f}")
print(f"Average Rate Latent: {average_rate_latent:.6f}")
print(f"Average Rate Img: {average_rate_img:.6f}")


Averages:
Average Loss: 4.178029
Average Rate NN: 0.028973
Average Rate Latent: 0.370850
Average Rate Img: 3.778206
