# DCGAN — Bug-Fixes

In [None]:
import torch, torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (Fix: normalize to [-1,1] to match Tanh)
# ---------------------------

In [None]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])   # FIX
])
ds = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

# ---------------------------
# Hyperparams  (Fix: sensible lrs)
# ---------------------------

In [None]:
z_dim = 100
g_lr  = 2e-4                                           # FIX
d_lr  = 2e-4

# ---------------------------
# Models
# ---------------------------

In [None]:
class D(nn.Module):
    def __init__(self, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, ch,   4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),            # FIX 
            nn.Conv2d(ch, ch*2, 4, 2, 1),
            nn.BatchNorm2d(ch*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch*2, ch*4, 4, 2, 1),
            nn.BatchNorm2d(ch*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ch*4, 1, 4, 1, 0)
            # (no Sigmoid; we use BCEWithLogitsLoss)     # FIX 
        )
    def forward(self, x):
        return self.net(x).view(x.size(0))

class G(nn.Module):
    def __init__(self, z=100, ch=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z,   ch*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ch*4), nn.ReLU(True),
            nn.ConvTranspose2d(ch*4, ch*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ch*2), nn.ReLU(True),
            nn.ConvTranspose2d(ch*2, ch,   4, 2, 1, bias=False),
            nn.BatchNorm2d(ch),   nn.ReLU(True),
            nn.ConvTranspose2d(ch,  3,     4, 2, 1, bias=False),
            nn.Tanh()                                           # FIX 
        )
    def forward(self, z):
        z = z.view(z.size(0), z_dim, 1, 1)                      # FIX 
        return self.net(z)

Dnet = D().to(device)
Gnet = G(z_dim).to(device)

# ---------------------------
# Loss & Optimizers
# ---------------------------

In [None]:
crit = nn.BCEWithLogitsLoss()
opt_d = torch.optim.Adam(Dnet.parameters(), lr=d_lr, betas=(0.5, 0.999))   # also improves stability
opt_g = torch.optim.Adam(Gnet.parameters(), lr=g_lr,  betas=(0.5, 0.999))

# ---------------------------
# Training (one short sanity epoch)
# ---------------------------

In [None]:
Gnet.train(); Dnet.train()
for step, (real, _) in enumerate(loader):
    real = real.to(device)
    b = real.size(0)

    # ---- Discriminator step ----
    z = torch.randn(b, z_dim, device=device)
    with torch.no_grad():                                             # OR fake = Gnet(z).detach()
        fake = Gnet(z)                                                # FIX 
    opt_d.zero_grad()                                                 # FIX 
    loss_d = crit(Dnet(real), torch.ones (b, device=device)) + crit(Dnet(fake), torch.zeros(b, device=device))      # FIX
             
    loss_d.backward()
    opt_d.step()                                                      # FIX

    # ---- Generator step ----
    z = torch.randn(b, z_dim, device=device)
    opt_g.zero_grad()                                                 # FIX
    fake = Gnet(z)
    loss_g = crit(Dnet(fake), torch.ones(b, device=device))           # FIX
    loss_g.backward()
    opt_g.step()                                                      # FIX

    if step % 200 == 0:
        print(f"step {step:05d}  loss_D={loss_d.item():.3f}  loss_G={loss_g.item():.3f}")
    if step == 600:  # small sanity run
        break