# INFOGAN with metrics

* souce: https://github.com/openai/InfoGAN/tree/master

In [5]:
import tensorflow as tf
print("TensorFlow:", tf.__version__)
print("GPU Available:", tf.config.list_physical_devices('GPU'))

TensorFlow: 2.19.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [1]:
!nvidia-smi
!nvcc --version


Thu Jul 17 03:01:27 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A4000               Off | 00000000:55:00.0  On |                  Off |
| 41%   36C    P8              29W / 140W |     54MiB / 16376MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [6]:
!pip install tensorflow>=2.13 scipy scikit-learn pandas pillow tqdm

In [None]:
#!pip install prettytensor

In [None]:
#!pip install progressbar

In [None]:
#!pip install python-dateutil

In [8]:
"""
InfoGAN em TensorFlow 2 / Keras – versão 2
-----------------------------------------
Compatível com Python ≥ 3.9 e TensorFlow ≥ 2.13 (eager por default).
Principais mudanças em relação ao script original:
• Todas as chamadas TF‑1.x (tf.random_normal, tf.multinomial, tf.pack, etc.)
  foram substituídas pelas equivalentes em TF‑2.x.
• Classes de distribuição (Gaussian, Categorical, Product) portadas para a nova API.
• Ajuste no cálculo de MI: em vez de depender de `Product.reg_z`, selecionamos
  as últimas dimensões do vetor latente (categorical + continuous).
• Pequenas correções antipandas/NumPy para garantir execução em Windows + Anaconda.

Requisitos:
 pip install tensorflow>=2.13 scipy scikit-learn pandas pillow tqdm
"""
from __future__ import annotations
import os, math, json, errno, time
from pathlib import Path
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.linalg import sqrtm
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics import mutual_info_score
import tensorflow as tf
from tensorflow.keras import layers, Model
from PIL import Image

# ----------------------------------------------------------------------------
# Utils ----------------------------------------------------------------------
# ----------------------------------------------------------------------------
TINY = 1e-8
floatX = np.float32

def mkdir_p(path: str):
    Path(path).mkdir(parents=True, exist_ok=True)

# ----------------------------------------------------------------------------
# Dataset loader (CelebA + atributos) ----------------------------------------
# ----------------------------------------------------------------------------
class CelebAWithAttr:
    """Loader mínimo da CelebA com 40 atributos binários."""

    def __init__(self, root_dir: str = '.', image_shape=(64, 64, 3), split_ratio=0.9):
        self.image_shape = image_shape
        self.image_dim = int(np.prod(image_shape))
        base = Path(root_dir)

        # encontra arquivo de atributos
        attr_files = list(base.glob('list_attr_celeba.*'))
        attr_path = attr_files[0]

        # leitura com pandas (auto-separador)
        df = pd.read_csv(attr_path, sep=None, engine='python')
        fname_col = df.columns[0]
        self.attr_names = [c for c in df.columns if c != fname_col]

        img_dir = base / 'img_align_celeba'
        sub = img_dir / 'img_align_celeba'
        if sub.is_dir():
            img_dir = sub

        self.files = df[fname_col].apply(lambda fn: img_dir / fn).values
        attrs = df[self.attr_names].replace(-1, 0).values.astype(np.int8)
        self.attrs = attrs

        # split train
        n_train = int(len(self.files) * split_ratio)
        self.train_idx = np.arange(n_train)
        np.random.shuffle(self.train_idx)
        self.ptr = 0

    def next_batch(self, batch_size):
        if self.ptr + batch_size > len(self.train_idx):
            np.random.shuffle(self.train_idx)
            self.ptr = 0
        sel = self.train_idx[self.ptr:self.ptr + batch_size]
        self.ptr += batch_size
        imgs, atts = [], []
        for i in sel:
            img = Image.open(self.files[i]).resize(self.image_shape[:2])
            imgs.append(np.asarray(img, np.float32) / 127.5 - 1.0)
            atts.append(self.attrs[i])
        x = np.stack(imgs).reshape(batch_size, -1)
        a = np.stack(atts)
        return x, a

    def inverse_transform(self, flat):
        imgs = flat.reshape((-1,) + self.image_shape)
        return ((imgs + 1.) * 127.5).clip(0, 255).astype(np.uint8)

