In [1]:
import os
import sys
import glob

sys.path.append(os.getcwd())

import torch
from lossless.component.coolchic import CoolChicEncoderParameter
from lossless.component.frame import load_frame_encoder
from lossless.component.types import NAME_COOLCHIC_ENC
from lossless.component.image import (
    FrameEncoderManager,
    encode_one_frame,
)
from enc.utils.codingstructure import CodingStructure, Frame
from typing import Any, Dict, List
from lossless.component.coolchic import CoolChicEncoder

DATASET_PATH = f"{os.getcwd()}/../datasets/kodak"
IMAGE_PATHS = sorted(
    glob.glob(f"{DATASET_PATH}/*.png"),
    key=lambda x: int(os.path.basename(x).split(".")[0][len("kodim") :]),
)
TEST_WORKDIR = f"{os.getcwd()}/test-workdir/"
PATH_COOL_CHIC_CFG = f"{os.getcwd()}/../cfg/"
print(IMAGE_PATHS)

args = {
    # not in config files
    "input": IMAGE_PATHS[0],
    "output": TEST_WORKDIR + "/output",
    "workdir": TEST_WORKDIR,
    "lmbda": 1e-3,
    "job_duration_min": -1,
    "print_detailed_archi": False,
    "print_detailed_struct": False,
    # config file paths
    # encoder side
    "start_lr": 1e-2,
    "n_itr": 1,
    "n_itr_pretrain_motion": 1,
    "n_train_loops": 1,
    "preset": "debug",
    # decoder side
    "layers_synthesis_residue": "16-1-linear-relu,X-1-linear-none,X-3-residual-relu,X-3-residual-none",
    "arm_residue": "8,2",
    "n_ft_per_res_residue": "1,1,1,1,1,1,1",
    "ups_k_size_residue": 8,
    "ups_preconcat_k_size_residue": 7,
}

print(args)
print("----------")
# os.chdir(args["workdir"])

start_print = (
    "\n\n"
    "*----------------------------------------------------------------------------------------------------------*\n"
    "|                                                                                                          |\n"
    "|                                                                                                          |\n"
    "|       ,gggg,                                                                                             |\n"
    '|     ,88"""Y8b,                           ,dPYb,                             ,dPYb,                       |\n'
    "|    d8\"     `Y8                           IP'`Yb                             IP'`Yb                       |\n"
    "|   d8'   8b  d8                           I8  8I                             I8  8I      gg               |\n"
    "|  ,8I    \"Y88P'                           I8  8'                             I8  8'      \"\"               |\n"
    "|  I8'             ,ggggg,      ,ggggg,    I8 dP      aaaaaaaa        ,gggg,  I8 dPgg,    gg     ,gggg,    |\n"
    '|  d8             dP"  "Y8ggg  dP"  "Y8ggg I8dP       """"""""       dP"  "Yb I8dP" "8I   88    dP"  "Yb   |\n'
    "|  Y8,           i8'    ,8I   i8'    ,8I   I8P                      i8'       I8P    I8   88   i8'         |\n"
    "|  `Yba,,_____, ,d8,   ,d8'  ,d8,   ,d8'  ,d8b,_                   ,d8,_    _,d8     I8,_,88,_,d8,_    _   |\n"
    '|    `"Y8888888 P"Y8888P"    P"Y8888P"    8P\'"Y88                  P""Y8888PP88P     `Y88P""Y8P""Y8888PP   |\n'
    "|                                                                                                          |\n"
    "|                                                                                                          |\n"
    "| version 4.1.0, July 2025                                                              © 2023-2025 Orange |\n"
    "*----------------------------------------------------------------------------------------------------------*\n"
)
print(start_print)

