In [1]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt


In [2]:
# Get CIFAR10 Dataset
data_set = datasets.CIFAR10(root="./data", download=False, transform=transforms.Compose(
    [transforms.Resize(64),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))
data_loader = torch.utils.data.DataLoader(data_set, batch_size = 128, shuffle=True, num_workers=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        
        self.main = nn.Sequential(
            nn.Conv2d(3,64,4,2,1,bias = False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(64,128,4,2,1,bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(128,256,4,2,1,bias = False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2,True),
            nn.Conv2d(256,512,4,2,1,bias = False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2,True)
        )
        
        self.verify = nn.Sequential(
            nn.Conv2d(512, 1, 4, 1, 0, bias = False), 
            nn.Sigmoid()
        )
        
        self.labels = nn.Sequential(
            nn.Conv2d(512, 11, 4, 1, 0, bias = False), 
            nn.LogSoftmax(dim = 1)
        )
        
    def forward(self, passed_input):
        passed_input = self.main(passed_input)
        validity = self.verify(passed_input)
        output_labels = self.labels(passed_input)
        
        # resize
        validity = validity.view(-1)
        output_labels = output_labels.view(-1,11)
        return validity, output_labels

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        
        self.emb = nn.Embedding(10,100)
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100,512,4,1,0,bias = False),
            nn.ReLU(True),
            nn.ConvTranspose2d(512,256,4,2,1,bias = False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256,128,4,2,1,bias = False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128,64,4,2,1,bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64,3,4,2,1,bias = False),
            nn.Tanh()
        )
        
    def forward(self, noise, inputLabels):
        embLabels = self.emb(inputLabels)
        temp = torch.mul(noise, embLabels)
        temp = temp.view(-1, 100, 1, 1)
        return self.main(temp)

In [6]:
# Create Generator and Discriminator and apply initial weights
discriminator = Discriminator().to(device)
generator = Generator().to(device)
discriminator.apply(weights_init)
generator.apply(weights_init)

Generator(
  (emb): Embedding(10, 100)
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): ReLU(inplace=True)
    (2): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU(inplace=True)
    (5): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (12): Tanh()
  )
)

In [7]:
# Setup optimizers 
dis_optim = optim.Adam(discriminator.parameters(), 0.0002, betas = (0.5,0.999))
gen_optim = optim.Adam(generator.parameters(), 0.0002, betas = (0.5,0.999))
criterion = nn.BCELoss()

# Parameters for training
num_epochs = 10
real_labels_tensor = 0.7 + 0.5 * torch.rand(10, device = device)
fake_labels_tensor = 0.3 * torch.rand(10, device = device)

# Variables to track training progress
counter_list = []
counter = 0
gen_loss_list = []
dis_loss_list = []

