# Training a PIX2PIX Model using PyTorch / ONNX

This notebook walks you through the steps of training your own image-to-image machine learning model.

Basically all you have to do is put your cursor in a cell and press Shift+Enter. At the end, you can download the latest model from the `output` folder (it will be called something like `generator_epoch_XXX.onnx`).

In [None]:
# Make sure you are connected to a runtime with a GPU
!nvidia-smi -L

In [None]:
# Install ONNX (not installed by default)
#import locale
#locale.getpreferredencoding = lambda: "UTF-8"
%pip install -q onnx matplotlib tqdm

In [None]:
# Import all other dependencies
import glob, os, random, copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.onnx
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torchvision.transforms.functional as TF
from PIL import Image

import argparse
from tqdm import tqdm
import matplotlib.pyplot as plt
from types import SimpleNamespace
from IPython.display import clear_output
from torch.nn.utils import spectral_norm
from torch.cuda.amp import autocast, GradScaler

In [None]:
# Check if GPU is available
gpu_available = torch.cuda.is_available()
print("GPU is", "available" if gpu_available else "NOT AVAILABLE")

In [None]:
# Download and unzip the dataset
!curl -O https://algorithmicgaze.s3.amazonaws.com/workshops/2025-raive/patterns_512.zip
!mkdir -p datasets/patterns
!unzip -j -o -qq *.zip -d datasets/patterns
!rm -r datasets/patterns/._*

In [None]:
# Some helper functions for creating/checking directories.
def directory_should_exist(*args):
    dir = os.path.join(*args)
    if not os.path.isdir(dir):
        raise Exception("Path '{}' is not a directory.".format(dir))
    return dir

def ensure_directory(*args):
    dir = os.path.join(*args)
    os.makedirs(dir, exist_ok=True)
    return dir

In [None]:
# Point to your dataset and configure training
input_dir = directory_should_exist("datasets/patterns")
output_dir = ensure_directory("output")

# I/O and schedule
epochs = 100
batch_size = 64              # 512x512 is heavy; tweak based on VRAM
sample_interval = 15        # iterations between samples
snapshot_interval = 1       # epochs between checkpoints

# Latent & image
z_dim = 128
img_channels = 3
image_size = 512

# Optim & regularization (modern defaults for SN+hinge)
g_lr = 2e-4
d_lr = 2e-4
betas = (0.0, 0.99)

# Regularization / stability
d_reg_every = 16           # R1 every N D steps (lazy)
r1_gamma = 10.0            # R1 weight (StyleGAN2 uses 10 at 256; 10 is fine here)
ema_decay = 0.995          # EMA for generator weights
use_amp = True             # mixed precision for speed
torch.backends.cudnn.benchmark = True  # speed-up on constant shapes

In [None]:
# Unconditional image dataset (no splitting) + shape guard
class UncondImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, expected_hw=None):
        """
        expected_hw: tuple (H, W) to enforce final tensor size, e.g. (512, 512).
        """
        self.root_dir = root_dir
        self.transform = transform
        self.expected_hw = expected_hw
        self.image_files = [f for f in os.listdir(root_dir)
                            if f.lower().endswith((".jpg", ".jpeg", ".png"))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        path = os.path.join(self.root_dir, self.image_files[idx])
        img = Image.open(path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        if self.transform:
            img = self.transform(img)  # [C,H,W], float in [-1,1]
        if self.expected_hw is not None:
            C, H, W = img.shape
            expH, expW = self.expected_hw
            if (H, W) != (expH, expW) or C != 3:
                raise ValueError(f"Bad tensor shape {img.shape} for {path}; expected (3,{expH},{expW})")
        return img

In [None]:
transform = transforms.Compose([
    transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(image_size),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # -> [-1, 1]
])

dataset = UncondImageDataset(input_dir, transform=transform)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=True,
)

In [None]:
# Show a single real image from the dataset
def plot_image(ax, title, img):
    img = (img + 1) / 2
    ax.imshow(img.permute(1, 2, 0).cpu().numpy())
    ax.set_title(title); ax.axis("off")

assert len(dataset) > 0, f"No images found in {input_dir}. Supported: .jpg .jpeg .png"
real_batch = next(iter(dataloader))
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
plot_image(ax, "Real Image", real_batch[0])
plt.show()

In [None]:
# Anti-alias blur (StyleGAN-ish upfirdn-lite)
def _make_kernel(k):
    k = torch.tensor(k, dtype=torch.float32)
    if k.ndim == 1:
        k = k[:, None] * k[None, :]
    k /= k.sum()
    return k

class Blur(nn.Module):
    def __init__(self, channels: int, kernel=(1,3,3,1), pad=(1,1)):
        super().__init__()
        k = _make_kernel(kernel)
        w = k.view(1,1,k.shape[0],k.shape[1]).repeat(channels,1,1,1)
        self.register_buffer("weight", w)
        self.pad = pad
        self.channels = channels
    def forward(self, x):
        return F.conv2d(x, self.weight, None, 1, self.pad, 1, self.channels)

class PixelNorm(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__(); self.eps = eps
    def forward(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=1, keepdim=True) + self.eps)

class NoiseInjection(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, ch, 1, 1))
    def forward(self, x):
        if not self.training:
            return x  # deterministic at eval/export/ONNX
        n = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device)
        return x + self.weight * n

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, use_blur=True, use_pn=False):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.blur = Blur(in_ch, (1,3,3,1), (1,1)) if use_blur else nn.Identity()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.act1  = nn.LeakyReLU(0.2, inplace=True)
        self.noise1 = NoiseInjection(out_ch)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
        self.act2  = nn.LeakyReLU(0.2, inplace=True)
        self.noise2 = NoiseInjection(out_ch)

        self.skip = nn.Conv2d(in_ch, out_ch, 1)  # light residual path
        self.pn = PixelNorm() if use_pn else nn.Identity()

        # init
        for m in [self.conv1, self.conv2, self.skip]:
            nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x):
        y = self.up(x)
        y = self.blur(y)
        s = self.skip(y)

        y = self.conv1(y); y = self.act1(y); y = self.noise1(y)
        y = self.conv2(y); y = self.act2(y); y = self.pn(y)
        return y + 0.1 * s