# ----------------------------------------------------------------------------
# Distributions (TF‑2.x) -----------------------------------------------------
# ----------------------------------------------------------------------------
class Distribution:
    @property
    def dist_flat_dim(self):
        raise NotImplementedError

    @property
    def dim(self):
        raise NotImplementedError

    @property
    def effective_dim(self):
        raise NotImplementedError

    def logli(self, x_var, dist_info):
        raise NotImplementedError

    def sample(self, dist_info):
        raise NotImplementedError

    def sample_prior(self, batch_size):
        return self.sample(self.prior_dist_info(batch_size))

    def prior_dist_info(self, batch_size):
        raise NotImplementedError

    # helpers usados pelo InfoGAN ------------------------------------------
    def entropy(self, dist_info):
        raise NotImplementedError

    def marginal_entropy(self, dist_info):
        raise NotImplementedError

    def marginal_logli(self, x_var, dist_info):
        raise NotImplementedError

    def kl(self, p, q):
        raise NotImplementedError

    def dist_info_keys(self):
        raise NotImplementedError

    def activate_dist(self, flat_dist):
        raise NotImplementedError


class Categorical(Distribution):
    def __init__(self, dim: int):
        self._dim = dim

    # --- propriedades ------------------------------------------------------
    @property
    def dim(self): return self._dim
    @property
    def dist_flat_dim(self): return self._dim
    @property
    def effective_dim(self): return 1
    @property
    def dist_info_keys(self): return ['prob']

    # --- likelihood / KL ---------------------------------------------------
    def logli(self, x_var, dist_info):
        prob = dist_info['prob']
        return tf.reduce_sum(tf.math.log(prob + TINY) * x_var, axis=1)

    def kl(self, p, q):
        p_prob, q_prob = p['prob'], q['prob']
        return tf.reduce_sum(p_prob * (tf.math.log(p_prob + TINY) - tf.math.log(q_prob + TINY)), axis=1)

    # --- sampling ----------------------------------------------------------
    def sample(self, dist_info):
        prob = dist_info['prob']
        ids = tf.random.categorical(tf.math.log(prob + TINY), 1)[:, 0]
        onehot = tf.eye(self.dim, dtype=tf.float32)
        return tf.nn.embedding_lookup(onehot, ids)

    def sample_prior(self, batch_size):
        prob = tf.ones([batch_size, self.dim], dtype=floatX) / self.dim
        return self.sample(dict(prob=prob))

    # --- helpers -----------------------------------------------------------
    def activate_dist(self, flat_dist):
        return {'prob': tf.nn.softmax(flat_dist)}

    def entropy(self, dist_info):
        prob = dist_info['prob']
        return -tf.reduce_sum(prob * tf.math.log(prob + TINY), axis=1)

    def marginal_entropy(self, dist_info):
        prob = dist_info['prob']
        avg_prob = tf.tile(tf.reduce_mean(prob, axis=0, keepdims=True), [tf.shape(prob)[0], 1])
        return self.entropy({'prob': avg_prob})

    def marginal_logli(self, x_var, dist_info):
        prob = dist_info['prob']
        avg_prob = tf.tile(tf.reduce_mean(prob, axis=0, keepdims=True), [tf.shape(prob)[0], 1])
        return self.logli(x_var, {'prob': avg_prob})

    def nonreparam_logli(self, x_var, dist_info):
        return self.logli(x_var, dist_info)


