# AutoStyleGAN

## Importaciones Generales

In [1]:
!pip install munch

Collecting munch
  Downloading munch-4.0.0-py2.py3-none-any.whl.metadata (5.9 kB)
Downloading munch-4.0.0-py2.py3-none-any.whl (9.9 kB)
Installing collected packages: munch
Successfully installed munch-4.0.0


Munch es una pequeña librería de Python que sirve para convertir diccionarios (dict) en objetos que se pueden usar como si fueran atributos.

In [2]:
import os
import math
import copy
import argparse
import random
import time
import numpy as np
from munch import Munch
from collections import OrderedDict
from scipy import linalg
import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision import models
from torchvision.utils import save_image, make_grid

import cv2
from PIL import Image

## Utilidades y Funciones Auxiliares

In [3]:
def he_init(module):
    if isinstance(module, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(module.weight)
        if module.bias is not None:
            nn.init.constant_(module.bias, 0)

La función `he_init(module)` **inicializa** los pesos de una capa de red neuronal en PyTorch utilizando la **inicialización de He** (ideal para activaciones tipo ReLU) si el módulo es una capa `Conv2d` o `Linear`; además, si la capa tiene un sesgo (`bias`), lo inicializa en **cero**.  
Su propósito es preparar adecuadamente las capas para que el entrenamiento del modelo sea **más rápido y estable**.

La **inicialización de He** es un método para asignar valores iniciales a los pesos de una red neuronal, diseñado para redes profundas que usan funciones de activación tipo **ReLU**.

Busca mantener la **varianza estable** entre capas, evitando que los gradientes desaparezcan o exploten durante el entrenamiento, y con ello mejora la velocidad y estabilidad del aprendizaje.

In [4]:
def denormalize(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)

La función convierte un tensor normalizado de `[-1, 1]` a `[0, 1]`, asegurando que todos los valores queden dentro del rango correcto para visualización o guardado.

In [5]:
def label2onehot(labels, dim):
    batch_size = labels.size(0)
    out = torch.zeros(batch_size, dim).to(labels.device)
    out.scatter_(1, labels.unsqueeze(1), 1)
    return out

La función `label2onehot(labels, dim)` convierte un vector de etiquetas en su representación **one-hot** en PyTorch, asegurando que los datos estén en el mismo dispositivo (CPU o GPU).


In [6]:
def resize(x, size):
    return F.interpolate(x, size=size, mode='bilinear', align_corners=True)

La función `resize(x, size)` cambia el tamaño de un tensor de imagen usando interpolación bilineal, asegurando que los bordes estén alineados.

In [7]:
def make_image_grid(x, nrow=8, padding=2):
    x = denormalize(x)
    grid = make_grid(x, nrow=nrow, padding=padding)
    return grid

La función `make_image_grid(x, nrow=8, padding=2)` desnormaliza un lote de imágenes y organiza varias de ellas en una rejilla o mosaico ordenado para fácil visualización o guardado.


In [8]:
def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)

La función `accumulate(model1, model2, decay)` actualiza los parámetros de `model1` como un **promedio exponencial** entre su valor actual y el valor de `model2`, para hacer que `model1` evolucione suavemente hacia `model2` y aumentar su estabilidad.

In [9]:
class MovingAverage:
    def __init__(self, decay=0.999):
        self.decay = decay
        self.ema = None

    def update(self, x):
        if self.ema is None:
            self.ema = x
        else:
            self.ema = self.decay * self.ema + (1 - self.decay) * x
        return self.ema

La clase `MovingAverage` calcula un **promedio móvil exponencial**, donde cada nuevo valor se combina de manera suavizada con el promedio anterior, permitiendo seguir las tendencias de los datos de forma estable y controlada por el parámetro `decay`.

## Definición de Modelos

In [10]:
class ResBlk(nn.Module):
    def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
        super().__init__()
        self.normalize = normalize
        self.downsample = downsample
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, 3, 1, 1),
            nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity(),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, 3, 1, 1),
            nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity()
        )
        self.downsample_layer = nn.AvgPool2d(2) if downsample else nn.Identity()
        self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)

    def forward(self, x):
        out = self.main(x)
        out = self.downsample_layer(out)
        skip = self.skip(x)
        skip = self.downsample_layer(skip)
        return (out + skip) / math.sqrt(2)

