In [3]:
# Install dependencies if not installed
# !pip install torch torchvision matplotlib

import torch
import torch.nn as nn
import torch.optim as optim
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Create a folder for generated samples
os.makedirs("generated_samples", exist_ok=True)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Better CIFAR-10 Generator (updated)
class Generator(nn.Module):
    def __init__(self, noise_dim=100, num_classes=10):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)

        self.init_size = 8  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(noise_dim + num_classes, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        labels = self.label_emb(labels)
        gen_input = torch.cat((noise, labels), -1)
        out = self.l1(gen_input)
        out = out.view(out.size(0), 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [5]:
# Discriminator remains the same
class Discriminator(nn.Module):
    def __init__(self, num_classes=10):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        
        self.model = nn.Sequential(
            nn.Linear(3*32*32 + num_classes, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img, labels):
        img_flat = img.view(img.size(0), -1)
        labels = self.label_emb(labels)
        input = torch.cat((img_flat, labels), -1)
        return self.model(input)

In [6]:
# Hyperparameters
batch_size = 64
lr = 0.0002
epochs = 50
noise_dim = 100
num_classes = 10

# CIFAR-10 Classes
cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Models (assuming you have Generator and Discriminator classes ready)
generator = Generator(noise_dim, num_classes).to(device)
discriminator = Discriminator(num_classes).to(device)

# Optimizers
optim_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optim_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Loss
adversarial_loss = nn.BCELoss()

# Training Loop
for epoch in range(epochs):
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        
        batch_size_curr = imgs.size(0)

        # --------------------
        # Train Discriminator
        # --------------------
        optim_D.zero_grad()

        # Real images
        real_validity = discriminator(imgs, labels)
        real_labels = torch.empty_like(real_validity).uniform_(0.8, 1.0)  # Label smoothing
        real_loss = adversarial_loss(real_validity, real_labels)

        # Fake images
        noise = torch.randn(batch_size_curr, noise_dim, device=device)
        gen_labels = torch.randint(0, num_classes, (batch_size_curr,), device=device)
        fake_imgs = generator(noise, gen_labels)

        fake_validity = discriminator(fake_imgs.detach(), gen_labels)
        fake_labels = torch.empty_like(fake_validity).uniform_(0.0, 0.2)  # Label smoothing
        fake_loss = adversarial_loss(fake_validity, fake_labels)

        # Total Discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optim_D.step()

        # --------------------
        # Train Generator
        # --------------------
        optim_G.zero_grad()

        gen_validity = discriminator(fake_imgs, gen_labels)
        valid_labels = torch.ones_like(gen_validity)  # Generator wants D to think it's real
        g_loss = adversarial_loss(gen_validity, valid_labels)

        g_loss.backward()
        optim_G.step()

    print(f"Epoch [{epoch+1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")


    # --------------------
    # Save generated images after every epoch
    # --------------------
    with torch.no_grad():
        generated_imgs = generator(fixed_noise, fixed_labels)
        generated_imgs = (generated_imgs + 1) / 2.0  # Denormalize to [0,1]
        save_image(generated_imgs, f"generated_samples/epoch_{epoch+1:03d}.png", nrow=5)


Files already downloaded and verified
Epoch [1/50] | D Loss: 0.6593 | G Loss: 1.1157


NameError: name 'fixed_noise' is not defined

In [6]:
torch.save(generator.state_dict(), "cifar_cgan_generator.pth")
print("Saved Conditional CIFAR-10 Generator!")

Saved Conditional CIFAR-10 Generator!
