# 📖 Analyzing and Improving the Image Quality of StyleGAN – Summary
# https://arxiv.org/pdf/1912.04958


## 🔎 Introduction
- StyleGAN (2018) achieved state-of-the-art in high-resolution image generation, but showed **characteristic artifacts** (e.g., droplet-like blobs, phase misalignment).
- This paper (often called **StyleGAN2**) identifies the causes and **proposes architectural and training improvements**:
  - Redesigning normalization inside the generator.
  - Removing issues caused by progressive growing.
  - Adding a new **path length regularizer** for smoother latent–image mapping.

---

## 🎯 Motivation
- **Problems in StyleGAN:**
  - Blob-like artifacts from **AdaIN normalization**.
  - Progressive growing introduces **phase artifacts** (misaligned features).
  - Metrics like FID don’t fully capture **shape consistency**.
- **Goal:** Improve image quality, stability, and **make the generator easier to invert**.

---

## 🧮 Key Methods

### 1. Weight Demodulation
- Replaces AdaIN normalization.  
- Prevents the generator from sneaking information through spikes in feature maps.  
- Removes droplet-like artifacts without losing style control.  

### 2. Path Length Regularization
- Encourages the generator to map latent vectors to images in a smoother, well-conditioned way.  
- Uses Jacobian statistics of the generator to enforce consistent scaling.  
- Correlates with **better perceptual quality** (lower Perceptual Path Length → smoother latent interpolations).  

### 3. Alternative to Progressive Growing
- Instead of altering topology during training, StyleGAN2 uses **skip connections** (generator) and **residual networks** (discriminator).  
- This avoids phase artifacts while still letting training focus on coarse-to-fine features.  

### 4. Larger Networks
- Identified a **capacity bottleneck**: StyleGAN underutilized its highest resolutions.  
- Doubling channels in high-res layers improved fidelity significantly.  

---

## ⚙️ Metrics
- **FID (Fréchet Inception Distance):** Distribution similarity.  
- **Precision & Recall (P&R):** Diversity vs. fidelity.  
- **Perceptual Path Length (PPL):** Smoothness and consistency of latent interpolation.  
- PPL correlates best with human perception of quality.  

---

## 🧪 Empirical Results
- **Artifacts removed:** Blob and phase artifacts eliminated.  
- **Better image quality:**  
  - StyleGAN2 achieves lower FID and PPL across datasets (FFHQ, LSUN Cats, Cars, Churches, Horses).  
  - Improved recall (diversity) while maintaining precision.  
- **Projection easier:** StyleGAN2 makes inversion of generated images into latent space more accurate, enabling attribution.  

---

## 📌 Contributions
- Introduced **StyleGAN2** with:
  - Weight demodulation (artifact-free normalization).  
  - Path length regularization (smoother mappings).  
  - Redesigned architecture without progressive growing.  
- Advanced **state-of-the-art image generation** with better quality and stability.  
- Made generated images easier to attribute back to their network of origin.  

---

## 🏆 Impact
- Redefined the **benchmark for unconditional image generation**.  
- Widely adopted as the backbone of face, art, and high-fidelity image synthesis.  
- Inspired follow-ups: **StyleGAN2-ADA** (adaptive augmentation) and **StyleGAN3** (alias-free generation).  

---

✅ **Reference:**  
Karras, T., Laine, S., Aittala, M., Hellsten, J., Lehtinen, J., & Aila, T. (2020). *Analyzing and Improving the Image Quality of StyleGAN*. CVPR 2020.


# 🔄 Comparison: StyleGAN (2019) vs. StyleGAN2 (2020)

