In [None]:
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader, Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Discriminator(nn.Module):
    def __init__(self, label_dim, img_channels=3):
        super().__init__()
        self.label_emb = nn.Embedding(label_dim, img_channels * img_size * img_size)
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        label_embedding = self.label_emb(labels).view(labels.size(0), 3, img_size, img_size)
        x = torch.cat([img, label_embedding], dim=1)
        return self.model(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, label_dim, img_channels=3):
        super().__init__()
        self.label_emb = nn.Embedding(label_dim, z_dim)
        self.model = nn.Sequential(
            nn.Linear(z_dim * 2, 128 * 8 * 8),
            nn.BatchNorm1d(128 * 8 * 8),
            nn.ReLU(True),
            nn.Unflatten(1, (128, 8, 8)),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, img_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z, labels):
        label_embedding = self.label_emb(labels)
        x = torch.cat([z, label_embedding], dim=1)
        return self.model(x)

In [None]:

z_dim = 100
label_dim = 2
batch_size = 64
lr = 2e-4
epochs = 20
img_size = 64
num_images = 10000

transform = transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3)
])

dataset = datasets.ImageFolder(root='/content/celeba/img_align_celeba', transform=transform)
subset = Subset(dataset, list(range(num_images)))
dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True)

G = Generator(z_dim, label_dim).to(device)
D = Discriminator(label_dim).to(device)

criterion = nn.BCELoss()
g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

os.makedirs("generated", exist_ok=True)

for epoch in range(1, epochs + 1):
    for imgs, _ in dataloader:
        batch_size = imgs.size(0)
        real_imgs = imgs.to(device)
        real_labels = torch.randint(0, label_dim, (batch_size,), device=device)

        z = torch.randn(batch_size, z_dim).to(device)
        fake_labels = torch.randint(0, label_dim, (batch_size,), device=device)
        fake_imgs = G(z, fake_labels)

        real_loss = criterion(D(real_imgs, real_labels), torch.ones(batch_size, 1).to(device))
        fake_loss = criterion(D(fake_imgs.detach(), fake_labels), torch.zeros(batch_size, 1).to(device))
        d_loss = real_loss + fake_loss

        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        z = torch.randn(batch_size, z_dim).to(device)
        gen_labels = torch.randint(0, label_dim, (batch_size,), device=device)
        gen_imgs = G(z, gen_labels)
        g_loss = criterion(D(gen_imgs, gen_labels), torch.ones(batch_size, 1).to(device))

        G.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    print(f"Epoch [{epoch}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
    if epoch % 5 == 0:
        G.eval()
        with torch.no_grad():
            z = torch.randn(16, z_dim).to(device)
            sample_labels = torch.randint(0, label_dim, (16,), device=device)
            fake_imgs = G(z, sample_labels)
            fake_imgs = (fake_imgs + 1) / 2  # Denormalize
            utils.save_image(fake_imgs, f"generated/sample_epoch_{epoch}.png", nrow=4)
        G.train()


Epoch [1/20] | D Loss: 0.8003 | G Loss: 1.2384
Epoch [2/20] | D Loss: 0.6971 | G Loss: 3.0782
Epoch [3/20] | D Loss: 0.5445 | G Loss: 2.0417
Epoch [4/20] | D Loss: 0.6199 | G Loss: 1.6738
Epoch [5/20] | D Loss: 0.8204 | G Loss: 3.5382
Epoch [6/20] | D Loss: 1.3846 | G Loss: 0.5518
Epoch [7/20] | D Loss: 0.7650 | G Loss: 0.8919
Epoch [8/20] | D Loss: 0.9777 | G Loss: 1.2962
Epoch [9/20] | D Loss: 0.6779 | G Loss: 1.7242
Epoch [10/20] | D Loss: 0.5818 | G Loss: 1.7265
Epoch [11/20] | D Loss: 0.6770 | G Loss: 0.6224
Epoch [12/20] | D Loss: 1.0123 | G Loss: 0.6182
Epoch [13/20] | D Loss: 0.7717 | G Loss: 1.2984
Epoch [14/20] | D Loss: 0.5891 | G Loss: 0.9734
Epoch [15/20] | D Loss: 0.5045 | G Loss: 1.9183
Epoch [16/20] | D Loss: 0.5665 | G Loss: 3.9800
Epoch [17/20] | D Loss: 0.6862 | G Loss: 2.0207
Epoch [18/20] | D Loss: 0.3417 | G Loss: 3.0633
Epoch [19/20] | D Loss: 0.6892 | G Loss: 1.9215
Epoch [20/20] | D Loss: 0.4587 | G Loss: 2.2163