In [None]:
class Generator(nn.Module):
    """
    512x512 generator with two convs per scale, early blur, early pixelnorm,
    light residual, and noise injection (train-only). ONNX-safe.
    """
    def __init__(self, z_dim=128, img_channels=3, chs=(512, 512, 256, 128, 64, 32, 16, 8)):
        super().__init__()
        self.z_dim = z_dim
        self.chs = chs
        self.fc = nn.Linear(z_dim, 4 * 4 * chs[0])

        blur_flags = [True, True, True, False, False, False, False]   # only early
        pn_flags   = [True, True, True, False, False, False, False]

        blocks = []
        for i, (cin, cout) in enumerate(zip(chs[:-1], chs[1:])):
            blocks.append(UpBlock(cin, cout, use_blur=blur_flags[i], use_pn=pn_flags[i]))
        self.up = nn.Sequential(*blocks)
        self.to_rgb = nn.Conv2d(chs[-1], img_channels, kernel_size=1)
        nn.init.xavier_uniform_(self.to_rgb.weight); nn.init.zeros_(self.to_rgb.bias)

    def forward(self, z):
        x = self.fc(z.view(z.size(0), -1)).view(-1, self.chs[0], 4, 4)
        x = self.up(x)
        return torch.tanh(self.to_rgb(x))

In [None]:
class MinibatchStdDev(nn.Module):
    def __init__(self, group_size=4):
        super().__init__()
        self.group_size = group_size
    def forward(self, x):
        N, C, H, W = x.shape
        g = min(self.group_size, N)  # handle small batches
        y = x.view(g, -1, C, H, W)   # [g, n, C, H, W]
        y = torch.var(y, dim=0, unbiased=False) + 1e-8
        y = torch.sqrt(y)
        y = torch.mean(y, dim=[1,2,3], keepdim=True)  # [1,1,1,1]
        y = y.repeat(g, 1, H, W)                      # [N,1,H,W]
        return torch.cat([x, y], dim=1)

class Discriminator(nn.Module):
    def __init__(self, img_channels=3, chs=(64, 128, 256, 512, 512, 512, 512)):
        super().__init__()
        layers = []
        in_ch = img_channels
        for out_ch in chs:
            layers += [
                spectral_norm(nn.Conv2d(in_ch, out_ch, 4, 2, 1)),
                nn.LeakyReLU(0.2, inplace=True),
            ]
            in_ch = out_ch
        self.body = nn.Sequential(*layers)
        self.mbstd = MinibatchStdDev(group_size=4)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.head = spectral_norm(nn.Linear(in_ch + 1, 1))  # +1 from mbstd

    def forward(self, x):
        h = self.body(x)
        h = self.mbstd(h)
        h = self.pool(h).view(h.size(0), -1)
        return self.head(h).squeeze(1)

In [None]:
# ==== Training helpers: EMA, R1, DiffAugment ====