In [None]:
# Training algorithm for discriminator and generator
for epoch in range(0, num_epochs):
    
    # Iterate through all batches
    for index, (data, image_labels) in enumerate(data_loader, 0):
        counter += 1
        counter_list.append(counter)
        
        # make data avaialbe for cuda
        data = data.to(device)
        image_labels = image_labels.to(device)
        size_of_batch = data.size(0)
        labels_real = real_labels_tensor[index % 10]
        labels_fake = fake_labels_tensor[index % 10]
        class_labels_fake = 10 * torch.ones((size_of_batch, ), dtype = torch.long, device = device)
        
        # Periodically switch labels
        if index % 25 == 0:
            temp = labels_real
            labels_real = labels_fake
            labels_fake = temp
        
        # Train Discriminator with real data
        labels_for_validate = torch.full((size_of_batch, ), labels_real, device = device)
        dis_optim.zero_grad() 
        validity, output_labels = discriminator(data)       
        dis_real_valid_error = criterion(validity, labels_for_validate)            
        dis_real_label_error = F.nll_loss(output_labels, image_labels)
        dis_real_error = dis_real_valid_error + dis_real_label_error
        dis_real_error.backward()
        valid_mean1 = validity.mean().item()        
        
        # Train Discriminator with fake data
        dis_fake_labels = torch.randint(0, 10, (size_of_batch, ), dtype = torch.long, device = device)
        noise = torch.randn(size_of_batch, 100, device = device)  
        labels_for_validate.fill_(labels_fake)
        fake_output = generator(noise, dis_fake_labels)
        validity, output_labels = discriminator(fake_output.detach())       
        dis_fake_valid_error = criterion(validity, labels_for_validate)
        dis_fake_label_error = F.nll_loss(output_labels, class_labels_fake)
        dis_fake_error = dis_fake_valid_error + dis_fake_label_error
        dis_fake_error.backward()
        final_dis_error = dis_real_error + dis_fake_error
        valid_mean2 = validity.mean().item()
        dis_optim.step()
    
        # Train Generator
        labels_for_validate.fill_(1)
        labels_for_gen = torch.randint(0, 10, (size_of_batch, ), device = device, dtype = torch.long)
        noise = torch.randn(size_of_batch, 100, device = device)  
        gen_optim .zero_grad()
        fake_output = generator(noise, labels_for_gen)
        validity, output_labels = discriminator(fake_output)
        gen_valid_error = criterion(validity, labels_for_validate)        
        gen_label_error = F.nll_loss(output_labels, labels_for_gen)
        final_gen_error = gen_valid_error + gen_label_error
        final_gen_error.backward()
        valid_mean3 = validity.mean().item()
        gen_optim .step()
        
        
        print("[{}/{}] [{}/{}] D(x): [{:.4f}] D(G): [{:.4f}/{:.4f}] GLoss: [{:.4f}] DLoss: [{:.4f}] DLabel: [{:.4f}] "
              .format(epoch, num_epochs, index, len(data_loader), valid_mean1, valid_mean2, valid_mean3, final_gen_error, final_dis_error,
                      dis_real_label_error+ dis_fake_label_error + gen_label_error))
        
        # Save errors for graphing
        dis_loss_list.append(finalDisError.cpu().detach().numpy())
        gen_loss_list.append(final_gen_error.cpu().detach().numpy())
        
    # Save images to folder
    labels = torch.arange(0,10,dtype = torch.long,device = device)
    noise = torch.randn(10,100,device = device)  
    images = generator(noise, labels)
    vutils.save_image(images.detach(),'ACGANOutput/fake_samples_epoch_%03d.png' % (epoch), normalize = True)

[0/10] [0/391] D(x): [0.2531] D(G): [0.8643/0.8847] GLoss: [6.1573] DLoss: [4.1448] DLabel: [8.9870] 
[0/10] [1/391] D(x): [0.3341] D(G): [0.7934/0.3802] GLoss: [7.8041] DLoss: [6.6059] DLabel: [10.5003] 
[0/10] [2/391] D(x): [0.3668] D(G): [0.4489/0.3014] GLoss: [9.9413] DLoss: [5.0902] DLabel: [11.8094] 
[0/10] [3/391] D(x): [0.6655] D(G): [0.5718/0.0997] GLoss: [9.8719] DLoss: [4.3452] DLabel: [10.0154] 
[0/10] [4/391] D(x): [0.7189] D(G): [0.4553/0.0920] GLoss: [9.6159] DLoss: [4.3058] DLabel: [10.0291] 
[0/10] [5/391] D(x): [0.7652] D(G): [0.5160/0.0563] GLoss: [11.4024] DLoss: [4.7737] DLabel: [11.4899] 
[0/10] [6/391] D(x): [0.7401] D(G): [0.4995/0.0381] GLoss: [11.5672] DLoss: [4.2521] DLabel: [10.5491] 
[0/10] [7/391] D(x): [0.5738] D(G): [0.4885/0.0541] GLoss: [11.2612] DLoss: [4.5883] DLabel: [10.8036] 
[0/10] [8/391] D(x): [0.6350] D(G): [0.6278/0.0443] GLoss: [13.2742] DLoss: [4.3422] DLabel: [12.4040] 
[0/10] [9/391] D(x): [0.7759] D(G): [0.5687/0.0109] GLoss: [14.5044] D

In [None]:
# Plot the loss of the generator and the descriminator
plt.plot(counter_list, gen_loss_list, 'r.', label='Generator')
plt.plot(counter_list, dis_loss_list, 'g.', label='Discriminator')
plt.title("ACGAN Loss of Discriminator and Generator")
plt.xlabel("Batch Number")
plt.ylabel("Loss (Binary Cross Entropy)")
plt.legend(loc="best")
plt.show()