In [None]:
'''GANs employ two neural networks, the Generator and the Discriminator, that are trained simultaneously through adversarial training, with the objective to create new data that is indistinguishable from real data.'''
import torch.nn as nn
def gen_block(in_dim, out_dim):
    return nn.Sequential(
        nn.Linear(in_dim, out_dim),
        nn.BatchNorm1d(out_dim),
        nn.ReLU(inplace=True)
    )

class Generator(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Generator, self).__init__()
        # Define generator block
        self.generator = ____(
            gen_block(in_dim, 256),
            gen_block(256, 512),
            gen_block(512, 1024),
          	# Add linear layer
            nn.Linear(1024, out_dim),
            # Add activation
            nn.Sigmoid(),
        )

    def forward(self, x):
      	# Pass input through generator
        return self.generator(x)

In [None]:
def disc_block(in_dim, out_dim):
    return nn.Sequential(
        nn.Linear(in_dim, out_dim),
        nn.LeakyReLU(0.2)
    )

class Discriminator(nn.Module):
    def __init__(self, im_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            disc_block(im_dim, 1024),
            disc_block(1024, 512),
            # Define last discriminator block
            disc_block(512, 256),
            # Add a linear layer
            nn.Linear(256, 1),
        )

    def forward(self, x):
        # Define the forward method
        return self.disc(x)

Deep Convolutional GAN

In [None]:
def dc_gen_block(in_dim, out_dim, kernel_size, stride):
    return nn.Sequential(
        nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride=stride),
        nn.BatchNorm2d(out_dim),
        nn.ReLU()
    )

class DCGenerator(nn.Module):
    def __init__(self, in_dim, kernel_size=4, stride=2):
        super(DCGenerator, self).__init__()
        self.in_dim = in_dim
        self.gen = nn.Sequential(
            dc_gen_block(in_dim, 1024, kernel_size, stride),
            dc_gen_block(1024, 512, kernel_size, stride),
            # Add last generator block
            dc_gen_block(512, 256, kernel_size, stride),
            # Add transposed convolution
            nn.ConvTranspose2d(256, 3, kernel_size, stride=stride),
            # Add tanh activation
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(len(x), self.in_dim, 1, 1)
        return self.gen(x)

In [None]:
def dc_disc_block(in_dim, out_dim, kernel_size, stride):
    return nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size, stride=stride),
        nn.BatchNorm2d(out_dim),
        nn.LeakyReLU(0.2),
    )

class DCDiscriminator(nn.Module):
    def __init__(self, kernel_size=4, stride=2):
        super(DCDiscriminator, self).__init__()
        self.disc = nn.Sequential(
          	# Add first discriminator block
            dc_disc_block(3, 512, kernel_size, stride),
            dc_disc_block(512, 1024, kernel_size, stride),
          	# Add a convolution
            nn.Conv2d(1024, 1, kernel_size, stride=stride),
        )

    def forward(self, x):
        # Pass input through sequential block
        x = self.disc(x)
        return x.view(len(x), -1)

Training GANs - generator's loss function

In [None]:
def gen_loss(gen, disc, criterion, num_images, z_dim):
    # Define random noise
    noise = torch.randn(num_images, z_dim)
    # Generate fake image
    fake = gen(noise)
    # Get discriminator's prediction on the fake image
    disc_pred = disc(fake)
    # Compute generator loss
    criterion = nn.BCEWithLogitsLoss()
    gen_loss = criterion(disc_pred, torch.ones_like(disc_pred))
    return gen_loss

In [None]:
def disc_loss(gen, disc, real, num_images, z_dim):
    criterion = nn.BCEWithLogitsLoss()
    noise = torch.randn(num_images, z_dim)
    fake = gen(noise)
    # Get discriminator's predictions for fake images
    disc_pred_fake = disc(fake)
    # Calculate the fake loss component
    fake_loss = criterion(disc_pred_fake, torch.zeros_like(disc_pred_fake))
    # Get discriminator's predictions for real images
    disc_pred_real = disc(real)
    # Calculate the real loss component
    real_loss = criterion(disc_pred_real, torch.zeros_like(disc_pred_real))
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss

In [None]:
for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        disc_loss = disc_loss(gen, disc, real, cur_batch_size, z_dim=16)
        # Compute gradients
        disc_loss.backward()
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        gen_loss = gen_loss(gen, disc, cur_batch_size, z_dim=16)
        # Compute generator gradients
        gen_loss.backward()
        gen_opt.step()

        print(f"Generator loss: {gen_loss}")
        print(f"Discriminator loss: {disc_loss}")
        break

Evaluating GANs

In [None]:
num_images_to_generate = 5
# Create random noise tensor
noise = torch.randn(num_images_to_generate, 16)

# Generate images
with torch.no_grad():
    fake = gen(noise)
print(f"Generated tensor shape: {fake.shape}")
    
for i in range(num_images_to_generate):
    # Slice fake to select i-th image
    image_tensor = fake[i, :, :, :]
    # Permute the image dimensions
    image_tensor_permuted = image_tensor.permute(1,2,0)
    plt.imshow(image_tensor_permuted)
    plt.show()

In [None]:
# Import FrechetInceptionDistance
from torchmetrics.image.fid import FrechetInceptionDistance

# Instantiate FID
fid = FrechetInceptionDistance(feature=64)

# Update FID with real images
fid.update((fake * 255).to(torch.uint8), real=False)
fid.update((real * 255).to(torch.uint8), real=True)

# Compute the metric
fid_score = fid.compute()
print(fid_score)