def r1_penalty(d_out, real_img):
    """R1 gradient penalty on real images."""
    grads = torch.autograd.grad(
        outputs=d_out.sum(), inputs=real_img,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    return grads.pow(2).reshape(grads.size(0), -1).sum(dim=1).mean()

class EMA:
    """Exponential Moving Average of model weights."""
    def __init__(self, model, decay=0.995):
        self.decay = decay
        self.shadow = copy.deepcopy(model).eval()
        for p in self.shadow.parameters():
            p.requires_grad = False

    @torch.no_grad()
    def update(self, model):
        msd = self.shadow.state_dict()
        for k, src in model.state_dict().items():
            tgt = msd[k]
            # Blend only floating-point tensors; copy integer buffers (e.g., counters)
            if torch.is_floating_point(tgt):
                tgt.mul_(self.decay).add_(src, alpha=1.0 - self.decay)
            else:
                tgt.copy_(src)

# ---- DiffAugment (color, translation, cutout) ----
def rand_brightness(x):  # ±0.1
    return x + (torch.rand(x.size(0),1,1,1, device=x.device) - 0.5) * 0.2

def rand_saturation(x):
    xm = x.mean(dim=1, keepdim=True)
    s = (torch.rand(x.size(0),1,1,1, device=x.device) * 0.6 + 0.7)  # 0.7..1.3
    return (x - xm) * s + xm

def rand_contrast(x):
    xm = x.mean(dim=[1,2,3], keepdim=True)
    c = (torch.rand(x.size(0),1,1,1, device=x.device) * 0.6 + 0.7)  # 0.7..1.3
    return (x - xm) * c + xm

def _translate_grid(x, ratio=0.06):
    B,C,H,W = x.shape
    shift = torch.randint(-int(W*ratio), int(W*ratio)+1, (B,2), device=x.device).float()
    yy, xx = torch.meshgrid(
        torch.linspace(-1,1,H,device=x.device),
        torch.linspace(-1,1,W,device=x.device),
        indexing='ij'
    )
    base = torch.stack((xx,yy), dim=-1).unsqueeze(0).repeat(B,1,1,1)
    base[...,0] += shift[:,0].view(B,1,1) * (2.0/W)
    base[...,1] += shift[:,1].view(B,1,1) * (2.0/H)
    return F.grid_sample(x, base, padding_mode='reflection', align_corners=False)

def _cutout(x, ratio=0.12):
    B,C,H,W = x.shape
    sz = max(1, int(H*ratio*0.5))
    cx = torch.randint(sz, W - sz + 1, (B,), device=x.device)
    cy = torch.randint(sz, H - sz + 1, (B,), device=x.device)
    mask = torch.ones((B,1,H,W), device=x.device)
    for i in range(B):
        mask[i,:, cy[i]-sz:cy[i]+sz, cx[i]-sz:cx[i]+sz] = 0.0
    return x * mask

def diff_augment(x):
    x = rand_brightness(x); x = rand_saturation(x); x = rand_contrast(x)
    x = _translate_grid(x, ratio=0.125)
    x = _cutout(x, ratio=0.25)
    return x

In [None]:
# Load snapshot if available
def get_latest_snapshot(output_dir):
    snapshots = glob.glob(os.path.join(output_dir, "snapshot_epoch_*.pth"))
    if not snapshots:
        return None
    return max(snapshots, key=os.path.getctime)

def get_latest_generator(output_dir):
    generators = glob.glob(os.path.join(output_dir, "generator_epoch_*.onnx"))
    if not generators:
        return None
    return max(generators, key=os.path.getctime)

In [None]:
def load_snapshot(generator, discriminator, g_optimizer, d_optimizer, snapshot_path, ema_shadow=None):
    checkpoint = torch.load(snapshot_path, map_location=device, weights_only=False)
    generator.load_state_dict(checkpoint["generator"])
    discriminator.load_state_dict(checkpoint["discriminator"])
    g_optimizer.load_state_dict(checkpoint["g_optimizer"])
    d_optimizer.load_state_dict(checkpoint["d_optimizer"])
    if ema_shadow is not None and "ema_generator" in checkpoint:
        ema_shadow.load_state_dict(checkpoint["ema_generator"])
    start_epoch = int(os.path.basename(snapshot_path).split("_")[2].split(".")[0])
    return start_epoch

In [None]:
# Create the training loop (unconditional GAN, hinge + R1, SN-D, EMA, AMP)
def train(generator, discriminator, dataloader, opts):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.train().to(device)
    discriminator.train().to(device)

    g_opt = optim.Adam(generator.parameters(), lr=opts.g_lr, betas=opts.betas)
    d_opt = optim.Adam(discriminator.parameters(), lr=opts.d_lr, betas=opts.betas)

    scaler_g = GradScaler(enabled=opts.use_amp)
    scaler_d = GradScaler(enabled=opts.use_amp)

    ema = EMA(generator, decay=opts.ema_decay)

    # Fixed latent for monitoring
    fixed_z = torch.randn(16, opts.z_dim, device=device)

    start_epoch = 1
    if not getattr(opts, "restart", False):
        latest_snapshot = get_latest_snapshot(opts.output_dir)
        if latest_snapshot:
            start_epoch = load_snapshot(
                generator, discriminator, g_opt, d_opt, latest_snapshot, ema_shadow=ema.shadow
            )
            print(f"Resumed from {latest_snapshot} (start_epoch={start_epoch})")

    it = 0
    for epoch in range(start_epoch, opts.epochs + 1):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{opts.epochs}")
        for real in pbar:
            it += 1
            real = real.to(device, non_blocking=True)
            bsz = real.size(0)

            # ---- Train Discriminator ----
            z = torch.randn(bsz, opts.z_dim, device=device)
            with autocast(enabled=opts.use_amp):
                with torch.no_grad():
                    fake = generator(z)
                # Apply DiffAugment on both
                real_aug = diff_augment(real)
                fake_aug = diff_augment(fake)
                d_real = discriminator(real_aug)
                d_fake = discriminator(fake_aug)
                d_loss = F.relu(1.0 - d_real).mean() + F.relu(1.0 + d_fake).mean()
            
            d_opt.zero_grad(set_to_none=True)
            scaler_d.scale(d_loss).backward()
            scaler_d.step(d_opt)
            scaler_d.update()

            # ---- Lazy R1 on augmented real (in FP32 for stability) ----
            if (it % opts.d_reg_every) == 0:
                real_req = diff_augment(real.detach().requires_grad_(True))
                with autocast(enabled=False):
                    d_real_r1 = discriminator(real_req.float())
                    r1 = r1_penalty(d_real_r1, real_req.float())
                    r1_loss = (opts.r1_gamma / 2.0) * r1
                d_opt.zero_grad(set_to_none=True)
                scaler_d.scale(r1_loss).backward()
                scaler_d.step(d_opt)
                scaler_d.update()

            # -----------------------
            #  Train Generator
            # -----------------------
            z = torch.randn(bsz, opts.z_dim, device=device)
            with autocast(enabled=opts.use_amp):
                fake = generator(z)
                g_fake = discriminator(fake)
                g_loss = -g_fake.mean()

            g_opt.zero_grad(set_to_none=True)
            scaler_g.scale(g_loss).backward()
            scaler_g.step(g_opt)
            scaler_g.update()

            # EMA update
            ema.update(generator)

            pbar.set_postfix({
                "D": f"{d_loss.item():.3f}",
                "G": f"{g_loss.item():.3f}",
            })

        # Visualization and sampling (always raw G)
        if it % opts.sample_interval == 0:
            with torch.no_grad():
                generator.eval() # Avoid noise injection during visualization
                samples = generator(fixed_z).detach().cpu()
                # Save and show
                save_path = os.path.join(opts.output_dir, f"epoch_{epoch}_iter_{it}.jpg")
                save_image(samples, save_path, nrow=4, normalize=True, value_range=(-1, 1))
                clear_output(wait=True)
                print(f"Epoch {epoch} | iter {it}")
                grid = (samples[:4] + 1) / 2
                fig, axes = plt.subplots(1, 4, figsize=(10, 3))
                for a, img in zip(axes, grid):
                    a.imshow(img.permute(1, 2, 0).numpy()); a.axis("off")
                plt.show()
                generator.train()

        # Snapshot & ONNX export at epoch end
        if (epoch % opts.snapshot_interval) == 0:
            snap_path = os.path.join(opts.output_dir, f"snapshot_epoch_{epoch}.pth")
            torch.save({
                "generator": generator.state_dict(),
                "discriminator": discriminator.state_dict(),
                "g_optimizer": g_opt.state_dict(),
                "d_optimizer": d_opt.state_dict(),
                "ema_generator": ema.shadow.state_dict(),
            }, snap_path)
            print(f"Saved snapshot to {snap_path}")

            # Export ONNX from RAW generator
            generator.eval()
            dummy_z = torch.randn(1, opts.z_dim, device=device)
            onnx_path = os.path.join(opts.output_dir, f"generator_epoch_{epoch}.onnx")
            torch.onnx.export(
                generator,
                (dummy_z,),
                onnx_path,
                export_params=True,
                opset_version=17,
                do_constant_folding=True,
                input_names=["z"],
                output_names=["image"],
                dynamic_axes={"z": {0: "batch"}, "image": {0: "batch"}},
            )
            print(f"ONNX model exported to {onnx_path}")
            generator.train()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(z_dim=z_dim, img_channels=img_channels).to(device)
discriminator = Discriminator(img_channels=img_channels).to(device)

opts = SimpleNamespace(
    output_dir=output_dir,
    sample_interval=sample_interval,
    snapshot_interval=snapshot_interval,
    epochs=epochs,
    restart=False,
    # new:
    z_dim=z_dim,
    g_lr=g_lr, d_lr=d_lr, betas=betas,
    d_reg_every=d_reg_every, r1_gamma=r1_gamma,
    ema_decay=ema_decay, use_amp=use_amp,
)

train(generator, discriminator, dataloader, opts)