In [None]:
import torch
from torch import nn
import numpy as np
from torchvision.utils import save_image

device = "cuda" if torch.cuda.is_available() else "cpu"
device

# 1. Dataset

In [None]:
import torchvision

img_size = 32
num_classes = 10

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
])

mnist_images = torchvision.datasets.MNIST(root='mnist_data', train=True, 
                                    download=True, transform=transform)

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 64
dataloader = DataLoader(mnist_images, batch_size=BATCH_SIZE, shuffle=True)

# 2. Model

In [None]:
channels = 1
img_shape = (channels, img_size, img_size)
latent_dim = 100

In [None]:
class Generator(nn.Module):
    def __init__(self, num_classes, emb_dim):
        super().__init__()
        self.init_size = 8
        self.label_emb = nn.Embedding(num_classes, emb_dim) 
        self.fc = nn.Linear(latent_dim+emb_dim, 128*8*8)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, padding=1),

            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, padding=1),

            nn.BatchNorm2d(64,),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 
                      kernel_size=3, padding=1),            
            nn.Tanh()
        )
        
    def forward(self, z, label):
        cond = self.label_emb(label)
        x = torch.cat([z, cond], 1)
        x = self.fc(x)
        x = x.view(x.size(0), 128, self.init_size, self.init_size)
        img = self.conv_blocks(x)
        return img

In [None]:
class Descriminator(nn.Module):
    def __init__(self, num_classes, emb_dim):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, emb_dim) 
        
        self.injection = nn.Sequential(
            nn.Linear(channels*img_size*img_size + emb_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, channels*img_size*img_size),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.model = nn.Sequential(
            nn.Conv2d(channels, 16, kernel_size=3, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.adv_layer = nn.Sequential(
            nn.Linear(128*2*2, 1), 
            nn.Sigmoid()
        )
        
    def forward(self, img, label):

        img_plat = img.view(img.size(0), -1)
        cond = self.label_emb(label)

        x = torch.cat([img_plat, cond], 1)
        x = self.injection(x)

        img = x.view(-1, 1, 32, 32)

        img = self.model(img)
        img = img.view(img.size(0), -1)
        
        validity = self.adv_layer(img)

        return validity

In [None]:
generator = Generator(num_classes=num_classes, emb_dim=32)
discriminator = Descriminator(num_classes=num_classes, emb_dim=32)

In [None]:
generator.to(device)


In [None]:
discriminator.to(device)


# 3. Training

In [None]:
import os
os.makedirs("images_cDCGAN", exist_ok=True)

save_interval = 10

In [None]:
EPOCHS = 200

optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0001)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

criterion = nn.BCELoss()

hist = {
        "train_G_loss": [],
        "train_D_loss": [],
}

In [None]:
for epoch in range(EPOCHS):
    running_G_loss = 0.0
    running_D_loss = 0.0

    for i, (imgs, labels) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        condition_labels = torch.randint(0, num_classes, (imgs.shape[0],)).to(device)
        real_labels = torch.ones((imgs.shape[0], 1)).to(device)
        fake_labels = torch.zeros((imgs.shape[0], 1)).to(device)


        # -------------------------- Train Generator --- 
        optimizer_G.zero_grad()
        
        # Noise input for Generator
        z = torch.randn((imgs.shape[0], latent_dim)).to(device)

        gen_imgs = generator(z, condition_labels)
        validity = discriminator(gen_imgs, condition_labels)
        G_loss = criterion(validity, real_labels)
        running_G_loss += G_loss.item()

        G_loss.backward()
        optimizer_G.step()


        # -------------- Train Discriminator --- 
        optimizer_D.zero_grad()

        real_validity = discriminator(real_imgs, labels)
        real_loss = criterion(real_validity, real_labels)

        fake_validity = discriminator(gen_imgs.detach(), condition_labels)
        fake_loss = criterion(fake_validity, fake_labels)
        
        D_loss = (real_loss + fake_loss) / 2
        running_D_loss += D_loss.item()

        D_loss.backward()
        optimizer_D.step()
    
    epoch_G_loss = running_G_loss / len(dataloader)
    epoch_D_loss = running_D_loss / len(dataloader)
    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train G Loss: {epoch_G_loss:.4f}, Train D Loss: {epoch_D_loss:.4f}")

    hist["train_G_loss"].append(epoch_G_loss)
    hist["train_D_loss"].append(epoch_D_loss)

    if epoch % save_interval == 0:
        save_image(gen_imgs.data[:25], f"images_cDCGAN/epoch_{epoch}.png", nrow=5, normalize=True)

# 4. Inference

In [None]:
%matplotlib inline
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

In [None]:
generator.eval()

num_sample = 5
for i in range(num_classes):
    target_class = i
    z = torch.randn((num_sample, latent_dim)).to(device)
    condition_labels = torch.full((num_sample,), target_class, dtype=torch.long).to(device)

    gen_imgs = generator(z, condition_labels).detach().cpu()

    grid = make_grid(gen_imgs, nrow=num_sample, normalize=True).permute(1,2,0).numpy()
    plt.imshow(grid)
    plt.show()