In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image

from torch.utils.tensorboard import SummaryWriter

root_path = '/mnt/c/Users/121js/OneDrive/Desktop/TorchImages/'

# Make Disc and Gen classes

In [2]:
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_size, img_channels=3, features=128): 
        super().__init__()
        self.disc = nn.Sequential(
            self._make_conv2d_block(img_channels+1, features, use_bn=False), 

            self._make_conv2d_block(features, features*2), 
            self._make_conv2d_block(features*2, features*4), 
            self._make_conv2d_block(features*4, features*8), 

            self._make_conv2d_block(features*8, 1, 4, 2, 0, use_bn=False, use_act=False), 
            nn.Sigmoid(),
        )
        self.embed = nn.Embedding(num_classes, img_size*img_size)
        self.img_size = img_size

    def _make_conv2d_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True, use_act=True, leak=0.2):
        layers = [
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
        ]
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
        if use_act:
            layers.append(nn.LeakyReLU(leak))
        return nn.Sequential(*layers)
    
    def forward(self, x, labels):
        x = torch.cat([x, self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)], axis=1)
        return self.disc(x)
    

class Generator(nn.Module):
    def __init__(self, num_classes, embed_size, z_dim=100, img_channels=3, features=1024):
        super().__init__()
        self.gen = nn.Sequential(
            self._make_convT2d_block(z_dim+embed_size, features*8, 4, 1, 0), 

            self._make_convT2d_block(features*8, features*4), 
            self._make_convT2d_block(features*4, features*2), 
            self._make_convT2d_block(features*2, features), 

            self._make_convT2d_block(features, img_channels, use_bn=False, use_act=False), 
            nn.Tanh(),
        )
        self.embed = nn.Embedding(num_classes, embed_size)

    def _make_convT2d_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True, use_act=True):
        layers = [
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
        ]
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
        if use_act:
            layers.append(nn.ReLU())
        return nn.Sequential(*layers)        
    
    def forward(self, x, labels):
        x = torch.cat([x, self.embed(labels).unsqueeze(2).unsqueeze(3)], axis=1)
        return self.gen(x)
    
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(
                tensor=m.weight.data,
                mean=0.0,
                std=0.02,
            )

# Testing

In [3]:
# BATCH_SIZE, NUM_CHANNELS, H, W = 8, 3, 64, 64
# Z_DIM = 100
# FEATURES = 64
# NUM_CLASSES = 10
# LABELS = torch.arange(8)

# img_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, H, W))
# disc = Discriminator(NUM_CLASSES, H, NUM_CHANNELS, FEATURES)
# assert disc(img_test, LABELS).shape == (BATCH_SIZE, 1, 1, 1), 'Disc Test Failed'

# noise_test = torch.randn((BATCH_SIZE, Z_DIM, 1, 1))
# gen = Generator(NUM_CLASSES, H, Z_DIM, NUM_CHANNELS, FEATURES)
# assert gen(noise_test, LABELS).shape == (BATCH_SIZE, NUM_CHANNELS, H, W), 'Gen Test Failed'

# print('Success!')

# Hyperparameters

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 128 
NUM_CHANNELS = 1
IMG_SIZE = 64
NUM_CLASSES = 10
GEN_EMBEDDING = 100
Z_DIM = 100
FEATURES_DISC = 64
FEATURES_GEN = 128
LR = 1e-4
NUM_EPOCHS = 50
DISC_ITER = 1
GRID_SHOW = 10

# SetUp

In [5]:
transformations = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.Grayscale(NUM_CHANNELS),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(NUM_CHANNELS)], [0.5 for _ in range(NUM_CHANNELS)]
        ),
    ]
)
# images = datasets.ImageFolder(root=root_path + 'dogs/images', transform=transformations)
images = datasets.MNIST(root='mnist', train=True, transform=transformations, download=True)
images_loader = DataLoader(dataset=images, batch_size=BATCH_SIZE, shuffle=True)

disc = Discriminator(NUM_CLASSES, IMG_SIZE, NUM_CHANNELS, FEATURES_DISC).to(DEVICE)
gen = Generator(NUM_CLASSES, GEN_EMBEDDING, Z_DIM, NUM_CHANNELS, FEATURES_GEN).to(DEVICE)
initialize_weights(disc), initialize_weights(gen)

