In [10]:
"""
Training of WGAN-GP
"""

'\nTraining of WGAN-GP\n'

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights
from tqdm import tqdm
from utils import gradient_penalty

In [12]:
# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
# CHANNELS_IMG = 1
CHANNELS_IMG = 1
NUM_CLASSES = 10
GEN_EMBEDDING = 100
NOISE_DIM = 100
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
# WEIGHT_CLIP = 0.01
LAMBDA_GP = 1

In [13]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),  # ADD THIS LINE - crops to exactly 64x64
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)])
    ]
)

dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=True)
# dataset = datasets.ImageFolder(root="custom_datasets/celeb_dataset", transform=transforms)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [14]:
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMAGE_SIZE, GEN_EMBEDDING).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_DISC, NUM_CLASSES, IMAGE_SIZE).to(device)
initialize_weights(gen)
initialize_weights(critic)

In [15]:
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
# criterion = nn.BCELoss()

In [16]:
fixed_noise = torch.randn(32,  Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

In [17]:
gen.train()
critic.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(2, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
  (embed): Embedding(10, 4096)
)

In [None]:
for epoch in range(NUM_EPOCHS):
    # labels needed for supervised, and not needed in unsupervised
    for batch_idx, (real, labels) in enumerate(tqdm(loader)):
        real = real.to(device)
        cur_batch_size = real.shape[0]  # ← Get actual batch size
        labels = labels.to(device)
    
        for _ in range(CRITIC_ITERATIONS):
            # noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
            noise = torch.randn((cur_batch_size, Z_DIM, 1, 1)).to(device)  # ← Use cur_batch_size
            fake = gen(noise, labels)
            critic_real = critic(real, labels).reshape(-1) # flat vector
            critic_fake = critic(fake, labels).reshape(-1)
            gp = gradient_penalty(critic, labels, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP*gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            # for p in critic.parameters():
            #     p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

        ### Train Generator: min -E[critic(gen_fake)]
        noise = torch.randn((cur_batch_size, Z_DIM, 1, 1)).to(device)  # ← Use cur_batch_size here too
        output = critic(fake, labels).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses ocassionally and print to tensorboard
        if batch_idx % 100 == 0:
            # Print losses occasionally and print to tensorboard
            if batch_idx % 100 == 0:
                print(
                    f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                    Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
                )
            
            with torch.no_grad():
                fake = gen(noise, labels)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )
                
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
            
            step += 1

  0%|                                                                                  | 1/938 [00:02<35:55,  2.30s/it]

Epoch [0/5] Batch 0/938                     Loss D: 0.6687, loss G: 0.1424


 11%|████████▌                                                                       | 101/938 [02:57<25:40,  1.84s/it]

Epoch [0/5] Batch 100/938                     Loss D: -17.9240, loss G: 15.6571


 21%|█████████████████▏                                                              | 201/938 [05:53<22:43,  1.85s/it]

Epoch [0/5] Batch 200/938                     Loss D: -47.1300, loss G: 52.3555


 32%|█████████████████████████▋                                                      | 301/938 [08:49<19:37,  1.85s/it]

Epoch [0/5] Batch 300/938                     Loss D: -167.2788, loss G: 123.0472


 43%|██████████████████████████████████▏                                             | 401/938 [11:45<16:36,  1.86s/it]

Epoch [0/5] Batch 400/938                     Loss D: -374.8900, loss G: 189.4330


 53%|██████████████████████████████████████████▋                                     | 501/938 [14:29<10:55,  1.50s/it]

Epoch [0/5] Batch 500/938                     Loss D: -509.7459, loss G: 334.8255


 64%|███████████████████████████████████████████████████▎                            | 601/938 [16:53<08:28,  1.51s/it]

Epoch [0/5] Batch 600/938                     Loss D: -796.7569, loss G: 467.4129


 75%|███████████████████████████████████████████████████████████▊                    | 701/938 [19:18<05:57,  1.51s/it]

Epoch [0/5] Batch 700/938                     Loss D: -1718.2653, loss G: 846.8582


 85%|████████████████████████████████████████████████████████████████████▎           | 801/938 [21:43<03:26,  1.50s/it]

Epoch [0/5] Batch 800/938                     Loss D: -1746.4707, loss G: 924.2231


 96%|████████████████████████████████████████████████████████████████████████████▊   | 901/938 [24:08<00:55,  1.51s/it]

Epoch [0/5] Batch 900/938                     Loss D: -2035.8812, loss G: 1104.3096


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [25:01<00:00,  1.60s/it]
  0%|                                                                                  | 1/938 [00:01<25:20,  1.62s/it]

Epoch [1/5] Batch 0/938                     Loss D: -2709.7366, loss G: 1402.0635


 11%|████████▌                                                                       | 101/938 [02:26<21:06,  1.51s/it]

Epoch [1/5] Batch 100/938                     Loss D: -2661.9417, loss G: 1734.5701


 21%|█████████████████                                                               | 200/938 [04:50<17:47,  1.45s/it]

Epoch [1/5] Batch 200/938                     Loss D: -3854.2507, loss G: 1940.8330


 32%|█████████████████████████▋                                                      | 301/938 [07:18<16:08,  1.52s/it]

Epoch [1/5] Batch 300/938                     Loss D: -4439.8423, loss G: 2278.2910


 43%|██████████████████████████████████▏                                             | 401/938 [09:44<13:38,  1.52s/it]

Epoch [1/5] Batch 400/938                     Loss D: -4535.4399, loss G: 1908.7439


 53%|██████████████████████████████████████████▋                                     | 501/938 [12:10<11:03,  1.52s/it]

Epoch [1/5] Batch 500/938                     Loss D: -6236.4868, loss G: 3037.6973


 64%|███████████████████████████████████████████████████▎                            | 601/938 [14:36<08:32,  1.52s/it]

Epoch [1/5] Batch 600/938                     Loss D: -7045.8667, loss G: 3601.9487


 75%|███████████████████████████████████████████████████████████▊                    | 701/938 [17:02<06:00,  1.52s/it]

Epoch [1/5] Batch 700/938                     Loss D: -5071.7725, loss G: 1891.2234


 85%|████████████████████████████████████████████████████████████████████▎           | 801/938 [19:27<03:28,  1.52s/it]

Epoch [1/5] Batch 800/938                     Loss D: -2768.0720, loss G: 3284.0811


 96%|████████████████████████████████████████████████████████████████████████████▊   | 901/938 [21:53<00:56,  1.52s/it]

Epoch [1/5] Batch 900/938                     Loss D: -10872.4199, loss G: 5457.9165


100%|████████████████████████████████████████████████████████████████████████████████| 938/938 [22:46<00:00,  1.46s/it]
  0%|                                                                                  | 1/938 [00:01<25:13,  1.62s/it]

Epoch [2/5] Batch 0/938                     Loss D: -8311.0293, loss G: 3862.3376


 11%|████████▌                                                                       | 101/938 [02:27<21:13,  1.52s/it]

Epoch [2/5] Batch 100/938                     Loss D: -12179.0645, loss G: 6286.6201


 21%|█████████████████▏                                                              | 201/938 [04:53<18:38,  1.52s/it]

Epoch [2/5] Batch 200/938                     Loss D: -13072.5859, loss G: 6636.0098


 32%|█████████████████████████▋                                                      | 301/938 [07:19<16:16,  1.53s/it]

Epoch [2/5] Batch 300/938                     Loss D: -14768.1963, loss G: 7410.0664


 43%|██████████████████████████████████▏                                             | 401/938 [09:45<13:47,  1.54s/it]

Epoch [2/5] Batch 400/938                     Loss D: -15888.0234, loss G: 7938.4521