class Gaussian(Distribution):
    def __init__(self, dim: int, fix_std: bool = False):
        self._dim = dim
        self._fix_std = fix_std

    @property
    def dim(self): return self._dim
    @property
    def dist_flat_dim(self): return self._dim * 2
    @property
    def effective_dim(self): return self._dim
    @property
    def dist_info_keys(self): return ['mean', 'stddev']

    # --- likelihood / KL ---------------------------------------------------
    def logli(self, x_var, dist_info):
        mean, std = dist_info['mean'], dist_info['stddev']
        eps = (x_var - mean) / (std + TINY)
        return tf.reduce_sum(-0.5 * np.log(2 * np.pi) - tf.math.log(std + TINY) - 0.5 * tf.square(eps), axis=1)

    def kl(self, p, q):
        μ1, σ1 = p['mean'], p['stddev']
        μ2, σ2 = q['mean'], q['stddev']
        num = tf.square(μ1 - μ2) + tf.square(σ1) - tf.square(σ2)
        den = 2. * tf.square(σ2)
        return tf.reduce_sum(num / (den + TINY) + tf.math.log(σ2 + TINY) - tf.math.log(σ1 + TINY), axis=1)

    # --- sampling ----------------------------------------------------------
    def sample(self, dist_info):
        mean, std = dist_info['mean'], dist_info['stddev']
        eps = tf.random.normal(tf.shape(mean))
        return mean + eps * std

    def sample_prior(self, batch_size):
        return tf.random.normal([batch_size, self.dim])

    # --- helpers -----------------------------------------------------------
    def activate_dist(self, flat):
        mean = flat[:, :self.dim]
        if self._fix_std:
            std = tf.ones_like(mean)
        else:
            std = tf.sqrt(tf.exp(flat[:, self.dim:]))
        return {'mean': mean, 'stddev': std}

    def entropy(self, dist_info):
        std = dist_info['stddev']
        return tf.reduce_sum(0.5 * np.log(2 * np.pi * np.e) + tf.math.log(std + TINY), axis=1)

    def marginal_entropy(self, dist_info):
        return self.entropy(dist_info)

    def marginal_logli(self, x_var, dist_info):
        return self.logli(x_var, dist_info)

    def nonreparam_logli(self, x_var, dist_info):
        return tf.zeros_like(x_var[:, 0])

    def prior_dist_info(self, batch_size):
        mean = tf.zeros([batch_size, self.dim])
        std = tf.ones([batch_size, self.dim])
        return {'mean': mean, 'stddev': std}

# Espaço latente
class Product(Distribution):
    def __init__(self, dists: list[Distribution]):
        self._dists = dists

    # --- short‑cuts --------------------------------------------------------
    @property
    def dists(self): return list(self._dists)
    @property
    def dim(self): return sum(d.dim for d in self.dists)
    @property
    def effective_dim(self): return sum(d.effective_dim for d in self.dists)
    @property
    def dist_flat_dim(self): return sum(d.dist_flat_dim for d in self.dists)

    def dims(self):
        return [d.dim for d in self.dists]

    def dist_flat_dims(self):
        return [d.dist_flat_dim for d in self.dists]

    # ----------------------------------------------------------------------
    def dist_info_keys(self):
        keys = []
        for idx, dist in enumerate(self.dists):
            for k in dist.dist_info_keys:
                keys.append(f'id_{idx}_{k}')
        return keys

    def split_dist_info(self, dist_info):
        ret = []
        for idx, dist in enumerate(self.dists):
            cur = {k: dist_info[f'id_{idx}_{k}'] for k in dist.dist_info_keys}
            ret.append(cur)
        return ret

    def join_dist_infos(self, infos):
        ret = {}
        for idx, dist, info in zip(itertools.count(), self.dists, infos):
            for k in dist.dist_info_keys:
                ret[f'id_{idx}_{k}'] = info[k]
        return ret

    def split_var(self, x):
        cum = np.cumsum([d.dim for d in self.dists])
        outs, start = [], 0
        for end in cum:
            outs.append(x[:, start:end])
            start = end
        return outs

    def split_dist_flat(self, flat):
        cum = np.cumsum([d.dist_flat_dim for d in self.dists])
        outs, start = [], 0
        for end in cum:
            outs.append(flat[:, start:end])
            start = end
        return outs

    # --- sampling ----------------------------------------------------------
    def sample(self, dist_info):
        parts = [tf.cast(d.sample(i), tf.float32) for d, i in zip(self.dists, self.split_dist_info(dist_info))]
        return tf.concat(parts, axis=1)

    def sample_prior(self, batch_size):
        parts = [tf.cast(d.sample_prior(batch_size), tf.float32) for d in self.dists]
        return tf.concat(parts, axis=1)

    # --- likelihood / entropy / etc. --------------------------------------
    def logli(self, x_var, dist_info):
        return tf.add_n([d.logli(xi, di) for d, xi, di in zip(self.dists, self.split_var(x_var), self.split_dist_info(dist_info))])

    def marginal_logli(self, x_var, dist_info):
        return tf.add_n([d.marginal_logli(xi, di) for d, xi, di in zip(self.dists, self.split_var(x_var), self.split_dist_info(dist_info))])

    def entropy(self, dist_info):
        return tf.add_n([d.entropy(di) for d, di in zip(self.dists, self.split_dist_info(dist_info))])

    def marginal_entropy(self, dist_info):
        return tf.add_n([d.marginal_entropy(di) for d, di in zip(self.dists, self.split_dist_info(dist_info))])

    def nonreparam_logli(self, x_var, dist_info):
        return tf.add_n([d.nonreparam_logli(xi, di) for d, xi, di in zip(self.dists, self.split_var(x_var), self.split_dist_info(dist_info))])

    def kl(self, p, q):
        return tf.add_n([d.kl(pi, qi) for d, pi, qi in zip(self.dists, self.split_dist_info(p), self.split_dist_info(q))])

    def activate_dist(self, flat):
        ret = {}
        for idx, d, f in zip(itertools.count(), self.dists, self.split_dist_flat(flat)):
            info = d.activate_dist(f)
            for k, v in info.items():
                ret[f'id_{idx}_{k}'] = v
        return ret

    def prior_dist_info(self, batch_size):
        infos = [d.prior_dist_info(batch_size) for d in self.dists]
        return self.join_dist_infos(infos)

