In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
!jupyter nbextension enable --py widgetsnbextension

  warn(


Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [2]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True), # this was batch norm in wgan
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

In [3]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


In [4]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [5]:
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(noise_dim, in_channels, 8)
    z = torch.randn((N, noise_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"


In [6]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [7]:
device = "cuda:7" if torch.cuda.is_available() else "cpu"
# lr = 1e-4
# BATCH_SIZE = 64
# CHANNELS_IMG = 1 #3
lr = 2e-4
BATCH_SIZE = 32
IMAGE_SIZE = 64
CHANNELS_IMG = 3 #3
Noise_dim = 100
num_epochs = 50
FEATURES_CRITIC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10

In [8]:
transforms = transforms.Compose(
    [
        # transforms.Resize(IMAGE_SIZE),
        transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

In [9]:
# dataset = datasets.MNIST(root="/mnt/disk1/Gulshan/GAN/DCGAN/dataset", transform=transforms, download=False)
dataset=datasets.ImageFolder(root="/mnt/disk1/Gulshan/dataset/Dog_data",transform=transforms)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,)
x,y=next(iter(loader))
x.shape

torch.Size([32, 3, 64, 64])

In [10]:
gen = Generator(Noise_dim, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=lr, betas=(0.0, 0.9))

fixed_noise = torch.randn(32, Noise_dim, 1, 1).to(device)
writer_real = SummaryWriter(f"logs_WGCAN_my/real")
writer_fake = SummaryWriter(f"logs_WGCAN_my/fake")
step = 0

In [11]:
gen.train()
critic.train()

for epoch in tqdm(range(num_epochs),total=num_epochs):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(tqdm(loader)):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Noise_dim, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx== 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                # writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_real.add_image("Real", img_grid_real, global_step=epoch+1)
                writer_fake.add_image("Fake", img_grid_fake, global_step=epoch+1)

            # step += 1

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/266 [00:00<?, ?it/s]

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch [0/50] Batch 0/266                   Loss D: 69.6034, loss G: -0.1643


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [1/50] Batch 0/266                   Loss D: -55.3033, loss G: 110.7690


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [2/50] Batch 0/266                   Loss D: -53.0876, loss G: 103.2914


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [3/50] Batch 0/266                   Loss D: -15.0849, loss G: 101.1408


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [4/50] Batch 0/266                   Loss D: -13.3195, loss G: 77.0674


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [5/50] Batch 0/266                   Loss D: -12.5289, loss G: 77.2063


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [6/50] Batch 0/266                   Loss D: -14.8673, loss G: 70.5661


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [7/50] Batch 0/266                   Loss D: -12.3188, loss G: 69.5912


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [8/50] Batch 0/266                   Loss D: -7.8189, loss G: 75.7038


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [9/50] Batch 0/266                   Loss D: -12.0822, loss G: 54.8569


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [10/50] Batch 0/266                   Loss D: -9.7663, loss G: 62.9632


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [11/50] Batch 0/266                   Loss D: -9.2849, loss G: 69.6240


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [12/50] Batch 0/266                   Loss D: -5.9437, loss G: 69.3776


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [13/50] Batch 0/266                   Loss D: -5.8513, loss G: 51.0778


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [14/50] Batch 0/266                   Loss D: -8.9086, loss G: 75.0584


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [15/50] Batch 0/266                   Loss D: -6.6494, loss G: 72.8608


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [16/50] Batch 0/266                   Loss D: -4.0855, loss G: 64.6609


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [17/50] Batch 0/266                   Loss D: -6.7543, loss G: 53.1131


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [18/50] Batch 0/266                   Loss D: -5.5306, loss G: 61.5469


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [19/50] Batch 0/266                   Loss D: -6.9622, loss G: 42.1371


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [20/50] Batch 0/266                   Loss D: -5.7626, loss G: 62.7545


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [21/50] Batch 0/266                   Loss D: -3.7593, loss G: 45.9201


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [22/50] Batch 0/266                   Loss D: -5.1698, loss G: 58.3444


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [23/50] Batch 0/266                   Loss D: -9.2736, loss G: 53.8220


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [24/50] Batch 0/266                   Loss D: -4.8781, loss G: 49.8728


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [25/50] Batch 0/266                   Loss D: -4.0362, loss G: 55.0139


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [26/50] Batch 0/266                   Loss D: -9.1136, loss G: 48.0271


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [27/50] Batch 0/266                   Loss D: -6.6816, loss G: 37.5892


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [28/50] Batch 0/266                   Loss D: -7.1220, loss G: 50.6369


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [29/50] Batch 0/266                   Loss D: -5.3829, loss G: 41.2299


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [30/50] Batch 0/266                   Loss D: -6.1867, loss G: 30.1814


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [31/50] Batch 0/266                   Loss D: -6.0816, loss G: 17.3093


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [32/50] Batch 0/266                   Loss D: -5.5026, loss G: 31.6222


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [33/50] Batch 0/266                   Loss D: -4.2630, loss G: 37.4049


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [34/50] Batch 0/266                   Loss D: -4.0199, loss G: 19.1122


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [35/50] Batch 0/266                   Loss D: -5.5761, loss G: 27.9832


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [36/50] Batch 0/266                   Loss D: -8.0019, loss G: 32.9565


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [37/50] Batch 0/266                   Loss D: -5.7948, loss G: 37.9743


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [38/50] Batch 0/266                   Loss D: -5.3323, loss G: 39.1729


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [39/50] Batch 0/266                   Loss D: -6.2142, loss G: 26.1781


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [40/50] Batch 0/266                   Loss D: -8.0052, loss G: 11.7041


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [41/50] Batch 0/266                   Loss D: -6.3244, loss G: 11.2803


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [42/50] Batch 0/266                   Loss D: -4.0428, loss G: 12.4764


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [43/50] Batch 0/266                   Loss D: -14.2538, loss G: 48.1696


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [44/50] Batch 0/266                   Loss D: -4.8180, loss G: 9.7471


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [45/50] Batch 0/266                   Loss D: -2.1124, loss G: 27.5852


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [46/50] Batch 0/266                   Loss D: -9.7145, loss G: 26.1931


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [47/50] Batch 0/266                   Loss D: -7.1048, loss G: 18.2779


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [48/50] Batch 0/266                   Loss D: -1.7239, loss G: 14.8373


  0%|          | 0/266 [00:00<?, ?it/s]

Epoch [49/50] Batch 0/266                   Loss D: -4.1798, loss G: 1.5443


In [12]:
# ! tensorboard --logdir=/mnt/disk1/Gulshan/GAN/WGAN/logs_wgan_gp_minist/GAN_MNIST