`ResBlk` es un bloque residual que procesa una imagen con convoluciones y normalizaciones, opcionalmente reduce su tamaño, y combina su salida con una versión adaptada de la entrada original para **mejorar el flujo de información** y **evitar pérdida de gradientes** en redes profundas.

In [11]:
class AdaIN(nn.Module):
    def __init__(self, num_features, style_dim):
        super(AdaIN, self).__init__()
        self.fc = nn.Linear(style_dim, num_features * 2)

    def forward(self, x, s):
        h = self.fc(s)
        gamma, beta = torch.chunk(h, chunks=2, dim=1)
        gamma = gamma.unsqueeze(2).unsqueeze(3)
        beta = beta.unsqueeze(2).unsqueeze(3)
        return (1 + gamma) * x + beta

La clase `AdaIN` ajusta la imagen `x` **canal por canal** usando parámetros (`gamma` y `beta`) que se derivan de un **vector de estilo** `s`, permitiendo cambiar dinámicamente el **estilo** de una imagen.


In [12]:
class AdainResBlk(nn.Module):
    def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False):
        super().__init__()
        self.upsample = upsample
        self.w_hpf = w_hpf

        self.norm1 = AdaIN(dim_in, style_dim)
        self.norm2 = AdaIN(dim_out, style_dim)
        self.actv = nn.LeakyReLU(0.2)
        self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)

        if dim_in != dim_out:
            self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0)
        else:
            self.skip = nn.Identity()

    def forward(self, x, s):
        x_orig = x

        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
            x_orig = F.interpolate(x_orig, scale_factor=2, mode='nearest')

        h = self.norm1(x, s)
        h = self.actv(h)
        h = self.conv1(h)

        h = self.norm2(h, s)
        h = self.actv(h)
        h = self.conv2(h)

        skip = self.skip(x_orig)

        out = (h + skip) / math.sqrt(2)
        return out

`AdainResBlk` es un **bloque residual** que realiza procesamiento convolucional, cambia el estilo de la imagen usando **AdaIN** basado en un vector de estilo, y puede **subir la resolución** si se necesita, manteniendo estable la información mediante una **suma residual**.


In [13]:
class Generator(nn.Module):
    def __init__(self, img_size=256, style_dim=64, max_conv_dim=512):
        super().__init__()
        dim_in = 64
        blocks = []
        blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
        repeat_num = int(np.log2(img_size)) - 4
        for _ in range(repeat_num):
            dim_out = min(dim_in*2, max_conv_dim)
            blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)]
            dim_in = dim_out
        self.encode = nn.Sequential(*blocks)

        self.decode = nn.ModuleList()
        for _ in range(repeat_num):
            dim_out = dim_in // 2
            self.decode += [AdainResBlk(dim_in, dim_out, style_dim, upsample=True)]
            dim_in = dim_out
        self.to_rgb = nn.Sequential(
            nn.InstanceNorm2d(dim_in, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_in, 3, 1, 1, 0)
        )

    def forward(self, x, s):
        x = self.encode(x)
        for block in self.decode:
            x = block(x, s)
        out = self.to_rgb(x)
        return out

El `Generator` toma una imagen y un vector de estilo, **codifica** la imagen para extraer características, **inyecta** el estilo mientras va **aumentando la resolución**, y finalmente genera una **imagen estilizada** en formato RGB.



In [14]:
class MappingNetwork(nn.Module):
    def __init__(self, latent_dim=16, style_dim=64, num_domains=2, hidden_dim=512):
        super(MappingNetwork, self).__init__()
        layers = [
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU()
        ]
        for _ in range(3):
            layers += [
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU()
            ]
        self.shared = nn.Sequential(*layers)

        self.unshared = nn.ModuleList()
        for _ in range(num_domains):
            self.unshared.append(nn.Linear(hidden_dim, style_dim))

    def forward(self, z, y):
        h = self.shared(z)
        out = []
        for layer in self.unshared:
            out.append(layer(h))
        out = torch.stack(out, dim=1)
        idx = torch.arange(y.size(0)).to(y.device)
        s = out[idx, y]
        return s

