In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from compressai.zoo import bmshj2018_factorized
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Pad, Resize
import numpy as np

# --- Generalized Gaussian PDF & log likelihood ---
def generalized_gaussian_logpdf(x, beta, loc, scale, eps=1e-8):
    """
    Log PDF of Generalized Gaussian:
    beta: shape parameter (>0)
    loc: location (mean)
    scale: scale (>0)
    """
    from math import gamma
    gamma_beta = torch.lgamma(torch.tensor(1/beta))
    c = beta / (2 * scale * torch.exp(torch.lgamma(1/beta)))
    z = torch.abs((x - loc) / scale) ** beta
    log_prob = torch.log(beta) - torch.log(2 * scale) - torch.lgamma(1/beta) - z
    return log_prob

# --- Custom Hyperprior Layer using Generalized Gaussian ---
class GenGaussHyperprior(nn.Module):
    def __init__(self, latent_channels):
        super().__init__()
        # Initialize shape (beta), location (mu), scale (alpha) as learnable parameters
        # Here we assume shared parameters across channels for simplicity,
        # you can make them per-channel if desired
        self.beta = nn.Parameter(torch.tensor(1.5))  # start near Laplace (beta=1)
        self.loc = nn.Parameter(torch.zeros(1))
        self.scale = nn.Parameter(torch.ones(1))

    def forward(self, z):
        # Compute log likelihood of z under GenGauss
        log_probs = generalized_gaussian_logpdf(z, self.beta.clamp(min=0.1), self.loc, self.scale.clamp(min=1e-5))
        # Sum over all elements for total log-likelihood (for rate estimation)
        return log_probs.sum()

# --- New Model with frozen g_a and trainable g_s + hyperprior ---
class CustomCompressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        base_model = bmshj2018_factorized(quality=3, pretrained=True)
        self.g_a = base_model.g_a.eval()
        for param in self.g_a.parameters():
            param.requires_grad = False  # freeze analysis transform

        self.g_s = base_model.g_s  # decoder (trainable)
        self.hyperprior = GenGaussHyperprior(latent_channels=192)  # latent channels of bmshj2018_factorized quality=3

    def forward(self, x):
        with torch.no_grad():
            z = self.g_a(x)
        x_hat = self.g_s(z)
        return x_hat, z

# --- Training loop skeleton ---
def train(model, dataloader, optimizer, device, epochs=10, lambda_rd=0.01):
    model.train()
    mse_loss = nn.MSELoss()

    for epoch in range(epochs):
        for batch in dataloader:
            imgs, _ = batch
            imgs = imgs.to(device)

            optimizer.zero_grad()
            x_hat, z = model(imgs)
            distortion = mse_loss(x_hat, imgs)

            # Rate term from hyperprior log likelihood (negate to get bits)
            rate = -model.hyperprior(z) / (imgs.numel())  # normalize by total pixels in batch

            loss = distortion + lambda_rd * rate
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}: Distortion {distortion.item():.4f}, Rate {rate.item():.4f}, Loss {loss.item():.4f}")

# --- Usage example ---
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CustomCompressionModel().to(device)

    transform = Compose([
        Resize((256,256)),
        ToTensor(),
        Pad((0,0,8,8))
    ])

    dataset = ImageFolder(root="/your/image/folder", transform=transform)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    optimizer = optim.Adam(list(model.g_s.parameters()) + list(model.hyperprior.parameters()), lr=1e-4)

    train(model, dataloader, optimizer, device, epochs=10)
