# Training a PIX2PIX Model using PyTorch / ONNX

This notebook walks you through the steps of training your own image-to-image machine learning model.

Basically all you have to do is put your cursor in a cell and press Shift+Enter. At the end, you can download the latest model from the `output` folder (it will be called something like `generator_epoch_XXX.onnx`).

In [None]:
# Make sure you are connected to a runtime with a GPU
!nvidia-smi -L

In [None]:
# Install ONNX (not installed by default)
#import locale
#locale.getpreferredencoding = lambda: "UTF-8"
%pip install -q onnx matplotlib tqdm

In [None]:
# Import all other dependencies
import glob, os, random, copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.onnx
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torchvision.transforms.functional as TF
from PIL import Image

import argparse
from tqdm import tqdm
import matplotlib.pyplot as plt
from types import SimpleNamespace
from IPython.display import clear_output
from torch.nn.utils import spectral_norm
from torch.cuda.amp import autocast, GradScaler

In [None]:
# Check if GPU is available
gpu_available = torch.cuda.is_available()
print("GPU is", "available" if gpu_available else "NOT AVAILABLE")

In [None]:
# Download and unzip the dataset
!curl -O https://algorithmicgaze.s3.amazonaws.com/workshops/2025-raive/patterns_512.zip
!mkdir -p datasets/patterns
!unzip -j -o -qq *.zip -d datasets/patterns
!rm -r datasets/patterns/._*

In [None]:
# Some helper functions for creating/checking directories.
def directory_should_exist(*args):
    dir = os.path.join(*args)
    if not os.path.isdir(dir):
        raise Exception("Path '{}' is not a directory.".format(dir))
    return dir

def ensure_directory(*args):
    dir = os.path.join(*args)
    os.makedirs(dir, exist_ok=True)
    return dir

In [None]:
# Point to your dataset and configure training
input_dir = directory_should_exist("datasets/patterns")
output_dir = ensure_directory("output")

# I/O and schedule
epochs = 100
batch_size = 64               # 512^2 is heavy; 4–8 is typical
sample_interval = 15          # iters between previews
snapshot_interval = 1        # epochs between checkpoints

# Latent & image
z_dim = 128
img_channels = 3
image_size = 512

# Optimizer (classic DCGAN)
g_lr = 2e-4
d_lr = 2e-4
betas = (0.5, 0.999)

# Label smoothing
real_label_smooth = 0.1      # real target = 1 - 0.1 = 0.9

# AMP
use_amp = True
torch.backends.cudnn.benchmark = True

