In [2]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [3]:
batch_size = 64
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Resize(32), 
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [4]:
train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)

print('#' * 40)
print("CIFAR10 dataloader Generated")


Files already downloaded and verified
Files already downloaded and verified
########################################
CIFAR10 dataloader Generated


In [5]:
class GeneratorACGAN(nn.Module):
    def __init__(self):
        super(GeneratorACGAN, self).__init__()
        self.embedding = nn.Embedding(10, 100)
        self.fully_connected = nn.Linear(100, 128 * 64)  # Simplified expression for 128 * 8 * 8
        self.generator_network = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        combined_input = torch.mul(self.embedding(labels), noise)
        transformed_input = self.fully_connected(combined_input)
        reshaped_input = transformed_input.view(transformed_input.shape[0], 128, 8, 8)
        output_image = self.generator_network(reshaped_input)
        return output_image

class DiscriminatorACGAN(nn.Module):
    def __init__(self):
        super(DiscriminatorACGAN, self).__init__()

        def discriminator_block(input_channels, output_channels, use_batchnorm=True):
            layers = [nn.Conv2d(input_channels, output_channels, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if use_batchnorm:
                layers.append(nn.BatchNorm2d(output_channels, 0.8))
            return layers

        self.discriminator_network = nn.Sequential(
            *discriminator_block(3, 16, use_batchnorm=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        self.adversarial_layer = nn.Sequential(nn.Linear(128 * 4, 1), nn.Sigmoid())
        self.classification_layer = nn.Sequential(nn.Linear(128 * 4, 10), nn.Softmax(dim=1))

    def forward(self, image):
        processed_image = self.discriminator_network(image)
        flattened_output = processed_image.view(processed_image.shape[0], -1)
        validity_output = self.adversarial_layer(flattened_output)
        label_output = self.classification_layer(flattened_output)
        return validity_output, label_output

print("Instantiating ACGAN generator and discriminator...")
acgan_generator_instance = GeneratorACGAN().to(device)
acgan_discriminator_instance = DiscriminatorACGAN().to(device)
print("Models are set up and moved to the device.")


Instantiating ACGAN generator and discriminator...
Models are set up and moved to the device.


In [6]:
epochs = 50
learning_rate = 2e-4

def train_acgan(generator_model, discriminator_model, data_loader):
    source_loss = nn.BCELoss()
    class_loss = nn.NLLLoss()
    optimizer_gen = torch.optim.Adam(generator_model.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_disc = torch.optim.Adam(discriminator_model.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    if not os.path.exists('train_generated_images_acgan_real/'): 
        os.makedirs('train_generated_images_acgan_real')
    if not os.path.exists('train_generated_images_acgan_fake/'): 
        os.makedirs('train_generated_images_acgan_fake')
        
    inception_log_file = open("inception_score_acgan.csv", "w")
    inception_log_file.write('epoch, inception_score \n')

    for epoch in tqdm(range(epochs)): 
        for real_images, real_labels in data_loader:
            current_batch_size = real_images.shape[0]
            real_images = Variable(real_images.type(torch.cuda.FloatTensor)).to(device)
            real_labels = Variable(real_labels.type(torch.cuda.LongTensor)).to(device)

            valid_labels = torch.ones(current_batch_size).to(device)
            fake_labels = torch.zeros(current_batch_size).to(device)

            ### Train generator
            optimizer_gen.zero_grad()
            noise_vector = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (current_batch_size, 100))))
            random_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, 10, current_batch_size)))

            generated_images = generator_model(noise_vector, random_labels)
            validity, predicted_labels = discriminator_model(generated_images)
            generator_loss = 0.5 * (source_loss(validity, valid_labels.unsqueeze(1)) + class_loss(predicted_labels, random_labels))
            generator_loss.backward()
            optimizer_gen.step()

            ### Train discriminator
            optimizer_disc.zero_grad()

            real_validity, real_predicted_labels = discriminator_model(real_images)
            loss_real = 0.5 * (source_loss(real_validity, valid_labels.unsqueeze(1)) + class_loss(real_predicted_labels, real_labels))

            fake_validity, fake_predicted_labels = discriminator_model(generated_images.detach())
            loss_fake = 0.5 * (source_loss(fake_validity, fake_labels.unsqueeze(1)) + class_loss(fake_predicted_labels, random_labels))

            discriminator_loss = 0.5 * (loss_real + loss_fake)
            discriminator_loss.backward()
            optimizer_disc.step()

        # Evaluate and save generated samples every epoch
        test_noise_vector = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (current_batch_size, 100))))
        test_random_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, 10, current_batch_size)))
        samples = generator_model(test_noise_vector, test_random_labels)

        samples = samples.mul(0.5).add(0.5)
        
        assert 0 <= samples.min() and samples.max() <= 1
        inception_score, std_dev_inception = get_inception_score(samples)
        print(f"Epoch: {epoch}, Inception Score: {round(inception_score, 2)} ± {round(std_dev_inception, 2)}")

        samples = samples[:64].data.cpu()
        utils.save_image(samples, f'train_generated_images_acgan_fake/epoch_{epoch}.png')
        utils.save_image(real_images, f'train_generated_images_acgan_real/epoch_{epoch}.png')
        
        inception_log_file.write(f"{epoch}, {round(inception_score, 2)}\n")

    inception_log_file.close()


