In [3]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

data_dir = '../data/images/'

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageFolder(data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

In [11]:
import torch
import torch.nn as nn

class VAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=1024, z_dim=128):
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(h_dim, z_dim*2)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 8 * 8 * 512),
            nn.ReLU(),
            nn.Unflatten(1, (512, 8, 8)),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, image_channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )
    
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
    
    def forward(self, x):
        mu_logvar = self.encoder(x).view(-1, 2, 128)
        mu = mu_logvar[:, 0, :]
        logvar = mu_logvar[:, 1, :]
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar
    
class Generator(nn.Module):
    def __init__(self, z_dim=128, image_channels=3):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim, 2*2*512),
            nn.ReLU(),
            nn.Unflatten(1, (512, 2, 2)),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, image_channels=3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(image_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(8192, 1)
        )
    
    def forward(self, x):
        return self.model(x)

In [9]:
from torch.optim import Adam

def vae_loss(recon_x, x, mu, logvar):
    MSE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD

vae = VAE().cuda()
optimizer = Adam(vae.parameters(), lr=1e-3)

num_epochs = 100

for epoch in range(num_epochs):
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, logvar = vae(data)
        loss = vae_loss(recon_batch, data, mu, logvar)
        
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(dataloader.dataset), loss.item() / len(data)))

Epoch: 0 [0/8300]	Loss: 4927.501953
Epoch: 1 [0/8300]	Loss: 2620.732666
Epoch: 2 [0/8300]	Loss: 2221.200439
Epoch: 3 [0/8300]	Loss: 2125.510498
Epoch: 4 [0/8300]	Loss: 1989.382080
Epoch: 5 [0/8300]	Loss: 1889.310791
Epoch: 6 [0/8300]	Loss: 1698.959229
Epoch: 7 [0/8300]	Loss: 1807.059692
Epoch: 8 [0/8300]	Loss: 1804.775757
Epoch: 9 [0/8300]	Loss: 1777.194580
Epoch: 10 [0/8300]	Loss: 1756.620728
Epoch: 11 [0/8300]	Loss: 1709.420776
Epoch: 12 [0/8300]	Loss: 1742.474365
Epoch: 13 [0/8300]	Loss: 1688.092651
Epoch: 14 [0/8300]	Loss: 1764.496460
Epoch: 15 [0/8300]	Loss: 1627.638184
Epoch: 16 [0/8300]	Loss: 1609.145996
Epoch: 17 [0/8300]	Loss: 1687.813477
Epoch: 18 [0/8300]	Loss: 1776.345581
Epoch: 19 [0/8300]	Loss: 1780.171753
Epoch: 20 [0/8300]	Loss: 1560.727417
Epoch: 21 [0/8300]	Loss: 1572.530151
Epoch: 22 [0/8300]	Loss: 1619.540039
Epoch: 23 [0/8300]	Loss: 1760.682495
Epoch: 24 [0/8300]	Loss: 1561.254639
Epoch: 25 [0/8300]	Loss: 1608.885986
Epoch: 26 [0/8300]	Loss: 1453.273438
Epoch: 27 [

In [12]:
generator = Generator().cuda()
discriminator = Discriminator().cuda()

optimizer_G = Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizer_D = Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))

criterion = nn.BCEWithLogitsLoss()

real_label = 1
fake_label = 0

for epoch in range(num_epochs):
    for batch_idx, (real_images, _) in enumerate(dataloader):
        real_images = real_images.cuda()
        batch_size = real_images.size(0)
        
        # Train the discriminator
        optimizer_D.zero_grad()
        
        real_target = torch.full((batch_size, 1), real_label, dtype=torch.float32, device='cuda')
        real_output = discriminator(real_images)
        d_real_loss = criterion(real_output, real_target)
        
        z = torch.randn(batch_size, 128, device='cuda')
        fake_images = generator(z)
        fake_target = torch.full((batch_size, 1), fake_label, dtype=torch.float32, device='cuda')
        fake_output = discriminator(fake_images.detach())
        d_fake_loss = criterion(fake_output, fake_target)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train the generator
        optimizer_G.zero_grad()
        
        fake_output = discriminator(fake_images)
        g_loss = criterion(fake_output, real_target)
        g_loss.backward()
        optimizer_G.step()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch} [{batch_idx}/{len(dataloader)}]\tD_loss: {d_loss.item():.4f}\tG_loss: {g_loss.item():.4f}')

In [None]:
from pytorch_fid import fid_score
from torchvision.models import inception_v3
from scipy.stats import entropy
import numpy as np

def inception_score(images, batch_size=32, splits=10):
    assert images.shape[1] == 3
    model = inception_v3(pretrained=True, transform_input=False).to(device)
    model.eval()

    def get_preds(images):
        preds = model(images)
        return F.softmax(preds, dim=1).cpu().data.numpy()

    preds = []
    n_batches = images.shape[0] // batch_size

    for i in range(n_batches):
        batch = images[i * batch_size:(i + 1) * batch_size].to(device)
        preds.append(get_preds(batch))

    preds = np.concatenate(preds, axis=0)
    scores = []
    for _ in range(splits):
        part = preds[np.random.choice(preds.shape[0], preds.shape[0] // splits, replace=False)]
        py = np.mean(part, axis=0)
        scores.append(entropy(part.T, py).mean())

    return np.exp(np.mean(scores)), np.exp(np.std(scores))

num_images = 1000

with torch.no_grad():
    z = torch.randn(num_images, z_dim).to(device)
    vae_images = vae.decoder(z)
    gan_images = generator(z)
    
# Prepare the images in the format expected by the FID and Inception Score functions
real_images = (next(iter(dataloader))[0] * 0.5 + 0.5).numpy().transpose((0, 2, 3, 1))
vae_images_np = (vae_images.cpu().numpy().transpose((0, 2, 3, 1)) * 0.5 + 0.5)
gan_images_np = (gan_images.cpu().numpy().transpose((0, 2, 3, 1)) * 0.5 + 0.5)

# Compute FID score
fid_vae = fid_score.calculate_fid_given_paths([real_images, vae_images_np], batch_size=32, device=device)
fid_gan = fid_score.calculate_fid_given_paths([real_images, gan_images_np], batch_size=32, device=device)

# Compute Inception Score
is_vae_mean, is_vae_std = inception_score(torch.tensor(vae_images_np.transpose((0, 3, 1, 2))))
is_gan_mean, is_gan_std = inception_score(torch.tensor(gan_images_np.transpose((0, 3, 1, 2))))

print("VAE FID score:", fid_vae)
print("GAN FID score:", fid_gan)
print("VAE Inception Score (mean, std):", is_vae_mean, is_vae_std)
print("GAN Inception Score (mean, std):", is_gan_mean, is_gan_std)