| Aspect                          | StyleGAN (CVPR 2019)                                   | StyleGAN2 (CVPR 2020)                                         |
|---------------------------------|--------------------------------------------------------|---------------------------------------------------------------|
| **Core Idea**                   | Style-based generator with AdaIN + noise injection     | Redesigned generator with **weight demodulation**             |
| **Normalization**               | AdaIN (Adaptive Instance Normalization)                | Removed AdaIN → replaced by **weight demod** (artifact-free)  |
| **Artifacts**                   | Droplet-like blobs, phase artifacts from progressive growing | Artifacts largely **eliminated** via new design               |
| **Latent Space**                | \( Z \to W \) mapping for disentanglement              | Same, but improved smoothness with **path length regularizer** |
| **Training Strategy**           | Progressive growing (train from low → high res)        | No progressive growing → **skip/residual connections** instead |
| **Path Length Regularization**  | Not present                                            | Introduced → enforces smooth latent–image mapping              |
| **Capacity**                    | Underutilized high-res layers                          | Doubled channels in high-res layers → improved fidelity        |
| **Metrics**                     | FID, PPL introduced                                   | FID + **Precision/Recall** for diversity/fidelity trade-off    |
| **Datasets**                    | CelebA-HQ, LSUN, FFHQ (introduced)                     | FFHQ, LSUN (Cats, Cars, Churches, Horses)                     |
| **Image Quality**               | High, but with artifacts                              | **State-of-the-art**, sharper, artifact-free, more diverse     |
| **Impact**                      | Foundation for controllable synthesis (style mixing)   | Redefined benchmark; backbone for StyleGAN2-ADA & StyleGAN3    |

---

✅ **Key Takeaway:**  
StyleGAN introduced **style-based control** (coarse-to-fine disentanglement).  
StyleGAN2 fixed **visual artifacts**, improved stability, and became the **new gold standard** for high-fidelity generative modeling.


In [1]:
# Imports
import math, random, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from tqdm import tqdm
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
torch.manual_seed(42)


Device: cuda


<torch._C.Generator at 0x7b3218525050>

In [13]:
# Config (32×32 CIFAR-10)
class Cfg:
    img = 32
    z_dim = 128
    w_dim = 256
    fmap = 128          # base feature maps
    batch = 128
    epochs = 30
    lr = 2e-4
    betas = (0.0, 0.99)
    r1_gamma = 1.0      # R1 regularization weight (D)
    pl_weight = 2.0     # path-length regularization weight (G)
    pl_every = 4        # apply PL every N iters
    r1_every = 16       # apply R1 every N iters
    sample_every = 2

cfg = Cfg()


In [14]:
cfg.batch = 32        # try 16 if you still OOM
cfg.fmap = 96
cfg.w_dim = 192

cfg.pl_weight = 1.0   # 0.0 to temporarily disable
cfg.pl_every  = 16
cfg.r1_gamma  = 0.5
cfg.r1_every  = 64


In [15]:
# Data: CIFAR-10 in [-1, 1]
transform = transforms.Compose([
    transforms.Resize(cfg.img),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
])
ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=cfg.batch, shuffle=True, drop_last=True, num_workers=2)


In [16]:
class PixelNorm(nn.Module):
    def forward(self, x, eps=1e-8):
        return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + eps)

class Mapping(nn.Module):
    def __init__(self, z_dim, w_dim, n_layers=4):
        super().__init__()
        layers = [PixelNorm()]
        dim = z_dim
        for _ in range(n_layers):
            layers += [nn.Linear(dim, w_dim), nn.LeakyReLU(0.2, True)]
            dim = w_dim
        self.net = nn.Sequential(*layers)
    def forward(self, z):  # [B, z_dim] -> [B, w_dim]
        return self.net(z)


In [5]:
class ModulatedConv2d(nn.Module):
    """
    StyleGAN2 modulated convolution with optional demodulation.
    x: [B, C_in, H, W]
    w: [B, w_dim] -> style -> scales for C_in
    """
    def __init__(self, in_ch, out_ch, w_dim, k=3, up=False, demod=True):
        super().__init__()
        self.in_ch, self.out_ch, self.k, self.up, self.demod = in_ch, out_ch, k, up, demod
        self.weight = nn.Parameter(torch.randn(out_ch, in_ch, k, k) * 0.02)
        self.style = nn.Linear(w_dim, in_ch)
        self.bias = nn.Parameter(torch.zeros(out_ch))
        self.pad = k // 2

    def forward(self, x, w):
        B, C, H, W = x.shape
        if self.up:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
            H, W = x.shape[-2:]

        s = self.style(w).view(B, 1, self.in_ch, 1, 1)        # [B,1,Cin,1,1]
        w_mod = self.weight.unsqueeze(0) * (s + 1.0)          # [B,Cout,Cin,k,k]

        if self.demod:
            d = torch.rsqrt((w_mod**2).sum(dim=[2,3,4], keepdim=True) + 1e-8)  # [B,Cout,1,1,1]
            w_mod = w_mod * d

        x = x.view(1, B * self.in_ch, H, W)
        w_mod = w_mod.view(B * self.out_ch, self.in_ch, self.k, self.k)
        out = F.conv2d(x, w_mod, padding=self.pad, groups=B)
        out = out.view(B, self.out_ch, H, W) + self.bias.view(1, -1, 1, 1)
        return out