# ----------------------------------------------------------------------------
# Modelos Keras --------------------------------------------------------------
# ----------------------------------------------------------------------------

#Gerador DCGAN-64: projeção → reshape → 4 transposed‐convs com batch norm + ReLU, final em tanh.
def build_generator(z_dim: int, img_shape):
    """Gerador DCGAN‑64 clássico (4× upsampling → 64×64)."""
    h, w, c = img_shape  # h==w==64

    inp = layers.Input(shape=(z_dim,))

    # 1) projeção + reshape → 4×4×512
    x = layers.Dense(4 * 4 * 512, use_bias=False)(inp)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Reshape((4, 4, 512))(x)

    # 2) 8×8×256
    x = layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # 3) 16×16×128
    x = layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # 4) 32×32×64
    x = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # 5) 64×64×c
    x = layers.Conv2DTranspose(c, kernel_size=4, strides=2, padding='same', activation='tanh')(x)

    return Model(inp, x, name='Generator')


def build_discriminator_q(img_shape, cat_dim, cont_dim):
    inp = layers.Input(shape=img_shape)
    x = layers.Conv2D(64, 4, 2, 'same')(inp)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(128, 4, 2, 'same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Flatten()(x)
    x = layers.Dense(1024)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    # discriminator
    d_out = layers.Dense(1, activation='sigmoid', name='d_out')(x)

    # Q‐network: duas saídas
    q_cat_logits = layers.Dense(cat_dim, name='q_cat_logits')(x)
    q_cont_params = layers.Dense(cont_dim * 2, name='q_cont_params')(x)

    return Model(inp, [d_out, q_cat_logits, q_cont_params], name='Discriminator_Q')


# ----------------------------------------------------------------------------
# Métricas

_inception = tf.keras.applications.InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

# ------------------------- FID helpers (stream‑safe) -------------------------
FID_BATCH = 256  # nº de imagens por forward pass no Inception (evita OOM)


def _get_inception_activations(img_uint8, bs: int = FID_BATCH):
    """Extrai ativações do pool-3 da Inception em minibatches para não estourar RAM."""
    acts = []
    for i in range(0, len(img_uint8), bs):
        batch = img_uint8[i:i + bs]
        batch = tf.image.resize(batch, (299, 299))
        batch = tf.keras.applications.inception_v3.preprocess_input(tf.cast(batch, tf.float32))
        acts.append(_inception(batch, training=False))
    # Retorna só o array numpy resultante
    return tf.concat(acts, axis=0).numpy()


def fid_np(real_uint8, gen_uint8):
    act1, act2 = _get_inception_activations(real_uint8), _get_inception_activations(gen_uint8)
    mu1, mu2 = act1.mean(0), act2.mean(0)
    sigma1, sigma2 = np.cov(act1, rowvar=False), np.cov(act2, rowvar=False)
    covmean = sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return float(np.sum((mu1 - mu2) ** 2) + np.trace(sigma1 + sigma2 - 2 * covmean))

# --- HSIC / MIG / SAP / quase-ortogonal -----------------------------------

def _hsic(K, L):
    n = K.shape[0]
    H = np.eye(n) - np.ones((n, n)) / n
    HKH, HLH = H @ K @ H, H @ L @ H
    return np.trace(HKH @ HLH) / ((n - 1) ** 2)

def metric_hsic(z: np.ndarray, attr: np.ndarray):
    n, d = z.shape
    hsic_vals = []
    for k in range(d):
        zk = z[:, k:k + 1]
        K = np.exp(-squareform(pdist(zk, 'sqeuclidean')) / (np.median(zk) ** 2 + 1e-8))
        for j in range(attr.shape[1]):
            aj = attr[:, j:j + 1]
            L = (aj == aj.T).astype(np.float32)
            hsic_vals.append(_hsic(K, L))
    return float(np.mean(hsic_vals))

def metric_mutual_info(z, attr):
    n_lat = z.shape[1]
    mi = np.zeros((n_lat, attr.shape[1]))
    for i in range(n_lat):
        zi_disc = np.digitize(z[:, i], np.histogram(z[:, i], bins=20)[1][:-1])
        for j in range(attr.shape[1]):
            mi[i, j] = mutual_info_score(zi_disc, attr[:, j])
    return mi

def metric_mig(z, attr):
    mi = metric_mutual_info(z, attr)
    entropy_attr = np.array([mutual_info_score(attr[:, j], attr[:, j]) for j in range(attr.shape[1])])
    sorted_mi = -np.sort(-mi, axis=0)
    gap = (sorted_mi[0] - sorted_mi[1]) / (entropy_attr + 1e-12)
    return float(np.mean(gap))

def metric_sap(z, attr):
    mi = metric_mutual_info(z, attr)
    sorted_mi = -np.sort(-mi, axis=0)
    return float(np.mean(sorted_mi[0] - sorted_mi[1]))

def quasi_orthogonality(z):
    zc = z - z.mean(0)
    cov = np.cov(zc, rowvar=False)
    off = cov - np.diag(np.diag(cov))
    max_abs = np.max(np.abs(off))
    return float(max_abs < 1e-5), float(max_abs)

# ----------------------------------------------------------------------------
# InfoGAN Trainer 

class InfoGANTrainer:
    def __init__(self, G, DQ, latent_dist: Product, dataset: CelebAWithAttr, batch_size=64,
                 info_coeff=1.0, log_dir='logs', ckpt_dir='ckpt', snapshot=1000, max_iter=100_000,
                 noise_dim=62, cat_dim=10, cont_dim=2):
        self.G, self.DQ = G, DQ
        self.latent_dist, self.dataset = latent_dist, dataset
        self.bs = batch_size
        self.info_coeff = info_coeff
        self.snapshot = snapshot
        self.max_iter = max_iter
        self.noise_dim = noise_dim
        self.cat_dim   = cat_dim
        self.cont_dim  = cont_dim

        self.log_dir, self.ckpt_dir = Path(log_dir), Path(ckpt_dir)
        mkdir_p(self.log_dir); mkdir_p(self.ckpt_dir)

        self.d_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.g_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        # adicionamos um otimizador separado para a cabeça Q
        self.q_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

        # filtramos os pesos da Q-head pelo nome
        self.q_vars = [
            v for v in self.DQ.trainable_variables
            if 'q_cat_logits' in v.name or 'q_cont_params' in v.name
        ]

        self.metric_hist = {'iter': [], 'FID': [], 'HSIC': [], 'MIG': [], 'SAP': [], 'OrthoMax': []} if not self.log_dir.is_dir() else pd.DataFrame(self.log_dir / 'info_gan_metrics.csv').to_dict(orient='list')
        self.metric_path = self.log_dir / 'info_gan_metrics.csv'

    # ----------------------------------------------------------------------
    def _mi_loss(self, z_reg, q_cat_logits, q_cont_params):
        # separa z_reg em z_cat (one‐hot) e z_cont (gaussiano)
        z_cat  = z_reg[:, :self.cat_dim]
        z_cont = z_reg[:, self.cat_dim:]

        # --- 1) Cross‐entropy (categórico) -----------------------------
        cat_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(
                labels=z_cat,
                logits=q_cat_logits
            )
        )

        # --- 2) Log‐likelihood gaussiano (contínuo) ------------------
        # previsão de mean e log‐std
        mean_pred    = q_cont_params[:, :self.cont_dim]
        log_std_pred = q_cont_params[:, self.cont_dim:]
        std_pred     = tf.exp(log_std_pred)

        # log‐likelihood de cada dimensão
        eps = (z_cont - mean_pred) / (std_pred + TINY)
        # log N(z;μ,σ) = −½·(log(2π) + 2·logσ + ε²)
        logli_per_dim = -0.5 * (tf.math.log(2. * np.pi) + 2. * log_std_pred + tf.square(eps))
        logli = tf.reduce_sum(logli_per_dim, axis=1)  # soma sobre dims

        cont_loss = -tf.reduce_mean(logli)  # NLL

        return cat_loss + cont_loss


    # ----------------------------------------------------------------------
    @tf.function
    def _train_step(self, real_imgs):
        # 1) amostra z
        z = self.latent_dist.sample_prior(self.bs)

        
        # Atualiza apenas o Discriminador (sem MI)
        with tf.GradientTape() as d_tape:
            fake = self.G(z, training=True)
            d_real, _, _ = self.DQ(real_imgs, training=True)
            d_fake, _, _ = self.DQ(fake,    training=True)
            d_loss = -tf.reduce_mean(
                tf.math.log(d_real + TINY) +
                tf.math.log(1 - d_fake + TINY)
            )
        d_grads = d_tape.gradient(d_loss, self.DQ.trainable_variables)
        self.d_opt.apply_gradients(zip(d_grads, self.DQ.trainable_variables))

        
        # Atualiza o Gerador (adversarial + MI)
        
        with tf.GradientTape() as g_tape:
            fake = self.G(z, training=True)
            d_fake, q_cat_logits, q_cont_params = self.DQ(fake, training=True)

            # adversarial
            g_adv_loss = -tf.reduce_mean(tf.math.log(d_fake + TINY))

            # MI
            z_reg = z[:, self.noise_dim:]
            mi_loss = self._mi_loss(z_reg, q_cat_logits, q_cont_params)

            g_loss = g_adv_loss + self.info_coeff * mi_loss

        g_grads = g_tape.gradient(g_loss, self.G.trainable_variables)
        self.g_opt.apply_gradients(zip(g_grads, self.G.trainable_variables))

         
        # Atualiza somente a cabeça Q (MI)
        
        with tf.GradientTape() as q_tape:
            fake = self.G(z, training=False)  # G congelado aqui
            _, q_cat_logits, q_cont_params = self.DQ(fake, training=True)
            z_reg = z[:, self.noise_dim:]
            mi_loss = self._mi_loss(z_reg, q_cat_logits, q_cont_params)

        q_grads = q_tape.gradient(mi_loss, self.q_vars)
        self.q_opt.apply_gradients(zip(q_grads, self.q_vars))

        return d_loss, g_adv_loss, mi_loss



    # ----------------------------------------------------------------------
    def _evaluate_metrics(self):
        #Calcula FID (±1 k amostras) + métricas de disentanglement sem estourar memória."""
        N_EVAL = 1000  # amostras para FID / HSIC / etc.
        real_flat, real_attr = self.dataset.next_batch(N_EVAL)
        z = self.latent_dist.sample_prior(N_EVAL)
        gen_imgs = self.G(z, training=False).numpy()

        fid = fid_np(
                self.dataset.inverse_transform(real_flat),
                self.dataset.inverse_transform(gen_imgs.reshape(real_flat.shape))
            )

        z_np = z.numpy()
        hsic = metric_hsic(z_np, real_attr)
        mig = metric_mig(z_np, real_attr)
        sap = metric_sap(z_np, real_attr)
        _, omax = quasi_orthogonality(z_np)
        return fid, hsic, mig, sap, omax

    # ----------------------------------------------------------------------
    def train(self):
        for step in tqdm(range(1, self.max_iter + 1)):
            real_flat, _ = self.dataset.next_batch(self.bs)
            real_imgs = real_flat.reshape((-1,) + self.dataset.image_shape)
            d_loss, g_loss, mi = self._train_step(real_imgs)

            if step % self.snapshot == 0:
                fid, hsic, mig, sap, omax = self._evaluate_metrics()
                self.metric_hist['iter'].append(step)
                self.metric_hist['FID'].append(fid)
                self.metric_hist['HSIC'].append(hsic)
                self.metric_hist['MIG'].append(mig)
                self.metric_hist['SAP'].append(sap)
                self.metric_hist['OrthoMax'].append(omax)
                pd.DataFrame(self.metric_hist).to_csv(self.metric_path, index=False)
                ckpt_path = str(self.ckpt_dir / f'ckpt_{step}')
                self.ckpt.save(ckpt_path)
                print(f"Step {step}: FID {fid:.1f} | HSIC {hsic:.4f} | MIG {mig:.4f} | SAP {sap:.4f} | max|off‑diag| {omax:.3e}")
    
    def generate_samples(self,
                         n_samples: int = 8,
                         output_dir: str = 'samples',
                         prefix: str = 'sample'):
        """
        Gera n amostras usando o G treinado e salva como PNG em output_dir.
        Args:
            n_samples: número de imagens a gerar (padrão 8).
            output_dir: pasta onde os arquivos serão salvos.
            prefix: prefixo do nome de cada arquivo (ex: 'sample_0.png').
        """
        # garante que a pasta existe
        mkdir_p(output_dir)

        # amostra do espaço latente
        z = self.latent_dist.sample_prior(n_samples)
        # gera imagens
        gen = self.G(z, training=False).numpy()
        # converte do formato flat [-1,1] para uint8 [0,255]
        imgs_uint8 = self.dataset.inverse_transform(gen.reshape(n_samples, -1))

        # salva cada imagem
        for i, img in enumerate(imgs_uint8):
            path = os.path.join(output_dir, f"{prefix}_{i}.png")
            Image.fromarray(img).save(path)

        print(f"Geradas e salvas {n_samples} amostras em '{output_dir}/'.")


In [None]:
BATCH = 64
IMG_SHAPE = (64, 64, 3)
    
#Parametros para latente (dimensão do vetor \Re^(noise_dim + cat_dim + cont_dim))
noise_dim, cat_dim, cont_dim = 62, 10, 2
    
# O espaço latente modelado pelo objeto Product que agrupa três distribuições:
latent_dist = Product([Gaussian(noise_dim, fix_std=True),Categorical(cat_dim),Gaussian(cont_dim)])

G = build_generator(noise_dim + cat_dim + cont_dim, IMG_SHAPE)
DQ = build_discriminator_q(IMG_SHAPE, cat_dim, cont_dim)
data = CelebAWithAttr('.', IMG_SHAPE)

trainer = InfoGANTrainer(G, DQ, latent_dist, data, batch_size=BATCH, max_iter=300000, snapshot=1000, 
                             noise_dim=noise_dim,cat_dim=cat_dim, cont_dim=cont_dim)
trainer.train()

IndexError: list index out of range