In [1]:
import torch
import torchvision
from torchvision import utils
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
from pytorch_gan_metrics import get_inception_score
from tqdm import tqdm
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
batch_size = 64
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Resize(32), 
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [3]:

train_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=True, download=True, transform=transform)
test_CIFAR10_set = torchvision.datasets.CIFAR10(root='./cifar10/', train=False, download=True, transform=transform)

train_CIFAR10_dataloader = DataLoader(train_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)
test_CIFAR10_dataloader = DataLoader(test_CIFAR10_set, batch_size=batch_size, shuffle=True, drop_last=True)

print('#' * 40)
print("CIFAR10 dataloader Generated")


Files already downloaded and verified
Files already downloaded and verified
########################################
CIFAR10 dataloader Generated


In [4]:
class Generator_DCGAN(nn.Module):
    def __init__(self): 
        super(Generator_DCGAN, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(in_channels=100, out_channels=1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(num_features=1024),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.ConvTranspose2d(in_channels=256, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh())

    def forward(self, input_tensor):
        return self.network(input_tensor)

class Discriminator_DCGAN(nn.Module):
    def __init__(self):
        super(Discriminator_DCGAN, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=1024),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid())

    def forward(self, input_tensor):
        return self.network(input_tensor)

print("Instantiating DCGAN generator and discriminator...")
dcgan_generator = Generator_DCGAN()
dcgan_discriminator = Discriminator_DCGAN()
dcgan_generator.to(device)
dcgan_discriminator.to(device)
print("Models moved to device.")



Instantiating DCGAN generator and discriminator...
Models moved to device.


In [7]:
learning_rate = 0.0002
epochs = 50

def train(generator_model, discriminator_model, data_loader):
    bce_loss = nn.BCELoss()
    optimizer_gen = torch.optim.Adam(generator_model.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_disc = torch.optim.Adam(discriminator_model.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    if not os.path.exists('train_generated_images_dcgan/'): 
        os.makedirs('train_generated_images_dcgan')
        
    score_log_file = open("inception_score_dcgan.csv", "w")
    score_log_file.write('epoch, inception_score \n')

    for epoch in tqdm(range(epochs)): 
        for real_images_batch, _ in data_loader:
            real_images_batch = real_images_batch.to(device)
            noise_vector = Variable(torch.randn(batch_size, 100, 1, 1)).to(device)
            labels_real = torch.ones(batch_size).to(device)
            labels_fake = torch.zeros(batch_size).to(device)

            ### train discriminator
            # compute loss using real images
            preds_real = discriminator_model(real_images_batch)
            loss_disc_real = bce_loss(preds_real.flatten(), labels_real)

            # compute loss using fake images
            fake_images_batch = generator_model(noise_vector)
            preds_fake = discriminator_model(fake_images_batch)
            loss_disc_fake = bce_loss(preds_fake.flatten(), labels_fake)

            # optimize discriminator
            loss_disc_total = loss_disc_real + loss_disc_fake
            discriminator_model.zero_grad()
            loss_disc_total.backward()
            optimizer_disc.step()

            ### train generator
            # compute loss with fake images
            noise_vector = Variable(torch.randn(batch_size, 100, 1, 1)).to(device)
            fake_images_batch = generator_model(noise_vector)
            preds_gen = discriminator_model(fake_images_batch)
            loss_gen = bce_loss(preds_gen.flatten(), labels_real)

            # optimize generator 
            generator_model.zero_grad()
            loss_gen.backward()
            optimizer_gen.step()

        # compute inception score and samples every epoch
        noise_sample = Variable(torch.randn(800, 100, 1, 1)).to(device)
        samples_generated = generator_model(noise_sample)
        samples_generated = samples_generated.mul(0.5).add(0.5)  # Normalize to [0, 1]

        assert 0 <= samples_generated.min() and samples_generated.max() <= 1
        inception_score, inception_std_dev = get_inception_score(samples_generated)
        print(f"Epoch: {epoch}, Inception Score: {round(inception_score, 2)} ± {round(inception_std_dev, 2)}")

        grid_image = utils.make_grid(samples_generated[:64].data.cpu())
        utils.save_image(grid_image, f'train_generated_images_dcgan/epoch_{epoch}.png')
        
        score_log_file.write(f'{epoch}, {round(inception_score, 2)}\n')

    score_log_file.close()


In [8]:
# train DCGAN
print("training DCGAN model...")
train(dcgan_generator, dcgan_discriminator, train_CIFAR10_dataloader)


print("saving DCGAN model to file...")
torch.save(dcgan_generator.state_dict(), 'dcgan_generator.pkl')
torch.save(dcgan_discriminator.state_dict(), 'dcgan_discriminator.pkl')

training DCGAN model...


  2%|▏         | 1/50 [01:12<59:17, 72.60s/it]

Epoch: 0, Inception Score: 2.32 ± 0.12


  4%|▍         | 2/50 [02:25<58:00, 72.51s/it]

Epoch: 1, Inception Score: 2.54 ± 0.18


  6%|▌         | 3/50 [03:37<56:47, 72.49s/it]

Epoch: 2, Inception Score: 2.51 ± 0.11


  8%|▊         | 4/50 [04:49<55:34, 72.48s/it]

Epoch: 3, Inception Score: 2.73 ± 0.2


 10%|█         | 5/50 [06:02<54:21, 72.47s/it]

Epoch: 4, Inception Score: 3.18 ± 0.14


 12%|█▏        | 6/50 [07:14<53:08, 72.46s/it]

Epoch: 5, Inception Score: 3.47 ± 0.25


 14%|█▍        | 7/50 [08:27<51:55, 72.46s/it]

Epoch: 6, Inception Score: 3.26 ± 0.18


 16%|█▌        | 8/50 [09:39<50:43, 72.46s/it]

Epoch: 7, Inception Score: 3.28 ± 0.28


 18%|█▊        | 9/50 [10:52<49:30, 72.45s/it]

Epoch: 8, Inception Score: 3.68 ± 0.27


 20%|██        | 10/50 [12:04<48:17, 72.45s/it]

Epoch: 9, Inception Score: 3.95 ± 0.36


 22%|██▏       | 11/50 [13:17<47:05, 72.44s/it]

Epoch: 10, Inception Score: 3.75 ± 0.22


 24%|██▍       | 12/50 [14:29<45:53, 72.45s/it]

Epoch: 11, Inception Score: 3.59 ± 0.21


 26%|██▌       | 13/50 [15:42<44:40, 72.45s/it]

Epoch: 12, Inception Score: 3.56 ± 0.22


 28%|██▊       | 14/50 [16:54<43:28, 72.45s/it]

Epoch: 13, Inception Score: 3.67 ± 0.34


 30%|███       | 15/50 [18:06<42:15, 72.44s/it]

Epoch: 14, Inception Score: 4.06 ± 0.24


 32%|███▏      | 16/50 [19:19<41:03, 72.45s/it]

Epoch: 15, Inception Score: 4.04 ± 0.26


 34%|███▍      | 17/50 [20:31<39:50, 72.45s/it]

Epoch: 16, Inception Score: 4.34 ± 0.4


 36%|███▌      | 18/50 [21:44<38:39, 72.48s/it]

Epoch: 17, Inception Score: 4.25 ± 0.28


 38%|███▊      | 19/50 [22:56<37:27, 72.50s/it]

Epoch: 18, Inception Score: 4.19 ± 0.36


 40%|████      | 20/50 [24:09<36:14, 72.48s/it]

Epoch: 19, Inception Score: 4.62 ± 0.23


 42%|████▏     | 21/50 [25:21<35:01, 72.47s/it]

Epoch: 20, Inception Score: 4.24 ± 0.35


 44%|████▍     | 22/50 [26:34<33:48, 72.46s/it]

Epoch: 21, Inception Score: 4.49 ± 0.21


 46%|████▌     | 23/50 [27:46<32:36, 72.45s/it]

Epoch: 22, Inception Score: 4.69 ± 0.39


 48%|████▊     | 24/50 [28:59<31:23, 72.45s/it]

Epoch: 23, Inception Score: 4.47 ± 0.3


 50%|█████     | 25/50 [30:11<30:11, 72.46s/it]

Epoch: 24, Inception Score: 4.87 ± 0.44


 52%|█████▏    | 26/50 [31:24<28:58, 72.45s/it]

Epoch: 25, Inception Score: 4.93 ± 0.43


 54%|█████▍    | 27/50 [32:36<27:46, 72.45s/it]

Epoch: 26, Inception Score: 5.07 ± 0.41


 56%|█████▌    | 28/50 [33:48<26:33, 72.45s/it]

Epoch: 27, Inception Score: 5.02 ± 0.44


 58%|█████▊    | 29/50 [35:01<25:21, 72.44s/it]

Epoch: 28, Inception Score: 5.12 ± 0.6


 60%|██████    | 30/50 [36:13<24:08, 72.45s/it]

Epoch: 29, Inception Score: 4.94 ± 0.37


 62%|██████▏   | 31/50 [37:26<22:56, 72.44s/it]

Epoch: 30, Inception Score: 4.82 ± 0.39


 64%|██████▍   | 32/50 [38:38<21:43, 72.43s/it]

Epoch: 31, Inception Score: 5.12 ± 0.29


 66%|██████▌   | 33/50 [39:51<20:31, 72.42s/it]

Epoch: 32, Inception Score: 4.96 ± 0.4


 68%|██████▊   | 34/50 [41:03<19:19, 72.45s/it]

Epoch: 33, Inception Score: 4.98 ± 0.5


 70%|███████   | 35/50 [42:15<18:06, 72.42s/it]

Epoch: 34, Inception Score: 5.14 ± 0.24


 72%|███████▏  | 36/50 [43:28<16:53, 72.41s/it]

Epoch: 35, Inception Score: 5.32 ± 0.61


 74%|███████▍  | 37/50 [44:40<15:41, 72.41s/it]

Epoch: 36, Inception Score: 5.19 ± 0.36


 76%|███████▌  | 38/50 [45:53<14:28, 72.39s/it]

Epoch: 37, Inception Score: 5.13 ± 0.38


 78%|███████▊  | 39/50 [47:05<13:16, 72.38s/it]

Epoch: 38, Inception Score: 5.08 ± 0.55


 80%|████████  | 40/50 [48:17<12:03, 72.38s/it]

Epoch: 39, Inception Score: 5.15 ± 0.43


 82%|████████▏ | 41/50 [49:30<10:51, 72.38s/it]

Epoch: 40, Inception Score: 5.38 ± 0.58


 84%|████████▍ | 42/50 [50:42<09:38, 72.37s/it]

Epoch: 41, Inception Score: 5.2 ± 0.37


 86%|████████▌ | 43/50 [51:54<08:26, 72.37s/it]

Epoch: 42, Inception Score: 5.04 ± 0.3


 88%|████████▊ | 44/50 [53:07<07:14, 72.36s/it]

Epoch: 43, Inception Score: 5.13 ± 0.42


 90%|█████████ | 45/50 [54:19<06:01, 72.35s/it]

Epoch: 44, Inception Score: 5.37 ± 0.3


 92%|█████████▏| 46/50 [55:31<04:49, 72.36s/it]

Epoch: 45, Inception Score: 5.49 ± 0.26


 94%|█████████▍| 47/50 [56:44<03:37, 72.36s/it]

Epoch: 46, Inception Score: 5.09 ± 0.28


 96%|█████████▌| 48/50 [57:56<02:24, 72.37s/it]

Epoch: 47, Inception Score: 5.34 ± 0.3


 98%|█████████▊| 49/50 [59:09<01:12, 72.39s/it]

Epoch: 48, Inception Score: 4.96 ± 0.36


100%|██████████| 50/50 [1:00:21<00:00, 72.43s/it]

Epoch: 49, Inception Score: 5.3 ± 0.43
saving DCGAN model to file...





In [12]:
def generate_best_sample_images(generator_model, discriminator_model, num_images=100):
    noise_vector = torch.randn(num_images, 100, 1, 1).to(device)
    generated_samples = generator_model(noise_vector)
    generated_samples = generated_samples.mul(0.5).add(0.5)  # Normalize images to [0, 1]
    
    # Evaluate images with the discriminator to get scores
    with torch.no_grad():
        scores = discriminator_model(generated_samples).view(-1)
    
    # Sort images by discriminator scores in descending order and pick the top 10
    _, top_indices = torch.topk(scores, 10)
    best_samples = generated_samples[top_indices]

    # Move the best samples to CPU and make a grid
    best_samples = best_samples.data.cpu()
    image_grid = utils.make_grid(best_samples, nrow=5)  # Arrange grid with 5 images per row
    print("Grid of the 10 best images saved to 'dcgan_best_images.png'.")
    utils.save_image(image_grid, 'dcgan_best_images.png')
def load_trained_model(model, model_filename): 
    model.load_state_dict(torch.load(model_filename))
print("Loading DCGAN model...")
load_trained_model(dcgan_generator, 'dcgan_generator.pkl')
load_trained_model(dcgan_discriminator, 'dcgan_discriminator.pkl')
generate_best_sample_images(dcgan_generator, dcgan_discriminator)

Loading DCGAN model...
Grid of the 10 best images saved to 'dcgan_best_images.png'.