fixed_labels = torch.arange(10).to(DEVICE)
optim_disc = optim.Adam(disc.parameters(), lr=LR, betas=(0.5, 0.999))
optim_gen = optim.Adam(gen.parameters(), lr=LR, betas=(0.5, 0.999))
criterion = nn.BCELoss()

writer_fake = SummaryWriter('logs/fake')
writer_real = SummaryWriter('logs/real')
steps = 1

In [6]:
def model_size(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Generator Size: {model_size(gen):.3e}, Discriminator Size: {model_size(disc):.3e}')

Generator Size: 1.429e+07, Discriminator Size: 2.806e+06


# Training

In [7]:
gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, labels) in enumerate(images_loader):
        real = real.to(DEVICE)
        labels = labels.to(DEVICE)

        for _ in range(DISC_ITER):
            noise = torch.randn((real.shape[0], Z_DIM, 1, 1)).to(DEVICE)
            fake = gen(noise, labels)
            disc_real = disc(real, labels).view(-1)
            loss_disc_real = criterion(
                input=disc_real,
                target=torch.ones_like(disc_real),
            )
            disc_fake = disc(fake.detach(), labels).view(-1)
            loss_disc_fake = criterion(
                input=disc_fake,
                target=torch.zeros_like(disc_fake),
            )
            loss_disc = (loss_disc_real + loss_disc_fake)
            disc.zero_grad()
            loss_disc.backward()
            optim_disc.step()
            
        disc_fake = disc(fake, labels).view(-1)
        loss_gen = criterion(
            input=disc_fake,
            target=torch.ones_like(disc_fake),
        )
        gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        if batch_idx % 40 == 0:
            print(
                f'Epoch: {epoch+1}/{NUM_EPOCHS} -- Step: {steps} -- Batch: {batch_idx+1}/{len(images_loader)} -- Disc Loss: {loss_disc:.4f} -- Gen Loss: {loss_gen:.4f}'
            )
            with torch.no_grad():
                fixed_noise = torch.randn((GRID_SHOW, Z_DIM, 1, 1)).to(DEVICE)

                fake = gen(fixed_noise, fixed_labels)
                fake_images = torchvision.utils.make_grid(fake, nrow=5, normalize=True)
                save_image(fake_images, f'dcgan_results/{steps}.png')
                real_images = torchvision.utils.make_grid(real[:GRID_SHOW], nrow=5, normalize=True)
                writer_fake.add_scalar('Gen Loss', loss_gen, global_step=steps)
                writer_fake.add_image(
                    'Fake', fake_images, global_step=steps
                )
                writer_real.add_image(
                    'Real', real_images, global_step=steps
                )
                steps += 1
                if steps == 358:
                    break

Epoch: 1/50 -- Step: 1 -- Batch: 1/469 -- Disc Loss: 1.3881 -- Gen Loss: 0.7137
Epoch: 1/50 -- Step: 2 -- Batch: 41/469 -- Disc Loss: 0.3444 -- Gen Loss: 2.0648
Epoch: 1/50 -- Step: 3 -- Batch: 81/469 -- Disc Loss: 0.1026 -- Gen Loss: 3.2163
Epoch: 1/50 -- Step: 4 -- Batch: 121/469 -- Disc Loss: 0.0511 -- Gen Loss: 3.8288
Epoch: 1/50 -- Step: 5 -- Batch: 161/469 -- Disc Loss: 0.0301 -- Gen Loss: 4.3112
Epoch: 1/50 -- Step: 6 -- Batch: 201/469 -- Disc Loss: 0.0203 -- Gen Loss: 4.6997
Epoch: 1/50 -- Step: 7 -- Batch: 241/469 -- Disc Loss: 0.1415 -- Gen Loss: 3.1004
Epoch: 1/50 -- Step: 8 -- Batch: 281/469 -- Disc Loss: 0.0392 -- Gen Loss: 4.0925
Epoch: 1/50 -- Step: 9 -- Batch: 321/469 -- Disc Loss: 2.8101 -- Gen Loss: 0.1479
Epoch: 1/50 -- Step: 10 -- Batch: 361/469 -- Disc Loss: 0.5958 -- Gen Loss: 1.8808
Epoch: 1/50 -- Step: 11 -- Batch: 401/469 -- Disc Loss: 0.6492 -- Gen Loss: 1.9418
Epoch: 1/50 -- Step: 12 -- Batch: 441/469 -- Disc Loss: 0.6691 -- Gen Loss: 2.4900
Epoch: 2/50 -- St