# ðŸ§± DCGAN - Bricks Data

In this notebook, we'll walk through the steps required to train your own DCGAN on the bricks dataset

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 64
CHANNELS = 1
BATCH_SIZE = 128
Z_DIM = 100
EPOCHS = 300
LEARNING_RATE = 0.0002
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.999
NOISE_PARAM = 0.1
LOAD_MODEL = False
DATA_PATH = "./data/lego-brick-images/dataset"
OUTPUT_DIR = "./output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
MODEL_DIR = "./models"
os.makedirs(MODEL_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Prepare the data <a name="prepare"></a>

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # [-1,1] like tanh
])

train_dataset = datasets.ImageFolder(root=DATA_PATH, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

In [None]:
def show_batch(batch):
    grid = utils.make_grid(batch, nrow=8, normalize=True)
    plt.figure(figsize=(10, 5))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy(), cmap="gray")
    plt.axis("off")
    plt.show()

sample_batch = next(iter(train_loader))[0]
show_batch(sample_batch)

## 2. Build the GAN <a name="build"></a>

In [None]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self, channels=1):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Flatten(),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self, z_dim=100, channels=1):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64, momentum=0.9),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(x.size(0), x.size(1), 1, 1)
        return self.model(x)

In [None]:
discriminator = Discriminator(channels=CHANNELS).to(device)
generator = Generator(z_dim=Z_DIM, channels=CHANNELS).to(device)

## 3. Train the GAN <a name="train"></a>

In [None]:
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA_1, ADAM_BETA_2))
g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(ADAM_BETA_1, ADAM_BETA_2))


In [None]:
def save_generated_images(generator, epoch, n_samples=10):
    z = torch.randn(n_samples, Z_DIM, device=device)
    gen_imgs = generator(z).detach().cpu()
    grid = utils.make_grid(gen_imgs, nrow=n_samples, normalize=True)
    plt.figure(figsize=(10, 5))
    plt.imshow(grid.permute(1, 2, 0).numpy(), cmap="gray")
    plt.axis("off")
    plt.savefig(f"{OUTPUT_DIR}/generated_img_{epoch:03d}.png")
    plt.close()

In [None]:
for epoch in range(1, EPOCHS+1):
    for real_imgs, _ in tqdm(train_loader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # Train Discriminator
        z = torch.randn(batch_size, Z_DIM, device=device)
        fake_imgs = generator(z)

        real_labels = torch.ones(batch_size, 1, device=device)
        real_labels += NOISE_PARAM * torch.rand_like(real_labels)
        fake_labels = torch.zeros(batch_size, 1, device=device)
        fake_labels -= NOISE_PARAM * torch.rand_like(fake_labels)

        d_optimizer.zero_grad()
        real_preds = discriminator(real_imgs)
        fake_preds = discriminator(fake_imgs.detach())
        d_loss = (criterion(real_preds, real_labels) + criterion(fake_preds, fake_labels)) / 2
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        g_optimizer.zero_grad()
        fake_preds_for_g = discriminator(fake_imgs)
        g_loss = criterion(fake_preds_for_g, torch.ones_like(fake_preds_for_g))
        g_loss.backward()
        g_optimizer.step()

    print(f"Epoch [{epoch}/{EPOCHS}] d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")
    save_generated_images(generator, epoch)

In [None]:
torch.save(generator.state_dict(), f"{MODEL_DIR}/generator.pth")
torch.save(discriminator.state_dict(), f"{MODEL_DIR}/discriminator.pth")

## 4. Generate new images <a name="decode"></a>

In [None]:
def generate_grid(generator, n_rows=3, n_cols=5):
    z = torch.randn(n_rows*n_cols, Z_DIM, device=device)
    gen_imgs = generator(z).detach().cpu()
    grid = utils.make_grid(gen_imgs, nrow=n_cols, normalize=True)
    plt.figure(figsize=(10, 6))
    plt.imshow(grid.permute(1, 2, 0).numpy(), cmap="gray")
    plt.axis("off")
    plt.show()

generate_grid(generator)

In [None]:
def compare_images(img1, img2):
    return torch.mean(torch.abs(img1 - img2))

all_data = []
for imgs, _ in train_loader:
    all_data.append(imgs)
all_data = torch.cat(all_data, dim=0)

r, c = 3, 5
z = torch.randn(r*c, Z_DIM, device=device)
gen_imgs = generator(z).detach().cpu()

fig, axs = plt.subplots(r, c, figsize=(10, 6))
fig.suptitle("Closest images in training set", fontsize=16)
for i in range(r*c):
    min_diff = float("inf")
    closest_img = None
    for img in all_data:
        diff = compare_images(gen_imgs[i], img)
        if diff < min_diff:
            min_diff = diff
            closest_img = img
    axs[i//c, i%c].imshow(closest_img.squeeze(), cmap="gray")
    axs[i//c, i%c].axis("off")
plt.show()