# Conditional GAN (cGAN) — Bug-Fixes

In [None]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------
# Data  (normalize to [-1,1] to match Tanh)
# ---------------------------
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])                # FIX [SE-1]
])
ds = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
loader = DataLoader(ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

# ---------------------------
# Hyperparams (sane defaults)
# ---------------------------
z_dim      = 128
emb_dim    = 50
num_classes= 10
g_lr = d_lr = 2e-4                                    # FIX [E-8]
betas = (0.5, 0.999)                                  # FIX [E-8]

# ---------------------------
# Models
# ---------------------------
class CGAN_G(nn.Module):
    def __init__(self, z=128, emb_dim=50, num_classes=10, ch=64):
        super().__init__()
        self.embed = nn.Embedding(num_classes, emb_dim)
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z + emb_dim, ch*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ch*4), nn.ReLU(True),
            nn.ConvTranspose2d(ch*4, ch*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ch*2), nn.ReLU(True),
            nn.ConvTranspose2d(ch*2, ch,   4, 2, 1, bias=False),
            nn.BatchNorm2d(ch),   nn.ReLU(True),
            nn.ConvTranspose2d(ch, 1,      4, 2, 1, bias=False),
            nn.Tanh()                                             # FIX [SE-2]
        )
    def forward(self, z, y):
        e = self.embed(y)                                         # (B, emb_dim)
        inp = torch.cat([z, e], dim=1)                            # FIX [M-9]: concat on feature dim
        x = inp.view(inp.size(0), -1, 1, 1)                       # (B, z+emb,1,1)
        return self.net(x)

class CGAN_D(nn.Module):
    def __init__(self, emb_dim=50, num_classes=10, ch=64):
        super().__init__()
        self.embed = nn.Embedding(num_classes, emb_dim)
        self.img = nn.Sequential(
            nn.Conv2d(1, ch, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch, ch*2, 4, 2, 1), nn.BatchNorm2d(ch*2), nn.LeakyReLU(0.2, True),
        )
        self.fc = nn.Linear(ch*2*8*8 + emb_dim, 1)                # FIX [SE-3]: no Sigmoid (use logits)
    def forward(self, x, y):
        h = self.img(x).view(x.size(0), -1)
        e = self.embed(y)
        return self.fc(torch.cat([h, e], dim=1)).view(x.size(0))

G = CGAN_G(z_dim, emb_dim, num_classes).to(device)
D = CGAN_D(emb_dim, num_classes).to(device)

# ---------------------------
# Loss & Optimizers
# ---------------------------
crit = nn.BCEWithLogitsLoss()
optD = torch.optim.Adam(D.parameters(), lr=d_lr, betas=betas)
optG = torch.optim.Adam(G.parameters(), lr=g_lr, betas=betas)

# ---------------------------
# Training (short sanity run)
# ---------------------------
G.train(); D.train()
for step, (real, y) in enumerate(loader):
    real = real.to(device)
    y    = y.to(device).long()                                    # FIX [SE-4]

    b = real.size(0)
    z = torch.randn(b, z_dim, device=device)

    # ---- Discriminator step ----
    with torch.no_grad():                                         # or fake = G(z,y).detach()
        fake = G(z, y)
    optD.zero_grad()
    lossD = crit(D(real, y), torch.ones (b, device=device)) + crit(D(fake, y), torch.zeros(b, device=device))   # FIX [SE-5]  # FIX [E-7]
                  
    lossD.backward()
    optD.step()                                                    # FIX [H-10]

    # ---- Generator step ----
    z = torch.randn(b, z_dim, device=device)
    optG.zero_grad()
    fake = G(z, y)                                                # same labels to D as used by G  [E-5]
    lossG = crit(D(fake, y), torch.ones(b, device=device))        # non-saturating G loss [E-6]
    lossG.backward()
    optG.step()                                                    # FIX [H-10]

    if step % 200 == 0:
        print(f"step {step:05d}  lossD={lossD.item():.3f}  lossG={lossG.item():.3f}")
    if step == 800:  # small sanity run
        break


## what was fixed (mapping to bug tags)

1. **\[SE-1]** input normalization to `(-1,1)`
2. **\[SE-2]** added `Tanh` at generator output
3. **\[SE-3]** removed `Sigmoid` from D (use logits + `BCEWithLogitsLoss`)
4. **\[SE-4]** labels cast to `long` for `nn.Embedding`
5. **\[E-5]** same class labels `y` passed to both `G` and `D` for fake samples
6. **\[E-6]** non-saturating G loss: target ones
7. **\[E-7]** no gradients to G during D step (`with torch.no_grad()` or `.detach()`)
8. **\[E-8]** sensible optimizers: `lr=2e-4`, `betas=(0.5,0.999)`
9. **\[M-9]** correct conditioning concat on `dim=1`, then `.view(B, z+emb, 1, 1)`
10. **\[H-10]** correct optimizer usage + proper `zero_grad()`/`step()` ordering