In [7]:
print("training ACGAN model...")
train_acgan(acgan_generator_instance, acgan_discriminator_instance, train_CIFAR10_dataloader)


print("saving ACGAN model to file...")
torch.save(acgan_generator_instance.state_dict(), 'acgan_generator.pkl')
torch.save(acgan_discriminator_instance.state_dict(), 'acgan_discriminator.pkl')

training ACGAN model...


  2%|▏         | 1/50 [00:23<19:17, 23.63s/it]

Epoch: 0, Inception Score: 1.48 ± 0.14


  4%|▍         | 2/50 [00:45<17:52, 22.34s/it]

Epoch: 1, Inception Score: 1.53 ± 0.19


  6%|▌         | 3/50 [01:06<17:17, 22.07s/it]

Epoch: 2, Inception Score: 1.49 ± 0.19


  8%|▊         | 4/50 [01:27<16:37, 21.68s/it]

Epoch: 3, Inception Score: 1.56 ± 0.11


 10%|█         | 5/50 [01:48<15:55, 21.24s/it]

Epoch: 4, Inception Score: 1.69 ± 0.12


 12%|█▏        | 6/50 [02:08<15:25, 21.03s/it]

Epoch: 5, Inception Score: 1.61 ± 0.2


 14%|█▍        | 7/50 [02:30<15:13, 21.25s/it]

Epoch: 6, Inception Score: 1.57 ± 0.14


 16%|█▌        | 8/50 [02:53<15:15, 21.80s/it]

Epoch: 7, Inception Score: 1.57 ± 0.13


 18%|█▊        | 9/50 [03:16<15:11, 22.23s/it]

Epoch: 8, Inception Score: 1.63 ± 0.16


 20%|██        | 10/50 [03:39<14:59, 22.48s/it]

Epoch: 9, Inception Score: 1.66 ± 0.17


 22%|██▏       | 11/50 [04:03<14:45, 22.70s/it]

Epoch: 10, Inception Score: 1.72 ± 0.14


 24%|██▍       | 12/50 [04:25<14:15, 22.50s/it]

Epoch: 11, Inception Score: 1.7 ± 0.17


 26%|██▌       | 13/50 [04:46<13:41, 22.20s/it]

Epoch: 12, Inception Score: 1.69 ± 0.19


 28%|██▊       | 14/50 [05:07<13:07, 21.88s/it]

Epoch: 13, Inception Score: 1.72 ± 0.19


 30%|███       | 15/50 [05:28<12:31, 21.46s/it]

Epoch: 14, Inception Score: 1.61 ± 0.11


 32%|███▏      | 16/50 [05:48<12:01, 21.22s/it]

Epoch: 15, Inception Score: 1.56 ± 0.05


 34%|███▍      | 17/50 [06:09<11:33, 21.01s/it]

Epoch: 16, Inception Score: 1.78 ± 0.25


 36%|███▌      | 18/50 [06:30<11:08, 20.88s/it]

Epoch: 17, Inception Score: 1.72 ± 0.19


 38%|███▊      | 19/50 [06:50<10:44, 20.80s/it]

Epoch: 18, Inception Score: 1.86 ± 0.32


 40%|████      | 20/50 [07:13<10:46, 21.54s/it]

Epoch: 19, Inception Score: 1.9 ± 0.35


 42%|████▏     | 21/50 [07:37<10:39, 22.05s/it]

Epoch: 20, Inception Score: 1.71 ± 0.18


 44%|████▍     | 22/50 [08:00<10:26, 22.39s/it]

Epoch: 21, Inception Score: 1.98 ± 0.29


 46%|████▌     | 23/50 [08:23<10:12, 22.70s/it]

Epoch: 22, Inception Score: 1.88 ± 0.21


 48%|████▊     | 24/50 [08:47<10:00, 23.08s/it]

Epoch: 23, Inception Score: 1.74 ± 0.29


 50%|█████     | 25/50 [09:11<09:42, 23.31s/it]

Epoch: 24, Inception Score: 1.75 ± 0.2


 52%|█████▏    | 26/50 [09:35<09:23, 23.47s/it]

