In [36]:
import torch
import torch.nn as nn

In [37]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

In [38]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self._block(z_dim, features_g * 16, 4, 1, 0),  # 1x1 -> 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # 4x4 -> 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # 8x8 -> 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # 16x16 -> 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),  # 32x32 -> 64x64
            nn.ConvTranspose2d(
                channels_img, channels_img, kernel_size=(2, 5), stride=(1, 2), padding=(0, 1)
            ),  # 64x64 -> 78x64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)

In [39]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [40]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100
  x = torch.randn((N, in_channels, H, W))
  disc = Discriminator(in_channels, 8)
  initialize_weights(disc)
  assert disc(x).shape == (N, 1, 1, 1)
  gen = Generator(z_dim, in_channels, 8)
  initialize_weights(gen)
  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W)

In [41]:
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [43]:
LEARNING_RATE = 1e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
FEATURES_DISC = 64
FEATURES_GEN = 64
NUM_EPOCHS = 2
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

In [44]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

In [45]:
dataset = datasets.MNIST(root = "dataset/", train = True, transform = transforms, download = True)
# dataset = datasets.ImageFolder(root="C:/Users/princ/Desktop/Deep Learning/GAN'S/DCGAN/celeb_dataset", transform=transforms)
loader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)

In [46]:
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

In [47]:
opt_gen = optim.Adam(gen.parameters(), lr = LEARNING_RATE, betas=(0.0, 0.9))
opt_disc = optim.Adam(disc.parameters(), lr = LEARNING_RATE, betas=(0.0, 0.9))

In [48]:
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_fake = SummaryWriter(f"runs/WGAN/fake")
writer_real = SummaryWriter(f"runs/WGAN/real")
step = 0

In [49]:
def gradient_penalty(disc, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    
    # # Print shapes for debugging
    # print(f"real shape: {real.shape}, fake shape: {fake.shape}, alpha shape: {alpha.shape}")
    
    interpolated_images = real * alpha + fake * (1 - alpha)
    
    # Print shape of interpolated images
    # print(f"interpolated_images shape: {interpolated_images.shape}")
    mixed_scores = disc(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [50]:
gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
)

In [51]:
import torch.nn.functional as F

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)

        # Train Discriminator
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
            fake = gen(noise)
            fake = F.interpolate(fake, size=(real.shape[2], real.shape[3]), mode='bilinear', align_corners=False)
            disc_real = disc(real).reshape(-1)
            disc_fake = disc(fake).reshape(-1)
            gp = gradient_penalty(disc, real, fake, device=device)
            loss_disc = (
                -(torch.mean(disc_real) - torch.mean(disc_fake)) + LAMBDA_GP * gp
            )
            disc.zero_grad()
            loss_disc.backward(retain_graph=True)
            opt_disc.step()

        # Train Generator
        noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
        fake = gen(noise)
        fake = F.interpolate(fake, size=(real.shape[2], real.shape[3]), mode='bilinear', align_corners=False)
        output = disc(fake).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} "
                f"Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                fake = F.interpolate(fake, size=(real.shape[2], real.shape[3]), mode='bilinear', align_corners=False)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(real, normalize=True)
                writer_fake.add_image(
                    "Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/2] Batch 0/938 Loss D: 404.2479, loss G: -0.6013
Epoch [0/2] Batch 100/938 Loss D: -96.5284, loss G: 52.1825
Epoch [0/2] Batch 200/938 Loss D: -121.5974, loss G: 72.1163
Epoch [0/2] Batch 300/938 Loss D: -133.1516, loss G: 83.5268
Epoch [0/2] Batch 400/938 Loss D: -137.3414, loss G: 95.7015
Epoch [0/2] Batch 500/938 Loss D: -137.8174, loss G: 105.3022
Epoch [0/2] Batch 600/938 Loss D: -138.2807, loss G: 112.9258
Epoch [0/2] Batch 700/938 Loss D: -138.8304, loss G: 118.2692
Epoch [0/2] Batch 800/938 Loss D: -135.6711, loss G: 121.7909
Epoch [0/2] Batch 900/938 Loss D: -136.1325, loss G: 124.5179


RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 0