`MappingNetwork` convierte un vector latente aleatorio `z` en un **vector de estilo** `s` adaptado al **dominio deseado** `y`, utilizando una parte compartida para el procesamiento general y partes independientes para especializar cada dominio.


In [15]:
class StyleEncoder(nn.Module):
    def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
        super().__init__()
        dim_in = 64
        blocks = []
        blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
        repeat_num = int(np.log2(img_size)) - 2

        for _ in range(repeat_num):
            dim_out = min(dim_in*2, max_conv_dim)
            blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)]
            dim_in = dim_out

        blocks += [nn.LeakyReLU(0.2)]
        self.shared = nn.Sequential(*blocks)

        self.unshared = nn.ModuleList()
        for _ in range(num_domains):
            self.unshared += [nn.Linear(dim_in, style_dim)]

    def forward(self, x, y):
        h = self.shared(x)
        h = F.adaptive_avg_pool2d(h, (1,1))
        h = h.view(h.size(0), -1)
        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = torch.stack(out, dim=1)
        idx = torch.arange(y.size(0)).to(y.device)
        s = out[idx, y]
        return s


`StyleEncoder` convierte una imagen `x` en un **vector de estilo** `s`, adaptado al **dominio especificado** `y`, usando una codificación convolucional profunda seguida de una red específica para cada dominio.

In [16]:
class Discriminator(nn.Module):
    def __init__(self, img_size=256, num_domains=3, max_conv_dim=512):
        super().__init__()
        dim_in = 64
        blocks = []
        blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
        repeat_num = int(np.log2(img_size)) - 2
        for _ in range(repeat_num):
            dim_out = min(dim_in*2, max_conv_dim)
            blocks += [ResBlk(dim_in, dim_out, normalize=False, downsample=True)]
            dim_in = dim_out
        self.main = nn.Sequential(*blocks)
        self.conv1 = nn.Conv2d(dim_out, num_domains, 1, 1, 0)

    def forward(self, x, y):
        h = self.main(x)
        out = self.conv1(h)
        out = out.mean([2, 3])
        idx = torch.arange(y.size(0)).to(y.device)
        out = out[idx, y]
        return out

El `Discriminator` toma una imagen `x` y un dominio `y`, extrae características profundas, y devuelve una evaluación sobre **qué tan real** es la imagen **en ese dominio específico**.


## Funciones de Pérdida

In [17]:
def adversarial_loss(logits, target):
    targets = torch.full_like(logits, fill_value=target)
    loss = F.binary_cross_entropy_with_logits(logits, targets)
    return loss

La función `adversarial_loss` calcula la **pérdida de entropía cruzada binaria** entre los logits de predicción y un objetivo (`target`=0 o 1).


In [18]:
def r1_reg(d_out, x_in):
    batch_size = x_in.size(0)
    grad_dout = torch.autograd.grad(
        outputs=d_out.sum(),
        inputs=x_in,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = grad_dout2.view(batch_size, -1).sum(1).mean(0)
    return reg

`r1_reg` calcula la **penalización R1**, que mide **qué tan sensible es el Discriminador** respecto a pequeñas perturbaciones en las imágenes reales, y **penaliza** gradientes grandes para hacer el modelo más **estable y robusto**.

## Cargador de Datos

In [19]:
class ImageFolder(Dataset):
    def __init__(self, root, transform, mode, which='source'):
        self.transform = transform
        self.paths = []
        domains = sorted(os.listdir(root))
        for domain in domains:
            if os.path.isdir(os.path.join(root, domain)):
                files = os.listdir(os.path.join(root, domain))
                files = [os.path.join(root, domain, f) for f in files]
                self.paths += [(f, domains.index(domain)) for f in files]
        if mode == 'train' and which == 'reference':
            random.shuffle(self.paths)

    def __getitem__(self, index):
        path, label = self.paths[index]
        img = Image.open(path).convert('RGB')
        return self.transform(img), label

    def __len__(self):
        return len(self.paths)

La clase `ImageFolder` carga imágenes desde carpetas organizadas por dominio, aplica transformaciones y devuelve pares (imagen transformada, etiqueta de dominio).

In [20]:
def get_transform(img_size, mode='train', prob=0.5):
    transform = []

    transform.append(transforms.Resize((img_size, img_size)))

    if mode == 'train':
        transform.append(transforms.RandomHorizontalFlip())
        transform.append(transforms.RandomApply([
            transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0))
        ], p=prob))

    transform.append(transforms.ToTensor())
    transform.append(transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                           std=[0.5, 0.5, 0.5]))
    return transforms.Compose(transform)