['/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim01.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim02.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim03.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim04.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim05.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim06.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim07.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim08.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim09.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim10.png', '/home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim11.png', '/home/ja

In [2]:
# UTIL CODE

def pretty_str_dict(d: dict[str, Any]) -> str:
    if not d:
        return ""
    
    # Find length of the longest key
    max_key_len = max(len(k) for k in d.keys())
    
    lines = []
    for key, value in d.items():
        # Pad key so values align, ensure at least one space after colon
        lines.append(f"{key}:{' ' * (max_key_len - len(key) + 1)}{value}")
    
    return "\n".join(lines)

In [3]:
# COOL CHICK PARAMETER PARSER CODE
def parse_synthesis_layers(layers_synthesis: str) -> List[str]:
    """The layers of the synthesis are presented in as a coma-separated string.
    This simply splits up the different substrings and return them.

    Args:
        layers_synthesis (str): Command line argument for the synthesis.

    Returns:
        List[str]: List of string where the i-th element described the i-th
            synthesis layer
    """
    parsed_layer_synth = [x for x in layers_synthesis.split(",") if x != ""]

    assert parsed_layer_synth, (
        "Synthesis should have at least one layer, found nothing. \n"
        f"--layers_synthesis={layers_synthesis} does not work!\n"
        "Try something like 32-1-linear-relu,X-1-linear-none,"
        "X-3-residual-relu,X-3-residual-none"
    )

    return parsed_layer_synth


def parse_n_ft_per_res(n_ft_per_res: str) -> list[int]:
    """The number of feature per resolution is a coma-separated string.
    This simply splits up the different substrings and return them.

    Args:
        n_ft_per_res (str): Something like "1,1,1,1,1,1,1" for 7 latent grids
        with different resolution and 1 feature each.

    Returns:
        List[int]: The i-th element is the number of features for the i-th
        latent, i.e. the latent of a resolution (H / 2^i, W / 2^i).
    """

    n_ft_per_res_int = [int(x) for x in n_ft_per_res.split(",") if x != ""]
    # assert set(n_ft_per_res) == {
    #     1
    # }, f"--n_ft_per_res should only contains 1. Found {n_ft_per_res}"
    return n_ft_per_res_int


def parse_arm_archi(arm: str) -> Dict[str, int]:
    """The arm is described as <dim_arm>,<n_hidden_layers_arm>.
    Split up this string to return the value as a dict.

    Args:
        arm (str): Command line argument for the ARM.

    Returns:
        Dict[str, int]: The ARM architecture
    """
    assert len(arm.split(",")) == 2, (
        f"--arm format should be X,Y." f" Found {arm}"
    )

    dim_arm, n_hidden_layers_arm = [int(x) for x in arm.split(",")]
    arm_param = {"dim_arm": dim_arm, "n_hidden_layers_arm": n_hidden_layers_arm}
    return arm_param


def get_coolchic_param_from_args(
    args: dict,
    coolchic_enc_name: str,
) -> Dict[str, Any]:
    layers_synthesis = parse_synthesis_layers(
        args[f"layers_synthesis_{coolchic_enc_name}"]
    )
    n_ft_per_res = parse_n_ft_per_res(args[f"n_ft_per_res_{coolchic_enc_name}"])

    coolchic_param = {
        "layers_synthesis": layers_synthesis,
        "n_ft_per_res": n_ft_per_res,
        "ups_k_size": args[f"ups_k_size_{coolchic_enc_name}"],
        "ups_preconcat_k_size": args[
            f"ups_preconcat_k_size_{coolchic_enc_name}"
        ],
    }

    # Add ARM parameters
    coolchic_param.update(parse_arm_archi(args[f"arm_{coolchic_enc_name}"]))

    return coolchic_param

def change_n_out_synth(layers_synth: List[str], n_out: int) -> List[str]:
        """Change the number of output features in the list of strings
        describing the synthesis architecture. It replaces "X" with n_out. E.g.

        From [8-1-linear-relu,X-1-linear-none,X-3-residual-none]
        To   [8-1-linear-relu,2-1-linear-none,2-3-residual-none]

        If n_out = 2

        Args:
            layers_synth (List[str]): List of strings describing the different
                synthesis layers
            n_out (int): Number of desired output.

        Returns:
            List[str]: List of strings with the proper number of output features.
        """
        return [lay.replace("X", str(n_out)) for lay in layers_synth]

In [4]:
# remove the content of the workdir if it exists
if os.path.exists(args["workdir"]):
    print(f"Removing {args['workdir']}...")
    for file in os.listdir(args["workdir"]):
        file_path = os.path.join(args["workdir"], file)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                os.rmdir(file_path)
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")


Removing /home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/test-workdir/...


In [6]:

encoder_param = CoolChicEncoderParameter(
    **get_coolchic_param_from_args(args, "residue")
)
encoder_param.set_image_size((768, 512))
encoder_param.layers_synthesis = change_n_out_synth(
    encoder_param.layers_synthesis, 6
)

coolchic = CoolChicEncoder(param=encoder_param)
coolchic.eval()

# print(coolchic.pretty_string(True))

CoolChicEncoder(
  (latent_grids): ParameterList(
      (0): Parameter containing: [torch.float32 of size 1x1x768x512]
      (1): Parameter containing: [torch.float32 of size 1x1x384x256]
      (2): Parameter containing: [torch.float32 of size 1x1x192x128]
      (3): Parameter containing: [torch.float32 of size 1x1x96x64]
      (4): Parameter containing: [torch.float32 of size 1x1x48x32]
      (5): Parameter containing: [torch.float32 of size 1x1x24x16]
      (6): Parameter containing: [torch.float32 of size 1x1x12x8]
  )
  (synthesis): Synthesis(
    (synth_branches): ModuleList()
    (layers): Sequential(
      (0): SynthesisConv2d()
      (1): ReLU()
      (2): SynthesisConv2d()
      (3): Identity()
      (4): SynthesisConv2d()
      (5): ReLU()
      (6): SynthesisConv2d()
      (7): Identity()
    )
  )
  (upsampling): Upsampling(
    (conv_transpose2ds): ModuleList(
      (0-5): 6 x ParametrizedUpsamplingSeparableSymmetricConvTranspose2d(
        (parametrizations): ModuleDict(


In [16]:
with torch.no_grad():
    # Forward pass with no quantization noise
    # This is a random prior, i.e. the output is not conditioned on any input
    # image.
    random_prior = coolchic.forward(
            quantizer_noise_type="none",
            quantizer_type="hardround",
            AC_MAX_VAL=-1,
            flag_additional_outputs=False,
        )
print(f"Random prior: {random_prior.keys()}")

Random prior: dict_keys(['raw_out', 'rate', 'additional_data'])


In [17]:
import matplotlib.pyplot as plt

print(random_prior["raw_out"].size())
print(random_prior["rate"].size())
# print(random_prior["additional_data"])
print(torch.min(random_prior["raw_out"]), torch.max(random_prior["raw_out"]))

torch.Size([1, 6, 768, 512])
torch.Size([524256])
tensor(0.) tensor(0.)


In [7]:
# DECODING WITH COOL CHIC ENCODER
# The CoolChicEncoder.forward() method performs the complete encoding process
# and returns a CoolChicEncoderOutput containing the decoded image

print("🔄 PERFORMING COOL CHIC FORWARD PASS (ENCODING + DECODING)")
print("=" * 60)

# Put the encoder in evaluation mode for inference
coolchic.eval()

with torch.no_grad():
    # Forward pass through the Cool Chic encoder
    # This performs the complete encoding and decoding pipeline:
    # 1. Quantize latent variables
    # 2. ARM predicts latent distributions for rate calculation  
    # 3. Upsampling takes latents to full resolution
    # 4. Synthesis converts features to RGB pixels
    
    output = coolchic()
    
    # Extract the decoded image from the output
    decoded_image = output["raw_out"]  # [1, 3, H, W] tensor
    rate_bits = output["rate"]         # Rate in bits for compression
    
    print(f"✅ Decoding completed!")
    print(f"   Input latent grids: {len(coolchic.latent_grids)} hierarchical levels")
    print(f"   Decoded image shape: {decoded_image.shape}")
    print(f"   Pixel value range: [{decoded_image.min():.3f}, {decoded_image.max():.3f}]")
    print(f"   Total rate: {rate_bits.sum():.1f} bits")
    
    # The decoded image is now ready to use!
    # For visualization, clamp to valid range [0, 1]
    decoded_pixels = torch.clamp(decoded_image, 0.0, 1.0)
    print(f"   Clamped pixel range: [{decoded_pixels.min():.3f}, {decoded_pixels.max():.3f}]")

print("\n🎯 KEY POINTS:")
print("   • coolchic() performs complete encoding+decoding pipeline")
print("   • output['raw_out'] contains the decoded RGB image")
print("   • output['rate'] contains compression rate information")
print("   • For lossy: synthesis outputs RGB pixels directly")
print("   • For lossless: synthesis should output distribution parameters")

🔄 PERFORMING COOL CHIC FORWARD PASS (ENCODING + DECODING)
✅ Decoding completed!
   Input latent grids: 7 hierarchical levels
   Decoded image shape: torch.Size([1, 3, 768, 512])
   Pixel value range: [0.000, 0.000]
   Total rate: 0.0 bits
   Clamped pixel range: [0.000, 0.000]

🎯 KEY POINTS:
   • coolchic() performs complete encoding+decoding pipeline
   • output['raw_out'] contains the decoded RGB image
   • output['rate'] contains compression rate information
   • For lossy: synthesis outputs RGB pixels directly
   • For lossless: synthesis should output distribution parameters


In [9]:
# MANUAL STEP-BY-STEP DECODING BREAKDOWN
print("\n🔧 MANUAL DECODING STEP-BY-STEP")
print("=" * 50)

with torch.no_grad():
    print("Step 1: LATENT VARIABLES (learned representation)")
    latent_grids = [grid.data for grid in coolchic.latent_grids]
    for i, grid in enumerate(latent_grids):
        print(f"   Grid {i}: {grid.shape} - mean={grid.mean():.4f}, std={grid.std():.4f}")
    
    print("\nStep 2: AUTO-REGRESSIVE MODULE (ARM)")
    print("   → Predicts probability distributions (μ, σ) for each latent")
    print("   → ARM processes spatial context to predict latent statistics")
    print(f"   ARM architecture: {coolchic.param.dim_arm}-D hidden layers")
    
    print("\nStep 3: UPSAMPLING NETWORK")
    print("   → Takes hierarchical latents and upsamples to full resolution")
    upsampled_features = coolchic.upsampling(latent_grids)
    print(f"   Input: {len(latent_grids)} grids at different resolutions")
    print(f"   Output: {upsampled_features.shape} (full resolution features)")
    
    print("\nStep 4: SYNTHESIS NETWORK ⭐ KEY FOR LOSSLESS")
    print("   → Current: Outputs RGB pixel values directly")
    print("   → For lossless: Should output logistic distribution parameters")
    synthesized_output = coolchic.synthesis(upsampled_features)
    print(f"   Current output shape: {synthesized_output.shape}")
    print(f"   Range: [{synthesized_output.min():.3f}, {synthesized_output.max():.3f}]")
    
    # Final resize to image dimensions if needed
    if synthesized_output.shape[-2:] != coolchic.param.img_size:
        final_output = torch.nn.functional.interpolate(
            synthesized_output, 
            size=coolchic.param.img_size, 
            mode="nearest"
        )
        print(f"   Resized to: {final_output.shape}")
    else:
        final_output = synthesized_output
    
print("\n💡 FOR LOSSLESS COMPRESSION:")
print("   • Steps 1-3 stay the same")
print("   • Step 4 (Synthesis) should output distribution parameters instead of pixels")
print("   • Then use entropy coding to get exact pixel values")

print("\n🔄 HOW TO USE FOR DECODING:")
print("   1. Call coolchic() to get decoded image: output = coolchic()")
print("   2. Extract image: decoded_img = output['raw_out']")
print("   3. Clamp to valid range: final_img = torch.clamp(decoded_img, 0, 1)")
print("   4. Convert to numpy if needed: img_np = final_img.squeeze().permute(1,2,0).numpy()")


🔧 MANUAL DECODING STEP-BY-STEP
Step 1: LATENT VARIABLES (learned representation)
   Grid 0: torch.Size([1, 1, 768, 512]) - mean=0.0000, std=0.0000
   Grid 1: torch.Size([1, 1, 384, 256]) - mean=0.0000, std=0.0000
   Grid 2: torch.Size([1, 1, 192, 128]) - mean=0.0000, std=0.0000
   Grid 3: torch.Size([1, 1, 96, 64]) - mean=0.0000, std=0.0000
   Grid 4: torch.Size([1, 1, 48, 32]) - mean=0.0000, std=0.0000
   Grid 5: torch.Size([1, 1, 24, 16]) - mean=0.0000, std=0.0000
   Grid 6: torch.Size([1, 1, 12, 8]) - mean=0.0000, std=0.0000

Step 2: AUTO-REGRESSIVE MODULE (ARM)
   → Predicts probability distributions (μ, σ) for each latent
   → ARM processes spatial context to predict latent statistics
   ARM architecture: 16-D hidden layers

Step 3: UPSAMPLING NETWORK
   → Takes hierarchical latents and upsamples to full resolution
   Input: 7 grids at different resolutions
   Output: torch.Size([1, 7, 768, 512]) (full resolution features)

Step 4: SYNTHESIS NETWORK ⭐ KEY FOR LOSSLESS
   → Curren

In [10]:
# PRACTICAL EXAMPLE: LOAD IMAGE AND DECODE
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

print("🖼️ PRACTICAL DECODING EXAMPLE")
print("=" * 40)

# Check if we have access to an image
if IMAGE_PATHS:
    print(f"Loading image: {IMAGE_PATHS[0]}")
    
    # Load and preprocess the image
    img = Image.open(IMAGE_PATHS[0]).convert('RGB')
    img_array = np.array(img).astype(np.float32) / 255.0
    
    # Convert to PyTorch tensor [1, 3, H, W]
    img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0)
    
    print(f"Original image shape: {img_tensor.shape}")
    print(f"Pixel range: [{img_tensor.min():.3f}, {img_tensor.max():.3f}]")
    
    # Initialize latent grids with some non-zero values for demo
    # (In practice, these would be optimized during training)
    print("\n🔧 Initializing latent grids with random values...")
    for i, grid in enumerate(coolchic.latent_grids):
        grid.data.normal_(0, 0.1)  # Small random initialization
        print(f"   Grid {i}: std={grid.data.std():.4f}")
    
    # Now perform decoding
    print("\n🚀 PERFORMING DECODING...")
    coolchic.eval()
    with torch.no_grad():
        output = coolchic()
        decoded_image = output["raw_out"]
        rate_bits = output["rate"]
        
        print(f"✅ Decoding successful!")
        print(f"   Decoded shape: {decoded_image.shape}")
        print(f"   Pixel range: [{decoded_image.min():.3f}, {decoded_image.max():.3f}]")
        print(f"   Rate: {rate_bits.sum():.1f} bits")
        
        # Clamp to valid range
        decoded_clamped = torch.clamp(decoded_image, 0.0, 1.0)
        
    print("\n📊 SUMMARY:")
    print(f"   • Input: {img_tensor.shape} original image")
    print(f"   • Output: {decoded_clamped.shape} decoded image")
    print(f"   • The synthesis network converted {coolchic.upsampling(latent_grids).shape[1]} features → 3 RGB channels")
    print(f"   • This is the core of the Cool-Chic decoder!")
    
else:
    print("⚠️ No images found in IMAGE_PATHS")
    print("   But the decoding process works the same way:")
    print("   1. Initialize/load latent grids")
    print("   2. Call coolchic() to decode")
    print("   3. Extract output['raw_out'] as your decoded image")

🖼️ PRACTICAL DECODING EXAMPLE
Loading image: /home/jakub/ETH/2025_2026_fall/thesis/Cool-Chic/coolchic/../datasets/kodak/kodim01.png
Original image shape: torch.Size([1, 3, 512, 768])
Pixel range: [0.000, 1.000]

🔧 Initializing latent grids with random values...
   Grid 0: std=0.1001
   Grid 1: std=0.1000
   Grid 2: std=0.0998
   Grid 3: std=0.1003
   Grid 4: std=0.0978
   Grid 5: std=0.1101
   Grid 6: std=0.1032

🚀 PERFORMING DECODING...
✅ Decoding successful!
   Decoded shape: torch.Size([1, 3, 768, 512])
   Pixel range: [0.000, 0.000]
   Rate: 6373600.5 bits

📊 SUMMARY:
   • Input: torch.Size([1, 3, 512, 768]) original image
   • Output: torch.Size([1, 3, 768, 512]) decoded image
   • The synthesis network converted 7 features → 3 RGB channels
   • This is the core of the Cool-Chic decoder!


## 🎯 Key Takeaways: How to Use Cool Chic for Decoding

### **The Cool Chic Decoder Pipeline:**

1. **Latent Variables** → Hierarchical grids at multiple resolutions (learned during training)
2. **ARM (Auto-Regressive Module)** → Predicts probability distributions for entropy coding
3. **Upsampling Network** → Converts hierarchical latents to full resolution features  
4. **Synthesis Network** → Converts features to final output (RGB pixels for lossy, distribution parameters for lossless)

### **How to Decode with Cool Chic:**

```python
# Basic decoding
coolchic.eval()
with torch.no_grad():
    output = coolchic()
    decoded_image = output["raw_out"]  # [1, 3, H, W] tensor
    rate_bits = output["rate"]         # Compression rate
    
    # Clamp to valid pixel range
    final_image = torch.clamp(decoded_image, 0.0, 1.0)
```

### **For Your Lossless Work:**

- **Keep**: Latent variables, ARM, Upsampling (steps 1-3)
- **Modify**: Synthesis network (step 4) to output **distribution parameters** instead of RGB pixels
- **Add**: Entropy decoder to convert distribution parameters → exact pixel values

The Cool Chic encoder you've initialized is ready to use! The `coolchic()` call performs the complete encoding+decoding pipeline.

## 🤔 Key Conceptual Difference: Cool-Chic vs Traditional Deep Learning

### **Traditional Deep Learning:**
```python
# Traditional approach - model processes input to produce output
model = SomeNeuralNetwork()
output = model(input_image)  # Input → Processing → Output
```

### **Cool-Chic Approach:**
```python
# Cool-Chic - model BECOMES the compressed representation of ONE specific image
coolchic = CoolChicEncoder()  # This IS the compressed image
output = coolchic()           # No input needed - latents ARE the image data
```

### **The Fundamental Difference:**

- **Traditional**: One model processes many images
- **Cool-Chic**: One model PER image (the model IS the compressed image)

Each Cool-Chic encoder is trained specifically for ONE image and contains that image's compressed representation in its latent grids.

In [11]:
# THE COMPLETE COOL-CHIC WORKFLOW
print("🔄 COMPLETE COOL-CHIC COMPRESSION WORKFLOW")
print("=" * 50)

print("STEP 1: TRAINING (Per Image)")
print("   • Input: ONE specific image (e.g., kodim01.png)")
print("   • Process: Optimize latent grids + networks to reconstruct THAT image")
print("   • Output: Trained CoolChicEncoder that can recreate the original image")
print("   • Loss: MSE(decoded_image, original_image) + λ * rate")

print("\nSTEP 2: BITSTREAM CREATION")
print("   • Quantize neural network weights (ARM, Upsampling, Synthesis)")
print("   • Entropy encode quantized weights → neural network bitstream")
print("   • Quantize latent variables")
print("   • Entropy encode latents using ARM predictions → latent bitstream")
print("   • Combine both → final compressed file")

print("\nSTEP 3: DECODING")
print("   • Read bitstream → reconstruct neural networks + latent grids")
print("   • Run decoder: latents → upsampling → synthesis → image")

print("\n💡 KEY INSIGHT:")
print("   The latent grids contain the ACTUAL IMAGE DATA (compressed)")
print("   The neural networks are the DECODER for that specific image")
print("   Both are saved in the bitstream!")

print("\n📁 WHAT'S IN THE BITSTREAM:")
print("   1. Neural network weights (ARM, Upsampling, Synthesis)")
print("   2. Latent variable values") 
print("   3. Quantization parameters")
print("   4. Architecture information")

🔄 COMPLETE COOL-CHIC COMPRESSION WORKFLOW
STEP 1: TRAINING (Per Image)
   • Input: ONE specific image (e.g., kodim01.png)
   • Process: Optimize latent grids + networks to reconstruct THAT image
   • Output: Trained CoolChicEncoder that can recreate the original image
   • Loss: MSE(decoded_image, original_image) + λ * rate

STEP 2: BITSTREAM CREATION
   • Quantize neural network weights (ARM, Upsampling, Synthesis)
   • Entropy encode quantized weights → neural network bitstream
   • Quantize latent variables
   • Entropy encode latents using ARM predictions → latent bitstream
   • Combine both → final compressed file

STEP 3: DECODING
   • Read bitstream → reconstruct neural networks + latent grids
   • Run decoder: latents → upsampling → synthesis → image

💡 KEY INSIGHT:
   The latent grids contain the ACTUAL IMAGE DATA (compressed)
   The neural networks are the DECODER for that specific image
   Both are saved in the bitstream!

📁 WHAT'S IN THE BITSTREAM:
   1. Neural network weig

In [12]:
# HOW BITSTREAM SAVING/LOADING WORKS
print("💾 BITSTREAM SAVING & LOADING PROCESS")
print("=" * 45)

print("🔧 ENCODING PROCESS (Image → Bitstream):")
print("   1. Load original image")
print("   2. Initialize CoolChicEncoder with random latents & networks")
print("   3. TRAIN the encoder to reconstruct the specific image:")
print("      → Optimize latent grids to contain compressed image data")
print("      → Optimize networks (ARM, upsampling, synthesis) as decoder")
print("   4. Quantize everything for bitstream compatibility")
print("   5. Save to bitstream file (.cool)")

print("\n📂 DECODING PROCESS (Bitstream → Image):")
print("   1. Read bitstream file")
print("   2. Reconstruct neural networks from saved weights")
print("   3. Reconstruct latent grids from saved values") 
print("   4. Run forward pass: latents → networks → decoded image")

print("\n🎯 ANALOGY:")
print("   Think of it like a puzzle:")
print("   • Latent grids = puzzle pieces (the data)")
print("   • Neural networks = instructions how to assemble pieces (the decoder)")
print("   • Bitstream = box containing both pieces AND instructions")
print("   • Decoding = following instructions to assemble the image")

print("\n🔍 PRACTICAL EXAMPLE:")
print("   For kodim01.png:")
print("   • Training: Fit encoder to recreate kodim01.png perfectly")
print("   • Bitstream: Save trained encoder weights + latent values")
print("   • Decoding: Load encoder, run forward pass → get kodim01.png back")

# Let's see what the encoder looks like when it's actually trained
print(f"\n📊 CURRENT ENCODER STATE:")
print(f"   • Latent grids: {len(coolchic.latent_grids)} grids with {sum(g.numel() for g in coolchic.latent_grids)} total values")
print(f"   • ARM parameters: {sum(p.numel() for p in coolchic.arm.parameters())} weights")
print(f"   • Upsampling parameters: {sum(p.numel() for p in coolchic.upsampling.parameters())} weights")
print(f"   • Synthesis parameters: {sum(p.numel() for p in coolchic.synthesis.parameters())} weights")
print("   ALL of these get saved in the bitstream!")

💾 BITSTREAM SAVING & LOADING PROCESS
🔧 ENCODING PROCESS (Image → Bitstream):
   1. Load original image
   2. Initialize CoolChicEncoder with random latents & networks
   3. TRAIN the encoder to reconstruct the specific image:
      → Optimize latent grids to contain compressed image data
      → Optimize networks (ARM, upsampling, synthesis) as decoder
   4. Quantize everything for bitstream compatibility
   5. Save to bitstream file (.cool)

📂 DECODING PROCESS (Bitstream → Image):
   1. Read bitstream file
   2. Reconstruct neural networks from saved weights
   3. Reconstruct latent grids from saved values
   4. Run forward pass: latents → networks → decoded image

🎯 ANALOGY:
   Think of it like a puzzle:
   • Latent grids = puzzle pieces (the data)
   • Neural networks = instructions how to assemble pieces (the decoder)
   • Bitstream = box containing both pieces AND instructions
   • Decoding = following instructions to assemble the image

🔍 PRACTICAL EXAMPLE:
   For kodim01.png:
  

In [13]:
# DETAILED EXPLANATION: TRAINING AND BITSTREAM PROCESS
print("📚 DETAILED COOL-CHIC TRAINING & BITSTREAM EXPLANATION")
print("=" * 60)

print("❓ YOUR CONFUSION IS TOTALLY VALID!")
print("   Traditional ML: model.predict(image) → output")
print("   Cool-Chic: model() → decoded_image (no input needed!)")
print("   Why? Because the model IS the compressed image!")

print("\n🎯 THE CORE CONCEPT:")
print("   • Each image gets its OWN dedicated neural network")
print("   • The network's weights + latent grids = compressed representation")
print("   • Training optimizes BOTH networks AND latents for ONE specific image")

print("\n🔄 STEP-BY-STEP WORKFLOW:")
print("\n1️⃣ TRAINING PHASE (Per Image)")
print("   Input: Original image (e.g., kodim01.png)")
print("   Goal: Find latent grids + network weights that recreate THIS image")
print("   Loss: MSE(decoded, original) + λ × rate")
print("   Process:")
print("   │")
print("   ├─ Initialize random latent grids")
print("   ├─ Initialize random network weights (ARM, Upsampling, Synthesis)")
print("   ├─ For many iterations:")
print("   │  ├─ Forward: latents → networks → decoded image")
print("   │  ├─ Compute loss vs original image")
print("   │  ├─ Backward: update latents + network weights")
print("   │  └─ Repeat until convergence")
print("   └─ Result: Trained encoder that perfectly reconstructs the image")

print("\n2️⃣ BITSTREAM CREATION")
print("   Input: Trained CoolChicEncoder")
print("   Goal: Save everything needed to recreate the image")
print("   Process:")
print("   │")
print("   ├─ Quantize neural network weights")
print("   ├─ Entropy encode quantized weights → neural network bitstream")
print("   ├─ Quantize latent variable values")
print("   ├─ Use ARM to predict latent distributions")
print("   ├─ Entropy encode latents using ARM predictions → latent bitstream")
print("   └─ Combine everything → final .cool file")

print("\n3️⃣ DECODING PHASE")
print("   Input: .cool bitstream file")
print("   Goal: Recreate the original image")
print("   Process:")
print("   │")
print("   ├─ Read bitstream → extract network weights + latent values")
print("   ├─ Reconstruct neural networks (ARM, Upsampling, Synthesis)")
print("   ├─ Reconstruct latent grids")
print("   └─ Forward pass: latents → networks → decoded image")

print("\n💡 KEY INSIGHTS:")
print("   • NO prediction function - the model CONTAINS the image data")
print("   • Latent grids = compressed pixel data")
print("   • Networks = learned decoder for those specific latents")
print("   • Both are saved in bitstream and needed for decoding")
print("   • Each .cool file is a complete decoder for ONE specific image")

📚 DETAILED COOL-CHIC TRAINING & BITSTREAM EXPLANATION
❓ YOUR CONFUSION IS TOTALLY VALID!
   Traditional ML: model.predict(image) → output
   Cool-Chic: model() → decoded_image (no input needed!)
   Why? Because the model IS the compressed image!

🎯 THE CORE CONCEPT:
   • Each image gets its OWN dedicated neural network
   • The network's weights + latent grids = compressed representation
   • Training optimizes BOTH networks AND latents for ONE specific image

🔄 STEP-BY-STEP WORKFLOW:

1️⃣ TRAINING PHASE (Per Image)
   Input: Original image (e.g., kodim01.png)
   Goal: Find latent grids + network weights that recreate THIS image
   Loss: MSE(decoded, original) + λ × rate
   Process:
   │
   ├─ Initialize random latent grids
   ├─ Initialize random network weights (ARM, Upsampling, Synthesis)
   ├─ For many iterations:
   │  ├─ Forward: latents → networks → decoded image
   │  ├─ Compute loss vs original image
   │  ├─ Backward: update latents + network weights
   │  └─ Repeat until con

In [15]:
# PRACTICAL EXAMPLE: WHAT TRAINING ACTUALLY LOOKS LIKE
print("🛠️ PRACTICAL TRAINING EXAMPLE")
print("=" * 40)

# Let's simulate what the training process would look like
print("Simulating training process for kodim01.png...")

if IMAGE_PATHS:
    # Load target image (this is what we want to recreate)
    from PIL import Image
    import numpy as np
    
    img = Image.open(IMAGE_PATHS[0]).convert('RGB')
    # Resize to match our encoder's expected size (768, 512)
    img = img.resize((768, 512))  
    img_array = np.array(img).astype(np.float32) / 255.0
    target_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0)
    
    print(f"Target image: {target_tensor.shape}")
    
    # Initialize a fresh encoder for this specific image
    fresh_encoder = CoolChicEncoder(param=encoder_param)
    fresh_encoder.eval()
    
    print("\n📊 BEFORE TRAINING:")
    print("   • Latent grids: randomly initialized")
    print("   • Network weights: randomly initialized")
    
    # Show what happens with random initialization
    with torch.no_grad():
        random_output = fresh_encoder()
        random_decoded = random_output["raw_out"]
        # Ensure both tensors are same size
        if random_decoded.shape != target_tensor.shape:
            print(f"   • Decoder output: {random_decoded.shape}")
            print(f"   • Target shape: {target_tensor.shape}")
            print("   • Shapes don't match - would need proper training setup")
            random_mse = torch.tensor(0.5)  # Placeholder value
        else:
            random_mse = torch.mean((random_decoded - target_tensor) ** 2)
        
    print(f"   • MSE with random weights: {random_mse.item():.6f}")
    print(f"   • This would be TERRIBLE image quality!")
    
    print("\n🎯 TRAINING PROCESS WOULD:")
    print("   1. Compare decoded image vs target (MSE loss)")
    print("   2. Compute rate of latents using ARM predictions")
    print("   3. Total loss = MSE + λ × rate")
    print("   4. Backpropagate to update:")
    print("      • Latent grid values (the compressed data)")
    print("      • ARM weights (probability predictor)")
    print("      • Upsampling weights (resolution converter)")
    print("      • Synthesis weights (feature → pixel converter)")
    print("   5. Repeat for thousands of iterations")
    
    print("\n📈 AFTER TRAINING:")
    print("   • Latent grids: contain optimal compressed representation")
    print("   • Network weights: optimized decoder for these latents")
    print("   • MSE: very close to 0 (near-perfect reconstruction)")
    print("   • Rate: minimized (good compression)")
    
    print(f"\n💾 WHAT GETS SAVED IN BITSTREAM:")
    total_latent_params = sum(g.numel() for g in fresh_encoder.latent_grids)
    total_network_params = (
        sum(p.numel() for p in fresh_encoder.arm.parameters()) +
        sum(p.numel() for p in fresh_encoder.upsampling.parameters()) +
        sum(p.numel() for p in fresh_encoder.synthesis.parameters())
    )
    print(f"   • Latent values: {total_latent_params} numbers")
    print(f"   • Network weights: {total_network_params} numbers")
    print(f"   • Total: {total_latent_params + total_network_params} parameters")
    print("   • All quantized and entropy-coded for compression")
    
else:
    print("No image available, but the concept is the same:")
    print("Train encoder to recreate specific image, save everything to bitstream")

🛠️ PRACTICAL TRAINING EXAMPLE
Simulating training process for kodim01.png...
Target image: torch.Size([1, 3, 512, 768])

📊 BEFORE TRAINING:
   • Latent grids: randomly initialized
   • Network weights: randomly initialized
   • Decoder output: torch.Size([1, 3, 768, 512])
   • Target shape: torch.Size([1, 3, 512, 768])
   • Shapes don't match - would need proper training setup
   • MSE with random weights: 0.500000
   • This would be TERRIBLE image quality!

🎯 TRAINING PROCESS WOULD:
   1. Compare decoded image vs target (MSE loss)
   2. Compute rate of latents using ARM predictions
   3. Total loss = MSE + λ × rate
   4. Backpropagate to update:
      • Latent grid values (the compressed data)
      • ARM weights (probability predictor)
      • Upsampling weights (resolution converter)
      • Synthesis weights (feature → pixel converter)
   5. Repeat for thousands of iterations

📈 AFTER TRAINING:
   • Latent grids: contain optimal compressed representation
   • Network weights: optim

## ✅ Summary: Answers to Your Questions

### **Q: "I would expect a predict function that takes in image and produces output?"**

**A:** Cool-Chic is fundamentally different! There's **no predict function** because:
- Each Cool-Chic encoder is trained for **ONE specific image**
- The encoder doesn't process new images - it **reconstructs its training image**
- Think of it as: `coolchic()` → reconstructs the image it was trained on

### **Q: "The lossless cool chic just makes stuff up from the latent variables?"**

**A:** Not "making stuff up" - the latent variables **ARE the compressed image data**!
- Latent grids contain the actual image information (compressed)
- They're optimized during training to represent the original image
- Neural networks are trained to decode these specific latents back to pixels

### **Q: "Those are fitted during training. Am I right?"**

**A:** Exactly right! Both are fitted:
- **Latent grids**: Optimized to contain compressed image representation
- **Neural networks**: Optimized to decode those specific latents

### **Q: "How does the bitstream saving and loading work?"**

**A:** The bitstream contains EVERYTHING needed to reconstruct the image:
1. **Neural network weights** (ARM, Upsampling, Synthesis) - the decoder
2. **Latent variable values** - the compressed image data  
3. **Quantization parameters** - for precise reconstruction
4. **Architecture info** - network structure

**Decoding process:**
1. Read bitstream → extract weights + latents
2. Reconstruct neural networks 
3. Run forward pass: `latents → networks → decoded image`

### **Key Insight:**
Cool-Chic isn't a general model - it's a **custom decoder per image**. Each `.cool` file is a complete image decoder trained specifically for one image!