In [None]:
# dcgan_mnist_minimal.py
import os, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs("samples", exist_ok=True)

# ===== 1) Data: MNIST 28x28, chuẩn hoá về [-1, 1] =====
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))  # => [-1,1]
])
ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
dl = DataLoader(ds, batch_size=128, shuffle=True, num_workers=2, drop_last=True)

# ===== 2) Mô hình =====
latent_dim = 100      # z ~ N(0,1)
img_channels = 1      # MNIST trắng đen
img_size = 28

class Generator(nn.Module):
    """
    Nhận z (100) -> sinh ảnh 28x28 qua các lớp ConvTranspose2d
    """
    def __init__(self, z_dim=100, base=128):
        super().__init__()
        self.net = nn.Sequential(
            # z -> 128*7*7
            nn.Linear(z_dim, base*7*7),
            nn.BatchNorm1d(base*7*7),
            nn.ReLU(True),
            View((-1, base, 7, 7)),
            # 7x7 -> 14x14
            nn.ConvTranspose2d(base, base//2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base//2),
            nn.ReLU(True),
            # 14x14 -> 28x28
            nn.ConvTranspose2d(base//2, img_channels, 4, 2, 1, bias=False),
            nn.Tanh(),  # [-1,1]
        )
    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    """
    Nhận ảnh 28x28 -> score thật/giả
    """
    def __init__(self, base=64):
        super().__init__()
        self.net = nn.Sequential(
            # 28x28 -> 14x14
            nn.Conv2d(img_channels, base, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 14x14 -> 7x7
            nn.Conv2d(base, base*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base*2),
            nn.LeakyReLU(0.2, inplace=True),
            View((-1, base*2*7*7)),
            nn.Linear(base*2*7*7, 1)  # logits
        )
    def forward(self, x):
        return self.net(x).squeeze(1)

class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape
    def forward(self, x):
        return x.view(*self.shape)

G = Generator(latent_dim).to(DEVICE)
D = Discriminator().to(DEVICE)

# He/Xavier init như DCGAN
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        nn.init.normal_(m.weight, 1.0, 0.02); nn.init.zeros_(m.bias)
G.apply(weights_init); D.apply(weights_init)

# ===== 3) Loss & Optim =====
criterion = nn.BCEWithLogitsLoss()
opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

fixed_z = torch.randn(64, latent_dim, device=DEVICE)  # để quan sát tiến bộ qua từng epoch

# ===== 4) Train loop =====
EPOCHS = 5
for epoch in range(1, EPOCHS+1):
    for real, _ in dl:
        real = real.to(DEVICE)
        b = real.size(0)

        # --- Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        z = torch.randn(b, latent_dim, device=DEVICE)
        fake = G(z).detach()
        real_logits = D(real)
        fake_logits = D(fake)

        d_loss_real = criterion(real_logits, torch.ones(b, device=DEVICE))
        d_loss_fake = criterion(fake_logits, torch.zeros(b, device=DEVICE))
        d_loss = d_loss_real + d_loss_fake

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # --- Update Generator: maximize log(D(G(z))) (hoặc minimize BCE(fake, 1))
        z = torch.randn(b, latent_dim, device=DEVICE)
        gen = G(z)
        gen_logits = D(gen)
        g_loss = criterion(gen_logits, torch.ones(b, device=DEVICE))

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

    # Lưu ảnh mẫu theo epoch
    with torch.no_grad():
        samples = G(fixed_z).cpu()
        samples = (samples + 1) / 2  # chuyển về [0,1] để lưu
        utils.save_image(samples, f"samples/epoch_{epoch}.png", nrow=8)
    print(f"[Epoch {epoch}] D_loss={d_loss.item():.4f} | G_loss={g_loss.item():.4f}")

print("Done. Check folder: samples/")
