#### About 

- LAPGAN is an abbreviation for "Laplacian Pyramid Generative Adversarial Networks." It is a generative adversarial network (GAN) used to generate high-resolution images.

- The main idea behind LAPGAN is to generate images in a step-by-step manner using a Laplacian pyramid. A Laplacian pyramid is a multiresolution image representation in which each level corresponds to a different scale of the image. The generator network in LAPGAN is trained to produce images at each level of the Laplacian pyramid, beginning with the coarsest and gradually increasing the resolution.

- LAPGAN uses a GAN with a modified loss function to generate an image at a given level of the Laplacian pyramid. This encourages the generated image to have the same Laplacian pyramid.

In [1]:
#importing modules
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as vutils
import torchvision

In [2]:
# Load FashionMNIST dataset
batch_size=2048
dataset = dset.FashionMNIST(root='./data', download=True,
                            transform=transforms.Compose([
                                transforms.Resize(64),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                            ]))

# Create dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=2)

In [3]:
# Define generator network, discriminator and laplacian network
# Define generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(
            # Initial layer
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            # Layer 2
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # Layer 3
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # Layer 4
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # Layer 5
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        #print("Here2",output.shape)
        return output

# Define discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.fc = nn.Linear(512*4*4, 1)

    def forward(self, x):
        x = nn.functional.leaky_relu(self.conv1(x), negative_slope=0.2)
        #print(x.shape)
        x = nn.functional.leaky_relu(self.conv2(x), negative_slope=0.2)
        #print(x.shape)
        x = nn.functional.leaky_relu(self.conv3(x), negative_slope=0.2)
        #print(x.shape)
        x = nn.functional.leaky_relu(self.conv4(x), negative_slope=0.2)
        #print(x.shape)
        x = x.view(x.size(0), -1)
        #print(x.shape)
        x = torch.sigmoid(self.fc(x))
        return x


# Define LaplacianPyramid class
class LaplacianPyramid(nn.Module):
    def __init__(self, num_levels):
        super(LaplacianPyramid, self).__init__()
        self.num_levels = num_levels
        self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, count_include_pad=False)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x):
        # Create list to hold pyramid levels
        pyramid = []
        # Add original image to pyramid
        pyramid.append(x)
        # Generate additional levels
        for i in range(self.num_levels - 1):
            x = self.downsample(x)
            pyramid.append(x)
        # Reconstruct original image and store residuals
        for i in range(self.num_levels - 1, 0, -1):
            x_up = self.upsample(pyramid[i])
            x_down = pyramid[i - 1]
            h, w = x_down.size()[2:]
            x_up = x_up[:, :, :h, :w]
            diff = x_down - x_up
            pyramid[i - 1] = diff
        return pyramid       


In [4]:
num_pyramid_levels=3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define fixed noise for testing generator during training
fixed_noise = torch.randn(64, 100, 1, 1)
generator = Generator().to(device)
discriminator = Discriminator().to(device)
pyramid = LaplacianPyramid(num_pyramid_levels).to(device)
fixed_noise = fixed_noise.to(device)

In [8]:
#loss function and optimizers
# Define loss function and optimizers
learning_rate=0.0002
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
num_epochs=50


In [15]:
!rm -rf generated_images
!mkdir generated_images
!rm -rf discriminator_epoch*.*
!rm -rf generator_epoch*.*

In [None]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # Update discriminator
        discriminator.zero_grad()
        real, _ = data
        #print(real.shape)
        batch_size = real.size(0)
        real = real.to(device)
        input_real = Variable(real)
        target_real = Variable(torch.ones(batch_size))
  
        target_real = target_real.to(device)
        #print(input_real.shape)
        output_real = discriminator(input_real)
        output_real = output_real.squeeze(1)
        error_real = criterion(output_real, target_real)
        error_real.backward()

        noise = torch.randn(batch_size, 100, 1, 1)
        
        noise = noise.to(device)
        input_fake = Variable(generator(noise).data)
        target_fake = Variable(torch.zeros(batch_size))
        
        target_fake = target_fake.to(device)
        #print("Here",input_fake.shape)
        output_fake = discriminator(input_fake)
        output_fake = output_fake.squeeze(1)
        error_fake = criterion(output_fake, target_fake)
        error_fake.backward()

        optimizer_d.step()

        # Update generator
        generator.zero_grad()
        noise = torch.randn(batch_size, 100, 1, 1)
        
        noise = noise.to(device)
        input_fake = generator(noise)
        output_fake = discriminator(input_fake)
        target_real = Variable(torch.ones(batch_size))
        
        target_real = target_real.to(device)
        pyramid_fake = pyramid(input_fake)
        pyramid_real = pyramid(input_real)
        loss = 0
        for j in range(num_pyramid_levels):
            loss += criterion(torch.sigmoid(pyramid_fake[j]), pyramid_real[j].detach())
        loss.backward()
        optimizer_g.step()

        # Print loss and save generated images periodically
        if i % 10 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, num_epochs, i, len(dataloader), error_real + error_fake, loss))
            #save models
            torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pth")
            torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch}.pth")

            # Generate images from fixed noise
            generator.eval()
            with torch.no_grad():
                fixed_fake = generator(fixed_noise)
                pyramid_fake = pyramid(fixed_fake)
                pyramid_real = pyramid(input_real[:64])
            generator.train()
            # Save images
            vutils.save_image(fixed_fake.data[:64], 'generated_images/fake_samples_epoch_%03d.png' % epoch, normalize=True)
            vutils.save_image(input_real[:64], 'generated_images/real_samples.png', normalize=True)
            for j in range(num_pyramid_levels):
                vutils.save_image(pyramid_real[j].data[:64], 'generated_images/real_pyramid_%d.png' % j, normalize=True)
                vutils.save_image(pyramid_fake[j].data[:64], 'generated_images/fake_pyramid_%d_epoch_%03d.png' % (j, epoch), normalize=True)



[0/50][0/30] Loss_D: 0.0000 Loss_G: 0.5850
[0/50][10/30] Loss_D: 0.0000 Loss_G: 0.5945
[0/50][20/30] Loss_D: 0.0000 Loss_G: 0.5953
[1/50][0/30] Loss_D: 0.0000 Loss_G: 0.6010
[1/50][10/30] Loss_D: 0.0000 Loss_G: 0.5928
[1/50][20/30] Loss_D: 0.0000 Loss_G: 0.5630
[2/50][0/30] Loss_D: 0.0000 Loss_G: 0.5766
[2/50][10/30] Loss_D: 0.0000 Loss_G: 0.5546
[2/50][20/30] Loss_D: 0.0000 Loss_G: 0.5508
[3/50][0/30] Loss_D: 0.0000 Loss_G: 0.5693
[3/50][10/30] Loss_D: 0.0000 Loss_G: 0.5523
[3/50][20/30] Loss_D: 0.0000 Loss_G: 0.5680
[4/50][0/30] Loss_D: 0.0000 Loss_G: 0.5520