In [19]:
def lrelu(x): return F.leaky_relu(x, 0.2)

class SynthesisBlock(nn.Module):
    def __init__(self, in_ch, out_ch, w_dim, up):
        super().__init__()
        self.conv1 = ModulatedConv2d(in_ch, out_ch, w_dim, k=3, up=up, demod=True)
        self.conv2 = ModulatedConv2d(out_ch, out_ch, w_dim, k=3, up=False, demod=True)
        self.toRGB = ModulatedConv2d(out_ch, 3, w_dim, k=1, up=False, demod=False)

    def forward(self, x, w1, w2, rgb=None):
        x = lrelu(self.conv1(x, w1))
        x = lrelu(self.conv2(x, w2))
        rgb_new = self.toRGB(x, w2)
        rgb = rgb_new if rgb is None else F.interpolate(rgb, scale_factor=2, mode='nearest') + rgb_new
        return x, rgb

class Generator(nn.Module):
    """
    4x4 -> 8 -> 16 -> 32 with skip-to-RGB (StyleGAN2)
    """
    def __init__(self, z_dim, w_dim, fmap):
        super().__init__()
        self.mapping = Mapping(z_dim, w_dim)
        self.const = nn.Parameter(torch.randn(1, fmap*4, 4, 4))
        self.b4  = SynthesisBlock(fmap*4, fmap*4, w_dim, up=False) # 4x4
        self.b8  = SynthesisBlock(fmap*4, fmap*2, w_dim, up=True)  # 8x8
        self.b16 = SynthesisBlock(fmap*2, fmap,   w_dim, up=True)  # 16x16
        self.b32 = SynthesisBlock(fmap,   fmap//2,w_dim, up=True)  # 32x32

    def forward(self, z, return_w=False):
        w = self.mapping(z)                                 # [B, w_dim]
        x = self.const.repeat(z.size(0), 1, 1, 1)
        rgb = None
        # each block consumes two styles (like StyleGAN2)
        x, rgb = self.b4 (x, w, w, rgb)
        x, rgb = self.b8 (x, w, w, rgb)
        x, rgb = self.b16(x, w, w, rgb)
        x, rgb = self.b32(x, w, w, rgb)
        if return_w: return torch.tanh(rgb), w
        return torch.tanh(rgb)

class ResBlockD(nn.Module):
    def __init__(self, in_ch, out_ch, down=True):
        super().__init__()
        self.down = down
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
        self.skip  = nn.Conv2d(in_ch, out_ch, 1, 1, 0)
    def forward(self, x):
        h = lrelu(self.conv1(x))
        h = lrelu(self.conv2(h))
        if self.down: h = F.avg_pool2d(h, 2)
        s = self.skip(x)
        if self.down: s = F.avg_pool2d(s, 2)
        return (h + s) / math.sqrt(2)

class Discriminator(nn.Module):
    def __init__(self, fmap):
        super().__init__()
        c = fmap
        self.fromRGB = nn.Conv2d(3, c//2, 1)
        self.b32 = ResBlockD(c//2, c,   down=True)
        self.b16 = ResBlockD(c,    c*2, down=True)
        self.b8  = ResBlockD(c*2,  c*4, down=True)
        self.b4  = ResBlockD(c*4,  c*4, down=False)
        self.out = nn.Linear(c*4*4*4, 1)
    def forward(self, x):
        x = lrelu(self.fromRGB(x))
        x = self.b32(x); x = self.b16(x); x = self.b8(x); x = self.b4(x)
        x = x.view(x.size(0), -1)
        return self.out(x).squeeze(1)


In [18]:
G = Generator(cfg.z_dim, cfg.w_dim, cfg.fmap).to(device)
D = Discriminator(cfg.fmap).to(device)
opt_G = optim.Adam(G.parameters(), lr=cfg.lr, betas=cfg.betas)
opt_D = optim.Adam(D.parameters(), lr=cfg.lr, betas=cfg.betas)


In [17]:
from torch.cuda.amp import autocast, GradScaler
g_scaler = GradScaler()
d_scaler = GradScaler()


  g_scaler = GradScaler()
  d_scaler = GradScaler()


In [21]:
def d_loss_fn(real_logits, fake_logits):
    # non-saturating logistic with R1 option outside
    return F.softplus(-real_logits).mean() + F.softplus(fake_logits).mean()

def g_loss_fn(fake_logits):
    return F.softplus(-fake_logits).mean()

def r1_regularizer(x, logits):
    grad = torch.autograd.grad(outputs=logits.sum(), inputs=x, create_graph=True)[0]
    return grad.pow(2).view(grad.size(0), -1).sum(1).mean()

def path_length_reg(img, w):
    """Simplified PL: gradient of (img * noise).sum() w.r.t. w"""
    noise = torch.randn_like(img) / math.sqrt(img.numel() / img.size(0))
    grad = torch.autograd.grad(outputs=(img * noise).sum(), inputs=w, create_graph=True)[0]
    pl = grad.pow(2).sum(1).sqrt().mean()
    return pl


In [22]:
def sample_z(n): return torch.randn(n, cfg.z_dim, device=device)

@torch.no_grad()
def show_grid(imgs, title=""):
    g = utils.make_grid(imgs, nrow=int(math.sqrt(len(imgs))), normalize=True, value_range=(-1,1))
    plt.figure(figsize=(6,6)); plt.imshow(g.permute(1,2,0)); plt.axis('off'); plt.title(title); plt.show()


In [23]:
G.train(); D.train()
g_hist, d_hist = [], []
pl_ema = None

for epoch in range(cfg.epochs):
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{cfg.epochs}")
    for it, (real, _) in enumerate(pbar):
        real = real.to(device)

        # ----------------- D step -----------------
        z = sample_z(real.size(0))
        with torch.no_grad():
            fake = G(z)
        real_logits = D(real)
        fake_logits = D(fake)
        d_loss = d_loss_fn(real_logits, fake_logits)

        if (it % cfg.r1_every) == 0:
            real.requires_grad_(True)
            r1 = r1_regularizer(real, D(real))
            real.requires_grad_(False)
            d_loss = d_loss + (cfg.r1_gamma * 0.5) * r1

        opt_D.zero_grad(set_to_none=True)
        d_loss.backward()
        opt_D.step()

        # ----------------- G step -----------------
        z = sample_z(real.size(0))
        img, w = G(z, return_w=True)
        g_loss = g_loss_fn(D(img))

        # Path-length regularization every N steps
        if (it % cfg.pl_every) == 0:
            pl = path_length_reg(img, w)
            if pl_ema is None: pl_ema = pl.detach()
            pl_pen = (pl - pl_ema).pow(2)
            pl_ema = pl_ema.lerp(pl.detach(), 0.01)
            g_loss = g_loss + cfg.pl_weight * pl_pen

        opt_G.zero_grad(set_to_none=True)
        g_loss.backward()
        opt_G.step()

        g_hist.append(float(g_loss.detach()))
        d_hist.append(float(d_loss.detach()))
        pbar.set_postfix(g=f"{g_hist[-1]:.3f}", d=f"{d_hist[-1]:.3f}")

    if (epoch+1) % cfg.sample_every == 0:
        with torch.no_grad():
            imgs = G(sample_z(16)).cpu()
        show_grid(imgs, title=f"Samples — epoch {epoch+1}")


In [None]:
@torch.no_grad()
def style_mix_demo(n_pairs=8):
    zA = sample_z(n_pairs); zB = sample_z(n_pairs)
    A = G(zA).cpu(); B = G(zB).cpu()

    # crude mix: average ws by convex combination to emulate early/late swap (for demo at 32x32)
    # (For a fuller impl, expose per-block ws in synthesis and swap after a chosen block.)
    mix = G(0.5*zA + 0.5*zB).cpu()

    rows = []
    for i in range(n_pairs):
        rows += [A[i], B[i], mix[i]]
    grid = utils.make_grid(torch.stack(rows), nrow=3, normalize=True, value_range=(-1,1))
    plt.figure(figsize=(8, 3*n_pairs/3)); plt.imshow(grid.permute(1,2,0)); plt.axis('off')
    plt.title("Style mixing — A / B / A→B"); plt.show()

style_mix_demo()


In [None]:
G.train(); D.train()
g_hist, d_hist = [], []
pl_ema = None

for epoch in range(cfg.epochs):
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{cfg.epochs}")
    for it, (real, _) in enumerate(pbar):
        real = real.to(device, non_blocking=True)

        # ----------------- D step (AMP) -----------------
        z = torch.randn(real.size(0), cfg.z_dim, device=device)
        with torch.no_grad():
            with autocast():
                fake = G(z)

        opt_D.zero_grad(set_to_none=True)
        with autocast():
            real_logits = D(real)
            fake_logits = D(fake)
            d_loss = F.softplus(-real_logits).mean() + F.softplus(fake_logits).mean()

        # R1 every r1_every iters, computed in fp32 (outside autocast) on a smaller subset
        if (it % cfg.r1_every) == 0 and cfg.r1_gamma > 0.0:
            mb = min(16, real.size(0))  # microbatch to save memory
            real_small = real[:mb].detach().requires_grad_(True)
            real_logits_small = D(real_small)  # fp32 since outside autocast
            r1 = torch.autograd.grad(outputs=real_logits_small.sum(),
                                     inputs=real_small,
                                     create_graph=True)[0]
            r1 = r1.pow(2).view(mb, -1).sum(1).mean()
            d_loss = d_loss + (cfg.r1_gamma * 0.5) * r1

        d_scaler.scale(d_loss).backward()
        d_scaler.step(opt_D)
        d_scaler.update()

        # ----------------- G step (AMP) -----------------
        z = torch.randn(real.size(0), cfg.z_dim, device=device)

        opt_G.zero_grad(set_to_none=True)
        with autocast():
            img, w = G(z, return_w=True)
            g_loss = F.softplus(-D(img)).mean()

        # Path-length every pl_every iters, compute in fp32 on microbatch
        if (it % cfg.pl_every) == 0 and cfg.pl_weight > 0.0:
            mb = min(16, img.size(0))  # micro-PL to reduce graph size
            img_small = img[:mb]
            w_small = w[:mb]
            # compute PL outside autocast to avoid fp16 grad issues
            noise = torch.randn_like(img_small) / (img_small.numel() / mb)**0.5
            pl = torch.autograd.grad(outputs=(img_small * noise).sum(),
                                     inputs=w_small,
                                     create_graph=True)[0]
            pl = pl.pow(2).sum(1).sqrt().mean()
            if pl_ema is None:
                pl_ema = pl.detach()
            pl_pen = (pl - pl_ema).pow(2)
            pl_ema = pl_ema.lerp(pl.detach(), 0.01)
            g_loss = g_loss + cfg.pl_weight * pl_pen

        g_scaler.scale(g_loss).backward()
        g_scaler.step(opt_G)
        g_scaler.update()

        g_hist.append(float(g_loss.detach()))
        d_hist.append(float(d_loss.detach()))
        pbar.set_postfix(g=f"{g_hist[-1]:.3f}", d=f"{d_hist[-1]:.3f}")


In [None]:
# Use channels-last to save memory bandwidth on Ampere+ GPUs
G = G.to(memory_format=torch.channels_last)
D = D.to(memory_format=torch.channels_last)

# Enable TF32 / higher matmul perf (A100/V100/RTX30+)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
