In [1]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
batch_size = 64
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Resize(32), 
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [3]:
train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)

print('#' * 40)
print("CIFAR10 dataloader Generated")

Files already downloaded and verified
Files already downloaded and verified
########################################
CIFAR10 dataloader Generated


In [4]:
class GeneratorWGAN(nn.Module):
    def __init__(self):
        super(GeneratorWGAN, self).__init__()

        def create_block(input_features, output_features, normalize=True):
            layers = [nn.Linear(input_features, output_features)]
            if normalize:
                layers.append(nn.BatchNorm1d(output_features, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.network = nn.Sequential(
            *create_block(100, 128, normalize=False),
            *create_block(128, 256),
            *create_block(256, 512),
            *create_block(512, 1024),
            nn.Linear(1024, 3*32*32),
            nn.Tanh()
        )

    def forward(self, noise_vector):
        image = self.network(noise_vector)
        image = image.view(image.shape[0], 3, 32, 32)
        return image

class DiscriminatorWGAN(nn.Module):
    def __init__(self):
        super(DiscriminatorWGAN, self).__init__()

        self.network = nn.Sequential(
            nn.Linear(3*32*32, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, image):
        image_flat = image.view(image.shape[0], -1)
        validity = self.network(image_flat)
        return validity

print("Instantiating WGAN generator and discriminator...")
wgan_generator_instance = GeneratorWGAN().to(device)
wgan_discriminator_instance = DiscriminatorWGAN().to(device)
print("Models are set up and moved to the device.")

Instantiating WGAN generator and discriminator...
Models are set up and moved to the device.


In [5]:
learning_rate = 5e-4
epochs = 50
batch_size = 64
n_critic = 5
weight_clipping_limit = 0.01

def train_wgan(generator_model, discriminator_model, data_loader):
    optimizer_gen = torch.optim.Adam(generator_model.parameters(), lr=learning_rate, betas=(0.5, 0.9))
    optimizer_disc = torch.optim.Adam(discriminator_model.parameters(), lr=learning_rate, betas=(0.5, 0.9))

    if not os.path.exists('train_generated_images_wgan/'):
        os.makedirs('train_generated_images_wgan')
        
    inception_score_log = open("inception_score_wgan.csv", "w")
    inception_score_log.write('epoch, inception_score \n')

    for epoch in tqdm(range(epochs)):
        for i, (images, _) in enumerate(data_loader):
            real_images = images.type(torch.cuda.FloatTensor).to(device)

            ### train discriminator
            optimizer_disc.zero_grad()
            noise_vector = torch.randn(images.shape[0], 100, device=device)
            fake_images = generator_model(noise_vector).detach()
            loss_discriminator = -torch.mean(discriminator_model(real_images)) + torch.mean(discriminator_model(fake_images))
            loss_discriminator.backward()
            optimizer_disc.step()

            # Apply weight clipping
            for p in discriminator_model.parameters():
                p.data.clamp_(-weight_clipping_limit, weight_clipping_limit)

            # Train generator every n_critic iterations
            if i % n_critic == 0:
                optimizer_gen.zero_grad()
                generated_images = generator_model(noise_vector)
                loss_generator = -torch.mean(discriminator_model(generated_images))
                loss_generator.backward()
                optimizer_gen.step()

        # Compute inception score and generate samples every epoch
        test_noise = torch.randn(images.shape[0], 100, device=device)
        sample_images = generator_model(test_noise)
        sample_images = sample_images.mul(0.5).add(0.5)  # Normalize to [0, 1]

        assert 0 <= sample_images.min() and sample_images.max() <= 1
        inception_score, inception_score_std = get_inception_score(sample_images)
        print(f"Epoch: {epoch}, Inception Score: {round(inception_score, 2)} ± {round(inception_score_std, 2)}")

        image_grid = utils.make_grid(sample_images[:64].data.cpu())
        utils.save_image(image_grid, f'train_generated_images_wgan/epoch_{epoch}.png')
        
        inception_score_log.write(f'{epoch}, {round(inception_score, 2)}\n')

    inception_score_log.close()


In [7]:
# train WGAN
print("training WGAN model...")
train_wgan(wgan_generator_instance, wgan_discriminator_instance, train_CIFAR10_dataloader)
print("saving WGAN model to file...")
torch.save(wgan_generator_instance.state_dict(), 'wgan_generator.pkl')
torch.save(wgan_discriminator_instance.state_dict(), 'wgan_discriminator.pkl')

training WGAN model...


  2%|▏         | 1/50 [00:18<14:55, 18.28s/it]

Epoch: 0, Inception Score: 1.89 ± 0.22


  4%|▍         | 2/50 [00:33<13:15, 16.58s/it]

Epoch: 1, Inception Score: 1.83 ± 0.16


  6%|▌         | 3/50 [00:49<12:37, 16.12s/it]

Epoch: 2, Inception Score: 1.89 ± 0.22


  8%|▊         | 4/50 [01:04<12:09, 15.85s/it]

Epoch: 3, Inception Score: 1.75 ± 0.22


 10%|█         | 5/50 [01:20<11:47, 15.71s/it]

Epoch: 4, Inception Score: 1.77 ± 0.2


 12%|█▏        | 6/50 [01:35<11:23, 15.54s/it]

Epoch: 5, Inception Score: 1.71 ± 0.17


 14%|█▍        | 7/50 [01:51<11:11, 15.61s/it]

Epoch: 6, Inception Score: 1.86 ± 0.19


 16%|█▌        | 8/50 [02:06<10:55, 15.60s/it]

Epoch: 7, Inception Score: 1.96 ± 0.25


 18%|█▊        | 9/50 [02:22<10:40, 15.61s/it]

Epoch: 8, Inception Score: 1.86 ± 0.17


 20%|██        | 10/50 [02:37<10:23, 15.59s/it]

Epoch: 9, Inception Score: 1.76 ± 0.23


 22%|██▏       | 11/50 [02:53<10:04, 15.50s/it]

Epoch: 10, Inception Score: 1.66 ± 0.14


 24%|██▍       | 12/50 [03:08<09:47, 15.46s/it]

Epoch: 11, Inception Score: 1.82 ± 0.28


 26%|██▌       | 13/50 [03:23<09:31, 15.45s/it]

Epoch: 12, Inception Score: 1.81 ± 0.2


 28%|██▊       | 14/50 [03:39<09:16, 15.47s/it]

Epoch: 13, Inception Score: 1.87 ± 0.23


 30%|███       | 15/50 [03:54<09:01, 15.47s/it]

Epoch: 14, Inception Score: 1.9 ± 0.27


 32%|███▏      | 16/50 [04:10<08:47, 15.51s/it]

Epoch: 15, Inception Score: 1.8 ± 0.25


 34%|███▍      | 17/50 [04:26<08:35, 15.63s/it]

Epoch: 16, Inception Score: 1.87 ± 0.22


 36%|███▌      | 18/50 [04:42<08:22, 15.70s/it]

Epoch: 17, Inception Score: 1.94 ± 0.32


 38%|███▊      | 19/50 [04:57<07:59, 15.46s/it]

Epoch: 18, Inception Score: 1.87 ± 0.25


 40%|████      | 20/50 [05:12<07:40, 15.36s/it]

Epoch: 19, Inception Score: 1.81 ± 0.3


 42%|████▏     | 21/50 [05:29<07:41, 15.91s/it]

Epoch: 20, Inception Score: 2.05 ± 0.22


 44%|████▍     | 22/50 [05:44<07:18, 15.68s/it]

Epoch: 21, Inception Score: 1.78 ± 0.15


 46%|████▌     | 23/50 [05:59<06:58, 15.51s/it]

Epoch: 22, Inception Score: 1.95 ± 0.26


 48%|████▊     | 24/50 [06:14<06:39, 15.38s/it]

Epoch: 23, Inception Score: 1.94 ± 0.28


 50%|█████     | 25/50 [06:29<06:21, 15.27s/it]

Epoch: 24, Inception Score: 2.07 ± 0.24


 52%|█████▏    | 26/50 [06:45<06:09, 15.41s/it]

Epoch: 25, Inception Score: 1.86 ± 0.27


 54%|█████▍    | 27/50 [07:00<05:51, 15.27s/it]

Epoch: 26, Inception Score: 2.04 ± 0.43


 56%|█████▌    | 28/50 [07:15<05:35, 15.25s/it]

Epoch: 27, Inception Score: 1.76 ± 0.24


 58%|█████▊    | 29/50 [07:30<05:19, 15.23s/it]

Epoch: 28, Inception Score: 1.8 ± 0.12


 60%|██████    | 30/50 [07:46<05:04, 15.24s/it]

Epoch: 29, Inception Score: 2.04 ± 0.39


 62%|██████▏   | 31/50 [08:01<04:48, 15.18s/it]

Epoch: 30, Inception Score: 1.95 ± 0.17


 64%|██████▍   | 32/50 [08:16<04:33, 15.20s/it]

Epoch: 31, Inception Score: 2.09 ± 0.22


 66%|██████▌   | 33/50 [08:31<04:16, 15.10s/it]

Epoch: 32, Inception Score: 1.92 ± 0.29


 68%|██████▊   | 34/50 [08:45<03:57, 14.84s/it]

Epoch: 33, Inception Score: 1.85 ± 0.15


 70%|███████   | 35/50 [09:00<03:44, 14.99s/it]

Epoch: 34, Inception Score: 1.79 ± 0.24


 72%|███████▏  | 36/50 [09:18<03:42, 15.87s/it]

Epoch: 35, Inception Score: 2.03 ± 0.42
Epoch: 36, Inception Score: 1.89 ± 0.33


 76%|███████▌  | 38/50 [09:53<03:21, 16.76s/it]

Epoch: 37, Inception Score: 2.01 ± 0.44
Epoch: 38, Inception Score: 1.85 ± 0.26


 80%|████████  | 40/50 [10:26<02:45, 16.50s/it]

Epoch: 39, Inception Score: 1.85 ± 0.2


 82%|████████▏ | 41/50 [10:42<02:28, 16.46s/it]

Epoch: 40, Inception Score: 1.88 ± 0.26


 84%|████████▍ | 42/50 [11:01<02:17, 17.14s/it]

Epoch: 41, Inception Score: 1.97 ± 0.32
Epoch: 42, Inception Score: 1.95 ± 0.26


 88%|████████▊ | 44/50 [11:34<01:40, 16.82s/it]

Epoch: 43, Inception Score: 1.79 ± 0.24


 90%|█████████ | 45/50 [11:52<01:25, 17.03s/it]

Epoch: 44, Inception Score: 2.05 ± 0.25


 92%|█████████▏| 46/50 [12:10<01:09, 17.29s/it]

Epoch: 45, Inception Score: 2.09 ± 0.33


 94%|█████████▍| 47/50 [12:27<00:51, 17.33s/it]

Epoch: 46, Inception Score: 1.84 ± 0.26


 96%|█████████▌| 48/50 [12:43<00:33, 16.93s/it]

Epoch: 47, Inception Score: 1.96 ± 0.38


 98%|█████████▊| 49/50 [13:00<00:17, 17.01s/it]

Epoch: 48, Inception Score: 1.9 ± 0.18


100%|██████████| 50/50 [13:17<00:00, 15.95s/it]

Epoch: 49, Inception Score: 2.04 ± 0.41
saving WGAN model to file...





In [15]:
def generate_sample_images(generator_model):
    noise_vector = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
    generated_samples = generator_model(noise_vector)
    generated_samples = generated_samples.mul(0.5).add(0.5)  # Normalize images to [0, 1]
    generated_samples = generated_samples.data.cpu()
    image_grid = utils.make_grid(generated_samples)
    print("Grid of 8x8 images saved to 'wgan_generated_images.png'.")
    utils.save_image(image_grid, 'wgan_generated_images.png')

def load_trained_model(model_instance, model_path):
    model_instance.load_state_dict(torch.load(model_path))

# Assuming wgan_generator and wgan_discriminator have been instantiated and `device` is defined.
print("Loading WGAN model...")
load_trained_model(wgan_generator_instance, 'wgan_generator.pkl')
load_trained_model(wgan_discriminator_instance, 'wgan_discriminator.pkl')

generate_sample_images(wgan_generator_instance)

Loading WGAN model...
Grid of 8x8 images saved to 'wgan_generated_images.png'.


In [16]:
def generate_best_sample_images(generator_model, discriminator_model, num_images=100):
    # Generate noise vector correctly shaped for a linear layer
    noise_vector = torch.randn(num_images, 100).to(device)  # Remove the 1,1 dimensions

    generated_samples = generator_model(noise_vector)
    generated_samples = generated_samples.mul(0.5).add(0.5)  # Normalize images to [0, 1]

    # Evaluate images with the discriminator to get scores
    with torch.no_grad():
        scores = discriminator_model(generated_samples).view(-1)

    # Sort images by discriminator scores in descending order and pick the top 10
    _, top_indices = torch.topk(scores, 10)
    best_samples = generated_samples[top_indices]

    # Move the best samples to CPU and make a grid
    best_samples = best_samples.data.cpu()
    image_grid = utils.make_grid(best_samples, nrow=5)  # Arrange grid with 5 images per row
    print("Grid of the 10 best images saved to 'wgan_best_images.png'.")
    utils.save_image(image_grid, 'wgan_best_images.png')


In [17]:
print("Loading WGAN model...")
load_trained_model(wgan_generator_instance, 'wgan_generator.pkl')
load_trained_model(wgan_discriminator_instance, 'wgan_discriminator.pkl')
generate_best_sample_images(wgan_generator_instance, wgan_discriminator_instance)

Loading WGAN model...
Grid of the 10 best images saved to 'wgan_best_images.png'.
