In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

# =====================================================
# CONFIG
# =====================================================
IMG_SIZE = 32
CHANNELS = 3
BATCH = 128
EPOCHS = 100
LATENT_DIM = 100
DATA_DIR = r"D:\pk\dataset\real"

GEN_PATH = "generator.pth"
DISC_PATH = "discriminator.pth"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =====================================================
# DATASET
# =====================================================
class ImageDataset(Dataset):
    def __init__(self, folder):
        self.files = [
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.lower().endswith(("jpg", "jpeg", "png"))
        ]
        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        return self.transform(img)

dataset = ImageDataset(DATA_DIR)
loader = DataLoader(dataset, batch_size=BATCH, shuffle=True)

# =====================================================
# GENERATOR
# =====================================================
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(LATENT_DIM, 256, 4, 1, 0),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, CHANNELS, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

# =====================================================
# DISCRIMINATOR
# =====================================================
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(CHANNELS, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).view(-1)

# =====================================================
# INIT MODELS
# =====================================================
G = Generator().to(DEVICE)
D = Discriminator().to(DEVICE)

# =====================================================
# MODEL SUMMARY
# =====================================================
print("\nGENERATOR SUMMARY\n")
print(G)
print(f"Generator Params: {sum(p.numel() for p in G.parameters()):,}")

print("\nDISCRIMINATOR SUMMARY\n")
print(D)
print(f"Discriminator Params: {sum(p.numel() for p in D.parameters()):,}\n")

# =====================================================
# OPTIMIZER & LOSS
# =====================================================
criterion = nn.BCELoss()
opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

# =====================================================
# TRAINING LOOP
# =====================================================
for epoch in range(EPOCHS):
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for real in pbar:
        real = real.to(DEVICE)
        b = real.size(0)

        real_label = torch.ones(b, device=DEVICE)
        fake_label = torch.zeros(b, device=DEVICE)

        # ---------------------
        # Train Discriminator
        # ---------------------
        z = torch.randn(b, LATENT_DIM, 1, 1, device=DEVICE)
        fake = G(z)

        D_real = D(real)
        D_fake = D(fake.detach())

        loss_D = criterion(D_real, real_label) + criterion(D_fake, fake_label)

        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # ---------------------
        # Train Generator
        # ---------------------
        D_fake = D(fake)
        loss_G = criterion(D_fake, real_label)

        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

        pbar.set_postfix(D_loss=loss_D.item(), G_loss=loss_G.item())

print("✔ Training Complete")

# =====================================================
# SAVE MODELS
# =====================================================
torch.save(G.state_dict(), GEN_PATH)
torch.save(D.state_dict(), DISC_PATH)
print("✔ Models Saved")

# =====================================================
# LOAD MODELS (INFERENCE)
# =====================================================
G.load_state_dict(torch.load(GEN_PATH, map_location=DEVICE))
G.eval()

# =====================================================
# INFERENCE (IMAGE GENERATION)
# =====================================================
@torch.no_grad()
def generate_images(n=16):
    z = torch.randn(n, LATENT_DIM, 1, 1, device=DEVICE)
    imgs = G(z)
    imgs = (imgs + 1) / 2
    return imgs.cpu()

# =====================================================
# SHOW GENERATED IMAGES
# =====================================================
samples = generate_images(16)

plt.figure(figsize=(5,5))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(samples[i].permute(1,2,0))
    plt.axis("off")
plt.show()



GENERATOR SUMMARY

Generator(
  (net): Sequential(
    (0): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): Tanh()
  )
)
Generator Params: 1,069,379

DISCRIMINATOR SUMMARY

Discriminator(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_s

Epoch 1/100:  15%|█▌        | 12/79 [00:22<02:07,  1.90s/it, D_loss=0.214, G_loss=3.53]