Epoch: 25, Inception Score: 1.93 ± 0.34


 54%|█████▍    | 27/50 [09:59<09:01, 23.56s/it]

Epoch: 26, Inception Score: 1.76 ± 0.17


 56%|█████▌    | 28/50 [10:22<08:39, 23.62s/it]

Epoch: 27, Inception Score: 1.76 ± 0.22


 58%|█████▊    | 29/50 [10:46<08:17, 23.70s/it]

Epoch: 28, Inception Score: 1.95 ± 0.35


 60%|██████    | 30/50 [11:10<07:56, 23.82s/it]

Epoch: 29, Inception Score: 1.78 ± 0.3


 62%|██████▏   | 31/50 [11:34<07:32, 23.84s/it]

Epoch: 30, Inception Score: 1.79 ± 0.14


 64%|██████▍   | 32/50 [11:58<07:09, 23.85s/it]

Epoch: 31, Inception Score: 2.01 ± 0.28


 66%|██████▌   | 33/50 [12:22<06:43, 23.72s/it]

Epoch: 32, Inception Score: 1.88 ± 0.16


 68%|██████▊   | 34/50 [12:46<06:20, 23.80s/it]

Epoch: 33, Inception Score: 1.92 ± 0.22


 70%|███████   | 35/50 [13:10<05:57, 23.84s/it]

Epoch: 34, Inception Score: 1.85 ± 0.16


 72%|███████▏  | 36/50 [13:33<05:33, 23.84s/it]

Epoch: 35, Inception Score: 2.0 ± 0.27


 74%|███████▍  | 37/50 [13:57<05:10, 23.86s/it]

Epoch: 36, Inception Score: 1.94 ± 0.31


 76%|███████▌  | 38/50 [14:21<04:45, 23.82s/it]

Epoch: 37, Inception Score: 1.93 ± 0.24


 78%|███████▊  | 39/50 [14:45<04:22, 23.82s/it]

Epoch: 38, Inception Score: 2.01 ± 0.29


 80%|████████  | 40/50 [15:09<03:58, 23.83s/it]

Epoch: 39, Inception Score: 1.82 ± 0.25


 82%|████████▏ | 41/50 [15:33<03:35, 23.95s/it]

Epoch: 40, Inception Score: 1.97 ± 0.24


 84%|████████▍ | 42/50 [15:57<03:11, 23.99s/it]

Epoch: 41, Inception Score: 1.94 ± 0.32


 86%|████████▌ | 43/50 [16:21<02:47, 23.91s/it]

Epoch: 42, Inception Score: 2.02 ± 0.53


 88%|████████▊ | 44/50 [16:44<02:23, 23.86s/it]

Epoch: 43, Inception Score: 2.09 ± 0.34


 90%|█████████ | 45/50 [17:08<01:59, 23.83s/it]

Epoch: 44, Inception Score: 1.94 ± 0.34


 92%|█████████▏| 46/50 [17:32<01:35, 23.78s/it]

Epoch: 45, Inception Score: 1.99 ± 0.25


 94%|█████████▍| 47/50 [17:56<01:11, 23.75s/it]

Epoch: 46, Inception Score: 1.87 ± 0.22


 96%|█████████▌| 48/50 [18:19<00:47, 23.76s/it]

Epoch: 47, Inception Score: 2.01 ± 0.26


 98%|█████████▊| 49/50 [18:43<00:23, 23.76s/it]

Epoch: 48, Inception Score: 1.85 ± 0.21


100%|██████████| 50/50 [19:07<00:00, 22.95s/it]

Epoch: 49, Inception Score: 1.95 ± 0.23
saving ACGAN model to file...





In [8]:
def generate_sample_images_acgan(generator_model):
    noise_vector = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, 100))))
    random_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, 10, batch_size)))
    generated_samples = generator_model(noise_vector, random_labels)
    
    # Normalize the images to [0, 1]
    generated_samples = generated_samples.mul(0.5).add(0.5)
    generated_samples = generated_samples.data.cpu()
    image_grid = utils.make_grid(generated_samples)
    print("Grid of 8x8 images saved to 'acgan_generated_images.png'.")
    utils.save_image(image_grid, 'acgan_generated_images.png')


generate_sample_images_acgan(acgan_generator_instance)

Grid of 8x8 images saved to 'acgan_generated_images.png'.


In [9]:
def load_trained_model_acgan(model_instance, model_path):
    model_instance.load_state_dict(torch.load(model_path))

# Load trained ACGAN models and generate sample images
print("Loading ACGAN model...")
load_trained_model_acgan(acgan_generator_instance, 'acgan_generator.pkl')
load_trained_model_acgan(acgan_discriminator_instance, 'acgan_discriminator.pkl')


Loading ACGAN model...


In [7]:
!pwd

/home/xiaofey/CPSC-8430/Homework#4
