In [None]:
# ✅ Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os

from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid

# ✅ Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ✅ Set Paths
save_dir = "/content/drive/MyDrive/cgan_cifar10"  # Your Google Drive folder
os.makedirs(save_dir, exist_ok=True)
os.makedirs(f"{save_dir}/generated_samples", exist_ok=True)
os.makedirs(f"{save_dir}/checkpoints", exist_ok=True)

# ✅ Hyperparameters
num_epochs = 500
batch_size = 128
learning_rate = 2e-4
noise_dim = 100
num_classes = 10
img_size = 32
channels = 3

# ✅ Transformations
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # (-1 to 1)
])

# ✅ Dataset and Loader
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# ✅ Generator
class Generator(nn.Module):
    def __init__(self, noise_dim, num_classes):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(noise_dim + num_classes, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, channels * img_size * img_size),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        labels = self.label_emb(labels)
        x = torch.cat([noise, labels], dim=1)
        x = self.model(x)
        x = x.view(x.size(0), channels, img_size, img_size)
        return x

# ✅ Discriminator
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(channels * img_size * img_size + num_classes, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        labels = self.label_emb(labels)
        img_flat = img.view(img.size(0), -1)
        d_in = torch.cat([img_flat, labels], dim=1)
        validity = self.model(d_in)
        return validity

# ✅ Initialize models
generator = Generator(noise_dim, num_classes).to(device)
discriminator = Discriminator(num_classes).to(device)

# ✅ Loss and optimizers
adversarial_loss = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# ✅ Learning rate schedulers
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=100, gamma=0.1)
scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=100, gamma=0.1)

# ✅ Fixed noise for samples
fixed_noise = torch.randn(25, noise_dim, device=device)
fixed_labels = torch.randint(0, num_classes, (25,), device=device)

# ✅ Training Loop
for epoch in range(1, num_epochs + 1):
    generator.train()
    discriminator.train()

    for i, (imgs, labels) in enumerate(train_loader):
        batch_size_i = imgs.size(0)
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # ✅ Label smoothing
        valid = torch.full((batch_size_i, 1), 0.9, device=device)  # Real labels smoothed to 0.9
        fake = torch.zeros(batch_size_i, 1, device=device)         # Fake labels 0.0

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        noise = torch.randn(batch_size_i, noise_dim, device=device)
        gen_labels = torch.randint(0, num_classes, (batch_size_i,), device=device)
        gen_imgs = generator(noise, gen_labels)

        g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        real_loss = adversarial_loss(discriminator(real_imgs, labels), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

    # Step the learning rate scheduler
    scheduler_G.step()
    scheduler_D.step()

    # ✅ Save generated images
    generator.eval()
    with torch.no_grad():
        gen_imgs = generator(fixed_noise, fixed_labels)
        gen_imgs = (gen_imgs + 1) / 2.0  # De-normalize
        grid = make_grid(gen_imgs, nrow=5)
        save_image(grid, f"{save_dir}/generated_samples/epoch_{epoch}.png")

    # ✅ Save model checkpoints
    if epoch % 50 == 0 or epoch == num_epochs:
        torch.save(generator.state_dict(), f"{save_dir}/checkpoints/generator_epoch_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"{save_dir}/checkpoints/discriminator_epoch_{epoch}.pth")

    print(f"Epoch [{epoch}/{num_epochs}]  Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

print("Training completed ✅")


Using device: cuda
Epoch [1/500]  Loss D: 0.2601, Loss G: 5.4431
Epoch [2/500]  Loss D: 0.2813, Loss G: 4.2664
Epoch [3/500]  Loss D: 0.2888, Loss G: 2.2455
Epoch [4/500]  Loss D: 0.2204, Loss G: 3.1708
Epoch [5/500]  Loss D: 0.2450, Loss G: 2.2851
Epoch [6/500]  Loss D: 0.2261, Loss G: 2.8419
Epoch [7/500]  Loss D: 0.2751, Loss G: 2.2979
Epoch [8/500]  Loss D: 0.3154, Loss G: 2.2672
Epoch [9/500]  Loss D: 0.2365, Loss G: 2.4570
Epoch [10/500]  Loss D: 0.2251, Loss G: 2.4040
Epoch [11/500]  Loss D: 0.3018, Loss G: 2.8377
Epoch [12/500]  Loss D: 0.2424, Loss G: 2.9397
Epoch [13/500]  Loss D: 0.2505, Loss G: 4.3046
Epoch [14/500]  Loss D: 0.2183, Loss G: 2.5180
Epoch [15/500]  Loss D: 0.2391, Loss G: 3.3543
Epoch [16/500]  Loss D: 0.2283, Loss G: 2.7990
Epoch [17/500]  Loss D: 0.2140, Loss G: 3.7090
Epoch [18/500]  Loss D: 0.2140, Loss G: 3.4815
Epoch [19/500]  Loss D: 0.2561, Loss G: 2.3322
Epoch [20/500]  Loss D: 0.2396, Loss G: 2.7938
Epoch [21/500]  Loss D: 0.2925, Loss G: 2.0746
Epo

In [None]:
torch.save(generator.state_dict(), "cifar_cgan_generator.pth")
print("Saved Conditional CIFAR-10 Generator!")