In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

### DCGAN

In [None]:
class Generator(nn.Module):
    def __init__(self, image_size, latent_dim, init_channel, output_channel):
        super(Generator, self).__init__()
        self.init_channel = init_channel
        self.image_size = image_size
        self.init_size = image_size // 4
        hidden_dims = [64, 32]
        self.zt = nn.Linear(latent_dim, init_channel * self.init_size ** 2)

        modules = []
        modules.append(nn.BatchNorm2d(init_channel))

        for hidden_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(init_channel, hidden_dim, kernel_size=3, padding=1),
                    nn.BatchNorm2d(hidden_dim),
                    nn.LeakyReLU(0.2, inplace=True),
                )
            )
            init_channel = hidden_dim
        modules.append(
            nn.Sequential(
                nn.Conv2d(init_channel, output_channel, kernel_size=3, padding=1),
                nn.Tanh()
            )
        )
        self.generator = nn.Sequential(*modules)

    def forward(self, z):
        bs, _ = z.shape
        x = self.zt(z)
        x = x.view(bs, self.init_channel, self.init_size, self.init_size)
        x = self.generator(x)
        return x
    
class Discriminator(nn.Module):
    def __init__(self, in_channel, in_size, hidden_dims=None):
        super(Discriminator, self).__init__()

        hidden_dims = [64, 128, 256, 512] if hidden_dims is None else hidden_dims
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channel, h_dim, kernel_size=3, padding=1, stride=2),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Dropout2d(0.25)
                )
            )
            in_channel = h_dim
        self.encoder = nn.Sequential(*modules)
        fc_in = (in_size // (2 ** 4)) ** 2 * hidden_dims[-1]
        self.cls_head = nn.Linear(fc_in, 2)

    def forward(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, start_dim=1)
        x = self.cls_head(x)
        return x

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
epochs = 5
batch_size = 4
sample_interval = 100
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

criterion = nn.BCEWithLogitsLoss()
generator = Generator(image_size=64, latent_dim=8, init_channel=128, output_channel=1).to(device=device)
discriminator = Discriminator(in_channel=1, in_size=64).to(device=device)

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

d_opt = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_opt = torch.optim.Adam(generator.parameters(), lr=1e-4)

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        imgs = imgs.to(device=device, dtype=torch.float32)
        real = torch.ones((batch_size, 2), dtype=torch.float32).to(device=device)
        fake = torch.ones((batch_size, 2), dtype=torch.float32).to(device=device)
        real[:, 0] = 0
        fake[:, 1] = 0
        ''' Training Generator '''
        g_opt.zero_grad()
        z = torch.rand((batch_size, 8), dtype=torch.float32).to(device=device)
        gen_z = generator(z)
        g_loss = criterion(discriminator(gen_z), real)

        g_loss.backward()
        g_opt.step()

        ''' Training Discriminator '''
        d_opt.zero_grad()
        real_loss = criterion(discriminator(imgs), real)
        fake_loss = criterion(discriminator(gen_z.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        d_opt.step()

        if i % (len(dataloader) // 2) == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

        # batches_done = epoch * len(dataloader) + i
        # if batches_done % sample_interval == 0:
        #     save_image(gen_z.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

### Conditional GAN

In [None]:
class Generator(nn.Module):
    def __init__(self, image_size, latent_dim, output_channel, n_classes):
        super(Generator, self).__init__()
        self.emb = nn.Embedding(n_classes, n_classes)
        self.image_size = image_size
        self.output_channel = output_channel
        init_size = n_classes + latent_dim
        hidden_dims = [128, 256, 512, 1024]
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(init_size, h_dim),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU(0.2, inplace=True)
                )
            )
            init_size = h_dim
        modules.append(
            nn.Sequential(
                nn.Linear(init_size, output_channel * (image_size ** 2)),
                nn.Tanh()
            )
        )
        self.model = nn.Sequential(*modules)

    def forward(self, noise, label):
        bs, _ = noise.shape
        emb = self.emb(label)
        model_in = torch.cat([noise, emb], dim=-1)
        output = self.model(model_in)
        output = output.view(bs, self.output_channel, self.image_size, self.image_size)
        return output
    
class Discriminator(nn.Module):
    def __init__(self, in_channel, in_size, n_classes, hidden_dims=None):
        super(Discriminator, self).__init__()
        self.emb = nn.Embedding(n_classes, n_classes)

        hidden_dims = [512, 256, 128, 64]
        modules = []
        modules.append(
            nn.Sequential(
                nn.Linear(n_classes + in_channel * in_size ** 2, hidden_dims[0]),
                nn.LeakyReLU(0.2, inplace=True)
            )
        )
        init_dim = hidden_dims[0]
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Linear(init_dim, h_dim),
                    nn.Dropout(0.2),
                    nn.LeakyReLU(0.2, inplace=True)
                )
            )
            init_dim = h_dim
        modules.append(nn.Linear(init_dim, 2))
        self.model = nn.Sequential(*modules)
    
    def forward(self, img, label):
        emb = self.emb(label)
        img = torch.flatten(img, start_dim=1)
        model_in = torch.cat([img, emb], dim=-1)
        output = self.model(model_in)
        return output

In [None]:
epochs = 5
batch_size = 4
sample_interval = 100
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

criterion = nn.BCEWithLogitsLoss()
generator = Generator(image_size=64, latent_dim=8, output_channel=1, n_classes=10).to(device=device)
discriminator = Discriminator(in_channel=1, in_size=64, n_classes=10).to(device=device)

os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

d_opt = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_opt = torch.optim.Adam(generator.parameters(), lr=1e-4)

for epoch in range(epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        imgs = imgs.to(device=device, dtype=torch.float32)
        labels = labels.to(device=device, dtype=torch.long)
        real = torch.ones((batch_size, 2), dtype=torch.float32).to(device=device)
        fake = torch.ones((batch_size, 2), dtype=torch.float32).to(device=device)
        real[:, 0] = 0
        fake[:, 1] = 0
        ''' Training Generator '''
        g_opt.zero_grad()
        z = torch.rand((batch_size, 8), dtype=torch.float32).to(device=device)
        gen_z = generator(z, labels)
        g_loss = criterion(discriminator(gen_z, labels), real)

        g_loss.backward()
        g_opt.step()

        ''' Training Discriminator '''
        d_opt.zero_grad()
        real_loss = criterion(discriminator(imgs, labels), real)
        fake_loss = criterion(discriminator(gen_z.detach(), labels), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        d_opt.step()

        if i % 1 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

        # batches_done = epoch * len(dataloader) + i
        # if batches_done % sample_interval == 0:
        #     save_image(gen_z.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)