<a href="https://colab.research.google.com/github/JHyunjun/DQTGAN/blob/main/WGANGP_VGGLOSS_CIFAR_10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torchvision.models import vgg16
from torch.utils.data import DataLoader
from torch.autograd import Variable, grad

# Hyper Parameter
batch_size = 32
epochs = 10
latent_dim = 100
lambda_gp = 10
n_critic = 5

# GPU set
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data load
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

downsample = transforms.Resize(8)

dataset = datasets.CIFAR10(root='./', train = True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
import matplotlib.pyplot as plt
import torchvision
import numpy as np

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Get some random training images
dataiter = iter(dataloader)
images, labels = next(dataiter)

# Show Input images
downsampled_images = downsample(images)
print('8X8 Images')
imshow(torchvision.utils.make_grid(downsampled_images[:3]))

# Show original images
print('Original(32X32) Images')
imshow(torchvision.utils.make_grid(images[:3]))



In [None]:
#Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input is 3 x 4 x 4
            nn.ConvTranspose2d(3, 256, 2, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # Size now is 256 x 8 x 8
            nn.ConvTranspose2d(256, 128, 2, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # Size now is 128 x 16 x 16
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output size is 3 x 32 x 32
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input is 3 x 32 x 32
            nn.Conv2d(3, 256, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # Size now is 128 x 16 x 16
            nn.Conv2d(256, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # Size now is 256 x 8 x 8
            nn.Conv2d(256, 1, 4, 2, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

# Pretrained VGG for perceptual loss
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        model = vgg16(pretrained=True)
        features = model.features
        self.to_relu_1_2 = nn.Sequential()
        for x in range(4):
            self.to_relu_1_2.add_module(str(x), features[x])
        self.to_relu_1_2 = self.to_relu_1_2.eval()

    def forward(self, input, target):
        input = (input + 1) / 2
        target = (target + 1) / 2
        return torch.nn.functional.l1_loss(self.to_relu_1_2(input), self.to_relu_1_2(target))

# WGAN-GP gradient penalty
def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = alpha * real + ((1 - alpha) * fake)
    mixed_scores = critic(interpolated_images)

    gradient = grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    penalty = torch.mean((gradient_norm - 1)**2)
    return penalty

# 모델 생성
G = Generator().to(device)
D = Discriminator().to(device)
vgg_loss = VGGPerceptualLoss().to(device)

# Optimizers
G_optimizer = optim.Adam(G.parameters(), lr=0.0003, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=0.0003, betas=(0.5, 0.999))

# Training Loop
for epoch in range(epochs):
    for i, (real, _) in enumerate(dataloader):
        real = real.to(device)
        real_downsampled = downsample(real).to(device)

        # Train Discriminator
        D_optimizer.zero_grad()

        fake = G(real_downsampled)  # Use downsampled real image as input for Generator
        real_score = D(real)
        fake_score = D(fake)

        gp = gradient_penalty(D, real, fake, device)
        d_loss = -(torch.mean(real_score) - torch.mean(fake_score)) + lambda_gp * gp

        d_loss.backward(retain_graph=True)
        D_optimizer.step()

        # Train Generator
        if i % n_critic == 0:
            G_optimizer.zero_grad()

            fake_score = D(fake)
            perceptual_loss = vgg_loss(fake, real)
            g_loss = -torch.mean(fake_score) + perceptual_loss

            g_loss.backward()
            G_optimizer.step()

    print(f'Epoch [{epoch}/{epochs}] d_loss: {d_loss.item()} g_loss: {g_loss.item()}')

    if epoch % 5 == 0:
        utils.save_image(fake.data[:25], f"{epoch}.png", nrow=5, normalize=True)

In [None]:
import matplotlib.pyplot as plt

image_start = 0
image_end = 5

# Test dataset load
test_dataset = datasets.CIFAR10(root='./', train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Get the dataloader iterator
data_iter = iter(test_dataloader)

# Get the original images
original_imgs = next(data_iter)[0][image_start:image_end].to(device)  # First batch

# Downsample the original images to get the input images
test_images = downsample(original_imgs.clone())

def show_imgs(original_imgs):
    # Set the generator to evaluation mode
    G.eval()
    with torch.no_grad():
        # Generate images from the downsampled images
        test_images_fake = G(test_images)

    # Convert the generated images to the correct size and normalization for visualization
    test_images_fake = (test_images_fake + 1) / 2  # Unnormalize
    test_images_fake = test_images_fake.clamp(0, 1)  # Clamp values

    # Convert the input and original images to the correct size and normalization for visualization
    input_imgs = (test_images + 1) / 2  # Unnormalize
    input_imgs = input_imgs.clamp(0, 1)  # Clamp values

    original_imgs = (original_imgs + 1) / 2  # Unnormalize
    original_imgs = original_imgs.clamp(0, 1)  # Clamp values

    # Make the grid of images
    grid_fake = utils.make_grid(test_images_fake, nrow=5, normalize=True).permute(1, 2, 0)
    grid_original = utils.make_grid(original_imgs, nrow=5, normalize=True).permute(1, 2, 0)
    grid_input = utils.make_grid(input_imgs, nrow=5, normalize=True).permute(1, 2, 0)

    # Plot the original, input and generated images
    fig, axs = plt.subplots(3, 1, figsize=(10,10))
    axs[1].imshow(grid_input.detach().cpu().numpy())
    axs[1].set_title("Input Images")
    axs[1].axis('off')
    axs[2].imshow(grid_fake.detach().cpu().numpy())
    axs[2].set_title("Generated Images")
    axs[2].axis('off')
    axs[0].imshow(grid_original.detach().cpu().numpy())
    axs[0].set_title("Original Images")
    axs[0].axis('off')
    plt.show()

# Show the images
show_imgs(original_imgs)