`get_transform` genera una **secuencia de transformaciones de preprocesamiento y augmentación** de imágenes para entrenamiento o validación, incluyendo cambios aleatorios, redimensionamiento, conversión a tensor y normalización.


In [21]:
def get_train_loader(root, which='source', img_size=256, batch_size=8, prob=0.5, num_workers=4):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=prob),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    dataset = ImageFolder(root=root, transform=transform, mode=which)
    loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

    return loader

`get_train_loader` crea un `DataLoader` para cargar **batches de imágenes preprocesadas** y **aumentadas**, listas para ser usadas durante el entrenamiento.


In [22]:
def get_test_loader(root, img_size=256, batch_size=8, shuffle=False, num_workers=4, mode='reference'):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    dataset = ImageFolder(root=root, transform=transform, mode=mode)
    loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=False)

    return loader

`get_test_loader` crea un `DataLoader` que carga imágenes preprocesadas para la fase de **prueba o evaluación**, **sin augmentaciones**, y las organiza en batches.


## Métricas de Evaluación

In [23]:
class InceptionV3(nn.Module):
    def __init__(self):
        super().__init__()
        inception = models.inception_v3(pretrained=True, transform_input=False)
        self.blocks = nn.Sequential(
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(3, stride=2),
            inception.Conv2d_3b_1x1,
            inception.Conv2d_4a_3x3,
            nn.MaxPool2d(3, stride=2),
            inception.Mixed_5b,
            inception.Mixed_5c,
            inception.Mixed_5d,
            inception.Mixed_6a,
            inception.Mixed_6b,
            inception.Mixed_6c,
            inception.Mixed_6d,
            inception.Mixed_6e
        )

    def forward(self, x):
        return self.blocks(x)

`InceptionV3` extrae **representaciones profundas** de imágenes usando una **parte de la red Inception v3 preentrenada**.

In [24]:
def calculate_fid(real_features, fake_features):
    mu1 = np.mean(real_features, axis=0)
    mu2 = np.mean(fake_features, axis=0)
    sigma1 = np.cov(real_features, rowvar=False)
    sigma2 = np.cov(fake_features, rowvar=False)

    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2*np.trace(covmean)
    return fid

La función `calculate_fid` compara las distribuciones de características de imágenes reales y generadas calculando la **Fréchet Inception Distance (FID).

