# DCGAN / Adam Optimizer, BCELogits Loss


In [None]:
!pip install torch torchvision matplotlib tqdm --quiet

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
base_dir = "/content/drive/MyDrive/Cogs185/CIFAR-DCGAN-1"
#os.makedirs(base_dir, exist_ok=True)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.utils as vutils
import numpy as np
import os
from tqdm import tqdm
from PIL import Image
from torch.nn.utils import spectral_norm

In [5]:
def one_hot(labels, num_classes=3):
    return F.one_hot(labels, num_classes).float()

In [6]:
class ConditionalGenerator(nn.Module):
    def __init__(self, z_dim=100, num_classes=3, img_channels=3, feature_maps=128):
        super().__init__()
        self.input_dim = z_dim + num_classes

        self.gen = nn.Sequential(
            nn.ConvTranspose2d(self.input_dim, feature_maps * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps),
            nn.ReLU(True),

            nn.ConvTranspose2d(feature_maps, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z, labels):
        one_hot_labels = one_hot(labels, 3).to(z.device)
        x = torch.cat([z, one_hot_labels], dim=1)
        x = x.view(x.size(0), -1, 1, 1)
        return self.gen(x)

In [11]:
class ConditionalDiscriminator(nn.Module):
    def __init__(self, img_channels=3, num_classes=3, feature_maps=128):
        super().__init__()
        self.label_embed = nn.Embedding(num_classes, num_classes)

        self.disc = nn.Sequential(
            spectral_norm(nn.Conv2d(img_channels + num_classes, feature_maps, 4, 2, 1)),
            nn.LeakyReLU(0.2),

            spectral_norm(nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1)),
            nn.LeakyReLU(0.2),

            spectral_norm(nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1)),
            nn.LeakyReLU(0.2),

            spectral_norm(nn.Conv2d(feature_maps * 4, 1, 4, 2, 1)),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )

    def forward(self, x, labels):
        label_map = self.label_embed(labels).unsqueeze(2).unsqueeze(3)
        label_map = label_map.expand(-1, -1, x.size(2), x.size(3))
        x = torch.cat([x, label_map], dim=1)
        return self.disc(x)


In [12]:
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

wanted_labels = [0, 1, 9]
label_map = {0: 0, 1: 1, 9: 2}

raw_data = datasets.CIFAR10(root="./data", download=True, transform=transform)
filtered_indices = [i for i, (_, label) in enumerate(raw_data) if label in wanted_labels]

class RelabeledCIFAR(Dataset):
    def __init__(self, base, indices):
        self.data = Subset(base, indices)
        self.targets = []
        for i in indices:
            original_label = base.targets[i]
            self.targets.append(label_map[original_label])

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        img, _ = self.data[idx]
        return img, torch.tensor(self.targets[idx])

train_loader = DataLoader(RelabeledCIFAR(raw_data, filtered_indices), batch_size=128, shuffle=True)

In [13]:
z_dim = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = ConditionalGenerator(z_dim=z_dim).to(device)
D = ConditionalDiscriminator().to(device)

opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

In [None]:
criterion = nn.BCEWithLogitsLoss()
epochs = 10

G.train()
D.train()

for epoch in range(epochs):
    loop = tqdm(train_loader, desc=f"Adam+BCE Epoch {epoch+1}/{epochs}")
    for real_images, labels in loop:
        batch_size = real_images.size(0)
        real_images = real_images.to(device)
        labels = labels.to(device)

        z = torch.randn(batch_size, z_dim).to(device)
        fake_images = G(z, labels)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        D_real = D(real_images, labels)
        D_fake = D(fake_images.detach(), labels)

        loss_real = criterion(D_real, real_labels)
        loss_fake = criterion(D_fake, fake_labels)
        loss_D = loss_real + loss_fake

        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

        z = torch.randn(batch_size, z_dim).to(device)
        fake_images = G(z, labels)
        D_pred = D(fake_images, labels)

        loss_G = criterion(D_pred, real_labels)

        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

        loop.set_postfix({
            "D_loss": loss_D.item(),
            "G_loss": loss_G.item()
        })

torch.save(G.state_dict(), "cDCGAN_generator.pt")
torch.save(D.state_dict(), "cDCGAN_discriminator.pt")


In [15]:
label_names = {
    0: "CIFAR_Airplane_gan_01",
    1: "CIFAR_Automobile_gan_01",
    2: "CIFAR_Truck_gan_01"
}

with torch.no_grad():
    for class_idx in range(3):
        folder_path = os.path.join(base_dir, label_names[class_idx])
        os.makedirs(folder_path, exist_ok=True)

        labels = torch.full((30,), class_idx, device=device, dtype=torch.long)
        z = torch.randn(30, z_dim).to(device)
        imgs = G(z, labels)

        for i, img in enumerate(imgs):
            img = (img * 0.5 + 0.5).clamp(0, 1)
            img_path = os.path.join(folder_path, f"{label_names[class_idx].lower()}_{i:02d}.png")
            transforms.ToPILImage()(img.cpu()).save(img_path)