In [None]:
# Unconditional image dataset (no splitting) + shape guard
class UncondImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, expected_hw=None):
        """
        expected_hw: tuple (H, W) to enforce final tensor size, e.g. (512, 512).
        """
        self.root_dir = root_dir
        self.transform = transform
        self.expected_hw = expected_hw
        self.image_files = [f for f in os.listdir(root_dir)
                            if f.lower().endswith((".jpg", ".jpeg", ".png"))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        path = os.path.join(self.root_dir, self.image_files[idx])
        img = Image.open(path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        if self.transform:
            img = self.transform(img)  # [C,H,W], float in [-1,1]
        if self.expected_hw is not None:
            C, H, W = img.shape
            expH, expW = self.expected_hw
            if (H, W) != (expH, expW) or C != 3:
                raise ValueError(f"Bad tensor shape {img.shape} for {path}; expected (3,{expH},{expW})")
        return img

In [None]:
transform = transforms.Compose([
    transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(image_size),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # -> [-1, 1]
])

dataset = UncondImageDataset(input_dir, transform=transform)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=True,
)

In [None]:
# Show a single real image from the dataset
def plot_image(ax, title, img):
    img = (img + 1) / 2
    ax.imshow(img.permute(1, 2, 0).cpu().numpy())
    ax.set_title(title); ax.axis("off")

assert len(dataset) > 0, f"No images found in {input_dir}. Supported: .jpg .jpeg .png"
real_batch = next(iter(dataloader))
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
plot_image(ax, "Real Image", real_batch[0])
plt.show()

In [None]:
# DCGAN weight init (as in the paper)
def weights_init_dcgan(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if getattr(m, 'bias', None) is not None:
            nn.init.zeros_(m.bias.data)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.zeros_(m.bias.data)

In [None]:
class Generator(nn.Module):
    """
    DCGAN generator for 512x512.
    z -> [B, z, 1,1] -> deconvs doubling size each step: 4→8→…→512
    """
    def __init__(self, z_dim=128, img_channels=3, base=64):
        super().__init__()
        self.z_dim = z_dim
        self.main = nn.Sequential(
            # 1x1 -> 4x4
            nn.ConvTranspose2d(z_dim, base*16, 4, 1, 0, bias=False),
            nn.BatchNorm2d(base*16), nn.ReLU(True),

            # 4 -> 8
            nn.ConvTranspose2d(base*16, base*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*8), nn.ReLU(True),

            # 8 -> 16
            nn.ConvTranspose2d(base*8, base*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*4), nn.ReLU(True),

            # 16 -> 32
            nn.ConvTranspose2d(base*4, base*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*2), nn.ReLU(True),

            # 32 -> 64
            nn.ConvTranspose2d(base*2, base, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base), nn.ReLU(True),

            # 64 -> 128
            nn.ConvTranspose2d(base, base//2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base//2), nn.ReLU(True),

            # 128 -> 256
            nn.ConvTranspose2d(base//2, base//4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base//4), nn.ReLU(True),

            # 256 -> 512
            nn.ConvTranspose2d(base//4, img_channels, 4, 2, 1, bias=False),
            nn.Tanh(),
        )
        self.apply(weights_init_dcgan)

    def forward(self, z):
        if z.dim() == 2:
            z = z.view(z.size(0), z.size(1), 1, 1)
        return self.main(z)

In [None]:
class Discriminator(nn.Module):
    """
    DCGAN discriminator for 512x512.
    Strided 4x4 convs halve size each step: 512→…→4, then 1x1 logit.
    BatchNorm everywhere except the first block (classic DCGAN).
    """
    def __init__(self, img_channels=3, base=64):
        super().__init__()
        self.main = nn.Sequential(
            # 512 -> 256
            nn.Conv2d(img_channels, base, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # 256 -> 128
            nn.Conv2d(base, base*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*2),
            nn.LeakyReLU(0.2, inplace=True),

            # 128 -> 64
            nn.Conv2d(base*2, base*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*4),
            nn.LeakyReLU(0.2, inplace=True),

            # 64 -> 32
            nn.Conv2d(base*4, base*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*8),
            nn.LeakyReLU(0.2, inplace=True),

            # 32 -> 16
            nn.Conv2d(base*8, base*16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*16),
            nn.LeakyReLU(0.2, inplace=True),

            # 16 -> 8
            nn.Conv2d(base*16, base*16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*16),
            nn.LeakyReLU(0.2, inplace=True),

            # 8 -> 4
            nn.Conv2d(base*16, base*16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*16),
            nn.LeakyReLU(0.2, inplace=True),

            # 4 -> 1x1
            nn.Conv2d(base*16, 1, 4, 1, 0, bias=False),
        )
        self.apply(weights_init_dcgan)

    def forward(self, x):
        out = self.main(x)       # [B, 1, 1, 1]
        return out.view(-1)      # logits (no sigmoid)

In [None]:
# Non-saturating BCE losses (classic DCGAN)
bce = nn.BCEWithLogitsLoss()

def real_targets(b, device):
    # one-sided label smoothing
    return torch.full((b,), 1.0 - real_label_smooth, device=device)

def fake_targets(b, device):
    return torch.zeros(b, device=device)

In [None]:
# Load snapshot if available
def get_latest_snapshot(output_dir):
    snapshots = glob.glob(os.path.join(output_dir, "snapshot_epoch_*.pth"))
    if not snapshots:
        return None
    return max(snapshots, key=os.path.getctime)

def get_latest_generator(output_dir):
    generators = glob.glob(os.path.join(output_dir, "generator_epoch_*.onnx"))
    if not generators:
        return None
    return max(generators, key=os.path.getctime)

In [None]:
def load_snapshot(generator, discriminator, g_optimizer, d_optimizer, snapshot_path, ema_shadow=None):
    checkpoint = torch.load(snapshot_path, map_location=device, weights_only=False)
    generator.load_state_dict(checkpoint["generator"])
    discriminator.load_state_dict(checkpoint["discriminator"])
    g_optimizer.load_state_dict(checkpoint["g_optimizer"])
    d_optimizer.load_state_dict(checkpoint["d_optimizer"])
    if ema_shadow is not None and "ema_generator" in checkpoint:
        ema_shadow.load_state_dict(checkpoint["ema_generator"])
    start_epoch = int(os.path.basename(snapshot_path).split("_")[2].split(".")[0])
    return start_epoch

In [None]:
def train(generator, discriminator, dataloader, opts):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.train().to(device)
    discriminator.train().to(device)

    g_opt = optim.Adam(generator.parameters(), lr=opts.g_lr, betas=opts.betas)
    d_opt = optim.Adam(discriminator.parameters(), lr=opts.d_lr, betas=opts.betas)
    scaler_g = GradScaler(enabled=opts.use_amp)
    scaler_d = GradScaler(enabled=opts.use_amp)

    fixed_z = torch.randn(16, opts.z_dim, device=device)
    it = 0
    for epoch in range(1, opts.epochs + 1):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{opts.epochs}")
        for real in pbar:
            it += 1
            real = real.to(device, non_blocking=True)
            bsz = real.size(0)

            # ------------------ Train D ------------------
            z = torch.randn(bsz, opts.z_dim, device=device)
            with autocast(enabled=opts.use_amp):
                with torch.no_grad():
                    fake = generator(z)

                d_real = discriminator(real)
                d_fake = discriminator(fake)

                d_loss_real = bce(d_real, real_targets(bsz, device))
                d_loss_fake = bce(d_fake, fake_targets(bsz, device))
                d_loss = d_loss_real + d_loss_fake

            d_opt.zero_grad(set_to_none=True)
            scaler_d.scale(d_loss).backward()
            scaler_d.step(d_opt)
            scaler_d.update()

            # ------------------ Train G ------------------
            z = torch.randn(bsz, opts.z_dim, device=device)
            with autocast(enabled=opts.use_amp):
                fake = generator(z)
                g_fake = discriminator(fake)
                # Non-saturating: tell D "these are real"
                g_loss = bce(g_fake, real_targets(bsz, device))

            g_opt.zero_grad(set_to_none=True)
            scaler_g.scale(g_loss).backward()
            scaler_g.step(g_opt)
            scaler_g.update()

            pbar.set_postfix({"D": f"{d_loss.item():.3f}", "G": f"{g_loss.item():.3f}"})

            # --------- Preview (always raw G) ----------
            if it % opts.sample_interval == 0:
                with torch.no_grad():
                    generator.eval()
                    samples = generator(fixed_z).detach().cpu()
                    save_path = os.path.join(opts.output_dir, f"epoch_{epoch}_iter_{it}.jpg")
                    save_image(samples, save_path, nrow=4, normalize=True, value_range=(-1, 1))
                    clear_output(wait=True)
                    print(f"Epoch {epoch} | iter {it}")
                    grid = (samples[:4] + 1) / 2
                    fig, axes = plt.subplots(1, 4, figsize=(10, 3))
                    for a, img in zip(axes, grid):
                        a.imshow(img.permute(1, 2, 0).numpy()); a.axis("off")
                    plt.show()
                    generator.train()

        # --------- Snapshot + ONNX (raw G) ----------
        if (epoch % opts.snapshot_interval) == 0:
            snap_path = os.path.join(opts.output_dir, f"snapshot_epoch_{epoch}.pth")
            torch.save({
                "generator": generator.state_dict(),
                "discriminator": discriminator.state_dict(),
                "g_optimizer": g_opt.state_dict(),
                "d_optimizer": d_opt.state_dict(),
            }, snap_path)
            print(f"Saved snapshot to {snap_path}")

            generator.eval()
            dummy_z = torch.randn(1, opts.z_dim, device=device)
            onnx_path = os.path.join(opts.output_dir, f"generator_epoch_{epoch}.onnx")
            torch.onnx.export(
                generator,
                (dummy_z,),
                onnx_path,
                export_params=True,
                opset_version=17,
                do_constant_folding=True,
                input_names=["z"],
                output_names=["image"],
                dynamic_axes={"z": {0: "batch"}, "image": {0: "batch"}},
            )
            print(f"ONNX model exported to {onnx_path}")
            generator.train()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(z_dim=z_dim, img_channels=img_channels).to(device)
discriminator = Discriminator(img_channels=img_channels).to(device)

opts = SimpleNamespace(
    output_dir=output_dir,
    sample_interval=sample_interval,
    snapshot_interval=snapshot_interval,
    epochs=epochs,
    z_dim=z_dim,
    g_lr=g_lr, d_lr=d_lr, betas=betas,
    use_amp=use_amp,
)

# Fresh run recommended if you changed architectures
train(generator, discriminator, dataloader, opts)