In [25]:
class LPIPS(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = models.vgg16(pretrained=True).features[:16].eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        x_vgg = self.vgg(x)
        y_vgg = self.vgg(y)
        return F.l1_loss(x_vgg, y_vgg)

La clase `LPIPS` mide la **distancia perceptual** entre dos imágenes, **comparando sus características internas** extraídas de las primeras capas de una red **VGG16 preentrenada**, y usando una **pérdida L1** sobre esas características.


## Definición de Solver

In [26]:
class Solver(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.G = Generator(args.img_size, args.style_dim)
        self.D = Discriminator(args.img_size, args.num_domains)
        self.M = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains)
        self.S = StyleEncoder(args.img_size, args.style_dim, args.num_domains)

        self.G.to(self.device)
        self.D.to(self.device)
        self.M.to(self.device)
        self.S.to(self.device)

        self.g_optimizer = torch.optim.Adam(
            list(self.G.parameters()) + list(self.M.parameters()) + list(self.S.parameters()),
            args.lr, [args.beta1, args.beta2], weight_decay=args.weight_decay)
        self.d_optimizer = torch.optim.Adam(self.D.parameters(),
            args.lr, [args.beta1, args.beta2], weight_decay=args.weight_decay)

        self.start_iter = 0
        self.inception = InceptionV3().to(self.device)
        self.lpips = LPIPS().to(self.device)

        if self.args.resume_iter > 0:
            ckpt_path = os.path.join(self.args.checkpoint_dir, f'{self.args.resume_iter}_nets_ema.ckpt')
            if os.path.exists(ckpt_path):
                print(f"Cargando checkpoint desde {ckpt_path}...")
                checkpoint = torch.load(ckpt_path, map_location=self.device)

                self.G.load_state_dict(checkpoint['generator'])
                self.M.load_state_dict(checkpoint['mapping_network'])
                self.S.load_state_dict(checkpoint['style_encoder'])
                self.D.load_state_dict(checkpoint['discriminator'])

                self.start_iter = self.args.resume_iter
                print(f"Checkpoint cargado correctamente. Empezando desde iteración {self.start_iter}.")
            else:
                print(f"No se encontró checkpoint en {ckpt_path}. Entrenando desde cero.")
        else:
            print("Entrenando desde cero (resume_iter = 0).")

    def _reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def _compute_d_loss(self, x_real, y_org, y_trg, z_trg, x_ref, masks=None):
        x_real.requires_grad_()
        out_real = self.D(x_real, y_org)
        loss_real = adversarial_loss(out_real, 1)
        loss_reg = r1_reg(out_real, x_real)

        with torch.no_grad():
            s_trg = self.M(z_trg, y_trg)
            x_fake = self.G(x_real, s_trg)
        out_fake = self.D(x_fake, y_trg)
        loss_fake = adversarial_loss(out_fake, 0)

        loss = loss_real + loss_fake + self.args.lambda_reg * loss_reg
        return loss

    def _compute_g_loss(self, x_real, y_org, y_trg, z_trgs, x_refs, masks=None):
        s_trg = self.M(z_trgs[0], y_trg)
        x_fake = self.G(x_real, s_trg)
        out_fake = self.D(x_fake, y_trg)
        loss_adv = adversarial_loss(out_fake, 1)

        s_pred = self.S(x_fake, y_trg)
        loss_sty = F.l1_loss(s_pred, s_trg)

        x_rec = self.G(x_fake, self.S(x_real, y_org))
        loss_cyc = F.l1_loss(x_rec, x_real)

        s_trg2 = self.M(z_trgs[1], y_trg)
        x_fake2 = self.G(x_real, s_trg2)
        loss_ds = -F.l1_loss(x_fake, x_fake2)

        loss = loss_adv + self.args.lambda_sty * loss_sty + self.args.lambda_cyc * loss_cyc
        if self.args.lambda_ds > 0:
            loss += self.args.lambda_ds * loss_ds

        return loss

    # Entrenamiento
    def train(self, loaders):
      src_loader = loaders.src
      ref_loader = loaders.ref
      print("Training started...")


      wandb.init(
          project="AutoStyleGAN",
          name=f"Entrenamiento_2",
          config=vars(self.args)
      )

      for it in range(self.start_iter, self.args.total_iters):
          x_real, y_org = next(iter(src_loader))
          x_ref, y_trg = next(iter(ref_loader))
          x_real, y_org = x_real.to(self.device), y_org.to(self.device)
          x_ref, y_trg = x_ref.to(self.device), y_trg.to(self.device)

          # Entrenamiento del discriminador
          z_trg = torch.randn(x_real.size(0), self.args.latent_dim).to(self.device)
          z_trgs = [torch.randn(x_real.size(0), self.args.latent_dim).to(self.device) for _ in range(2)]

          d_loss = self._compute_d_loss(x_real, y_org, y_trg, z_trg, x_ref)
          self._reset_grad()
          d_loss.backward()
          self.d_optimizer.step()

          # Entrenamiento del generador
          g_loss = self._compute_g_loss(x_real, y_org, y_trg, z_trgs, x_ref)
          self._reset_grad()
          g_loss.backward()
          self.g_optimizer.step()

          if (it + 1) % self.args.print_every == 0:
              print(f"Iter [{it+1}/{self.args.total_iters}] d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")

              wandb.log({
                  "Discriminator Loss": d_loss.item(),
                  "Generator Loss": g_loss.item(),
                  "Iteration": it + 1
              })


          if (it + 1) % self.args.save_every == 0:

              ckpt_path = os.path.join(self.args.checkpoint_dir, f'{it+1}_nets_ema.ckpt')
              torch.save({
                  'generator': self.G.state_dict(),
                  'mapping_network': self.M.state_dict(),
                  'style_encoder': self.S.state_dict(),
                  'discriminator': self.D.state_dict()
              }, ckpt_path)
              print(f"Checkpoint guardado en {ckpt_path}")

              self.G.eval()
              self.S.eval()
              with torch.no_grad():
                  x_src, y_src = next(iter(src_loader))
                  x_ref, y_ref = next(iter(ref_loader))
                  x_src, y_src = x_src.to(self.device), y_src.to(self.device)
                  x_ref, y_ref = x_ref.to(self.device), y_ref.to(self.device)

                  s_ref = self.S(x_ref, y_ref)
                  x_fake = self.G(x_src, s_ref)
                  x_concat = torch.cat([x_src, x_ref, x_fake], dim=3)

                  os.makedirs(self.args.sample_dir, exist_ok=True)

                  sample_path = os.path.join(self.args.sample_dir, f'sample_{it+1}.png')
                  save_image((x_concat.data.cpu() + 1) / 2, sample_path, nrow=1, padding=0)
                  print(f"Muestra guardada en {sample_path}")

                  wandb.log({
                      f"Samples_{it+1}": wandb.Image(sample_path)
                  })
              self.G.train()
              self.S.train()


    def sample(self, loaders):
      src_loader = loaders.src
      ref_loader = loaders.ref
      print("Sampling started...")

      x_src, y_src = next(iter(src_loader))
      x_ref, y_ref = next(iter(ref_loader))

      n_samples = min(x_src.size(0), x_ref.size(0))

      x_src, y_src = x_src[:n_samples].to(self.device), y_src[:n_samples].to(self.device)
      x_ref, y_ref = x_ref[:n_samples].to(self.device), y_ref[:n_samples].to(self.device)

      with torch.no_grad():
          s_ref = self.S(x_ref, y_ref)
          x_fake = self.G(x_src, s_ref)
          x_concat = torch.cat([x_src, x_ref, x_fake], dim=3)

          save_path = os.path.join(self.args.sample_dir, 'sample.png')
          os.makedirs(self.args.sample_dir, exist_ok=True)
          save_image(denormalize(x_concat.data.cpu()), save_path, nrow=1, padding=0, normalize=True)

      print(f"Sample images saved to {save_path}")


    def evaluate(self):
        print("Evaluation started...")
        real_images = torch.randn(50, 3, 256, 256).to(self.device)
        fake_images = torch.randn(50, 3, 256, 256).to(self.device)

        with torch.no_grad():
            real_feats = self.inception(resize(real_images, size=(299,299))).view(50, -1).cpu().numpy()
            fake_feats = self.inception(resize(fake_images, size=(299,299))).view(50, -1).cpu().numpy()

        fid_score = calculate_fid(real_feats, fake_feats)
        lpips_score = self.lpips(real_images, fake_images).mean().item()

        print(f"FID: {fid_score:.4f}, LPIPS: {lpips_score:.4f}")

La clase `Solver` es el **controlador** del entrenamiento, generación y evaluación del GAN,  
encargándose de calcular pérdidas, actualizar parámetros, crear muestras, y evaluar la calidad de las imágenes generadas.


## Código Principal

In [27]:
def str2bool(v):
    return v.lower() in ('true')

La función convierte una cadena de texto (v) en un valor booleano (True o False),
basándose en si el texto es 'true' (sin distinguir mayúsculas o minúsculas).

In [28]:
def subdirs(dname):
    return [d for d in os.listdir(dname)
            if os.path.isdir(os.path.join(dname, d))]

Esta función te devuelve una lista con todos los subdirectorios (carpetas) que hay dentro de una carpeta principal.

In [29]:
def main(args):
    print(args)
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains

        loaders = Munch(
            src=get_train_loader(root=args.train_img_dir,
                                 which='source',
                                 img_size=args.img_size,
                                 batch_size=args.batch_size,
                                 prob=args.randcrop_prob,
                                 num_workers=args.num_workers),
            ref=get_train_loader(root=args.train_img_dir,
                                 which='reference',
                                 img_size=args.img_size,
                                 batch_size=args.batch_size,
                                 prob=args.randcrop_prob,
                                 num_workers=args.num_workers),
            val=get_test_loader(root=args.val_img_dir,
                                img_size=args.img_size,
                                batch_size=args.val_batch_size,
                                shuffle=True,
                                num_workers=args.num_workers)
        )
        solver.train(loaders)

    elif args.mode == 'sample':
        loaders = Munch(
            src=get_test_loader(root=args.src_dir,
                                img_size=args.img_size,
                                batch_size=args.val_batch_size,
                                shuffle=False,
                                num_workers=args.num_workers),
            ref=get_test_loader(root=args.ref_dir,
                                img_size=args.img_size,
                                batch_size=args.val_batch_size,
                                shuffle=False,
                                num_workers=args.num_workers)
        )
        solver.sample(loaders)

    elif args.mode == 'eval':
        solver.evaluate()

    else:
        raise NotImplementedError

La función `main(args)` es el núcleo que, dependiendo del valor de `args.mode`, controla el flujo del programa: primero imprime los argumentos, optimiza PyTorch para la GPU y fija una semilla para reproducibilidad.

Luego crea un objeto `Solver` que se encarga de las acciones principales.

- Si el modo es `'train'`, carga los datos de entrenamiento, validación y referencia, verifica que el número de subcarpetas coincida con los dominios esperados, y llama a `solver.train()`.

- Si es `'sample'`, carga datos de prueba para generar muestras con `solver.sample()`.

- Si es `'eval'`, ejecuta la evaluación del modelo con `solver.evaluate()`.

- En cualquier otro caso, lanza un error indicando que el modo no está implementado.

## Configuración de Parámetros

In [30]:
!git clone https://github.com/Pacolaz/AutoStyleGAN

Cloning into 'AutoStyleGAN'...
remote: Enumerating objects: 958, done.[K
remote: Counting objects: 100% (958/958), done.[K
remote: Compressing objects: 100% (955/955), done.[K
remote: Total 958 (delta 2), reused 949 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (958/958), 8.97 MiB | 12.36 MiB/s, done.
Resolving deltas: 100% (2/2), done.


In [31]:
project_root = 'AutoStyleGAN'

expr_folders = [
    'expr/checkpoints/autos',
    'expr/eval',
    'expr/results',
    'expr/samples'
]

for folder in expr_folders:
    path = os.path.join(project_root, folder)
    os.makedirs(path, exist_ok=True)

print("Carpetas creadas exitosamente")


Carpetas creadas exitosamente


In [32]:
from types import SimpleNamespace

args = SimpleNamespace(
    img_size=128,
    num_domains=3, # BMW, Corvette, Mazda
    latent_dim=16,
    hidden_dim=1024,
    style_dim=64,
    lambda_reg=1,
    lambda_cyc=1,
    lambda_sty=1,
    lambda_ds=1,
    ds_iter=100000,
    randcrop_prob=0.5,
    total_iters=100000,
    resume_iter=24000,
    batch_size=4,
    val_batch_size=4,
    lr=1e-4,
    f_lr=1e-6,
    beta1=0.0,
    beta2=0.99,
    weight_decay=1e-4,
    num_outs_per_domain=10,
    mode='train', # train, sample, eval
    num_workers=4,
    seed=8365,
    train_img_dir = 'AutoStyleGAN/dataset/train',
    val_img_dir = 'AutoStyleGAN/dataset/val',
    sample_dir = 'AutoStyleGAN/expr/samples',
    checkpoint_dir = 'AutoStyleGAN/expr/checkpoints',
    eval_dir = 'AutoStyleGAN/expr/eval',
    result_dir = 'AutoStyleGAN/expr/results',
    src_dir = 'AutoStyleGAN/assets/src',
    ref_dir = 'AutoStyleGAN/assets/ref',
    print_every=100,
    sample_every=500,
    save_every=500,
    eval_every=500
)

In [None]:
main(args)

namespace(img_size=128, num_domains=3, latent_dim=16, hidden_dim=1024, style_dim=64, lambda_reg=1, lambda_cyc=1, lambda_sty=1, lambda_ds=1, ds_iter=100000, randcrop_prob=0.5, total_iters=100000, resume_iter=24000, batch_size=4, val_batch_size=4, lr=0.0001, f_lr=1e-06, beta1=0.0, beta2=0.99, weight_decay=0.0001, num_outs_per_domain=10, mode='train', num_workers=4, seed=8365, train_img_dir='AutoStyleGAN/dataset/train', val_img_dir='AutoStyleGAN/dataset/val', sample_dir='AutoStyleGAN/expr/samples', checkpoint_dir='AutoStyleGAN/expr/checkpoints', eval_dir='AutoStyleGAN/expr/eval', result_dir='AutoStyleGAN/expr/results', src_dir='AutoStyleGAN/assets/src', ref_dir='AutoStyleGAN/assets/ref', print_every=100, sample_every=500, save_every=500, eval_every=500)




Cargando checkpoint desde AutoStyleGAN/expr/checkpoints/24000_nets_ema.ckpt...
Checkpoint cargado correctamente. Empezando desde iteración 24000.
Training started...


Iter [24100/100000] d_loss: 0.5440, g_loss: 2.9086
Iter [24200/100000] d_loss: 0.6751, g_loss: 2.4661
Iter [24300/100000] d_loss: 0.7235, g_loss: 2.4762
Iter [24400/100000] d_loss: 0.7929, g_loss: 1.9892
Iter [24500/100000] d_loss: 0.4674, g_loss: 2.5760
Checkpoint guardado en AutoStyleGAN/expr/checkpoints/24500_nets_ema.ckpt
Muestra guardada en AutoStyleGAN/expr/samples/sample_24500.png




Iter [24600/100000] d_loss: 0.5507, g_loss: 1.9651
Iter [24700/100000] d_loss: 0.4183, g_loss: 3.1053
Iter [24800/100000] d_loss: 0.7290, g_loss: 1.9248
Iter [24900/100000] d_loss: 0.6641, g_loss: 1.4705
Iter [25000/100000] d_loss: 0.6118, g_loss: 2.5347
Checkpoint guardado en AutoStyleGAN/expr/checkpoints/25000_nets_ema.ckpt
Muestra guardada en AutoStyleGAN/expr/samples/sample_25000.png
Iter [25100/100000] d_loss: 0.8215, g_loss: 1.5936
Iter [25200/100000] d_loss: 1.1846, g_loss: 2.3469
Iter [25300/100000] d_loss: 0.6002, g_loss: 1.9973
Iter [25400/100000] d_loss: 0.7153, g_loss: 2.4499
Iter [25500/100000] d_loss: 0.7628, g_loss: 3.5995
Checkpoint guardado en AutoStyleGAN/expr/checkpoints/25500_nets_ema.ckpt
Muestra guardada en AutoStyleGAN/expr/samples/sample_25500.png
Iter [25600/100000] d_loss: 0.6612, g_loss: 0.6519
Iter [25700/100000] d_loss: 0.7981, g_loss: 1.6274
Iter [25800/100000] d_loss: 0.6842, g_loss: 3.4070
Iter [25900/100000] d_loss: 0.4642, g_loss: 2.3215
Iter [26000/10