In [1]:
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt

In [2]:
data_set = datasets.CIFAR10(root="./data", download=True, transform=transforms.Compose([
    transforms.Resize(64),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))
data_loader = torch.utils.data.DataLoader(data_set, batch_size = 128, shuffle=True, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

noise_dimension = 100 

real_value = 1 

fake_value = 0 


rand_seed = random.randint(10, 100000)
random.seed(rand_seed)
torch.manual_seed(rand_seed)

<torch._C.Generator at 0x7fb53d4964d0>

In [4]:
def weights_initialize(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, passed_input):
        discriminator_output = self.main(passed_input)
        return discriminator_output.view(-1, 1).squeeze(1)

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(noise_dimension, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, passed_input):
        generator_output = self.main(passed_input)
        return generator_output

In [7]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)
discriminator.apply(weights_initialize)
generator.apply(weights_initialize)

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [8]:
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
disriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()
noise = torch.randn(128, noise_dimension, 1, 1, device=device)
gen_loss_list = []
dis_loss_list = []
counter = 0
counter_list = []

In [10]:
num_epochs = 10
for epoch in range(num_epochs):
    for i, data in enumerate(data_loader, 0):
        counter += 1
        counter_list.append(counter)
        
       
        real_data = data[0].to(device)
        size_of_batch = real_data.size(0)
        labels_tensor = torch.full((size_of_batch,), real_value, device = device).float()
        discriminator.zero_grad()
        dis_output = discriminator(real_data ).float()
        dis_real_error = criterion(dis_output, labels_tensor)
        dis_real_error.backward()
        dis_real_output_mean = dis_output.mean().item()

        
        labels_tensor.fill_(fake_value).float()
        noise = torch.randn(size_of_batch, noise_dimension, 1, 1, device=device)
        fake_data = generator(noise)
        dis_output = discriminator(fake_data.detach()).float()
        dis_fake_error = criterion(dis_output, labels_tensor)
        dis_fake_error.backward()
        dis_fake_output_mean = dis_output.mean().item()
        disriminator_optimizer.step()
        final_dis_error = dis_real_error + dis_fake_error
        dis_loss_list.append(final_dis_error.item())
        

       
        labels_tensor.fill_(real_value).float()
        generator.zero_grad()
        gen_output = discriminator(fake_data).float()
        gen_error = criterion(gen_output, labels_tensor)
        gen_loss_list.append(gen_error.item())
        gen_error.backward()
        gen_output_mean = gen_output.mean().item()
        generator_optimizer.step()

        
        print('[%d/%d][%d/%d] DisLoss: %.4f GenLoss: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % 
              (epoch, num_epochs, i, len(data_loader), final_dis_error.item(), 
               gen_error.item(), dis_real_output_mean, dis_fake_output_mean, gen_output_mean ))
        
  
    fake_data = generator(noise)
    vutils.save_image(real_data,'DCganOutput/real_samples.png',normalize=True)
    vutils.save_image(fake_data.detach(),'DCganOutput/fake_samples_epoch_%03d.png' % (epoch), normalize=True)

[0/10][0/391] DisLoss: 0.4627 GenLoss: 3.9463 D(x): 0.7078 D(G(z)): 0.0411 / 0.0635
[0/10][1/391] DisLoss: 0.3262 GenLoss: 3.5863 D(x): 0.8852 D(G(z)): 0.1456 / 0.0621
[0/10][2/391] DisLoss: 0.4054 GenLoss: 3.7787 D(x): 0.8791 D(G(z)): 0.1803 / 0.0358
[0/10][3/391] DisLoss: 0.3075 GenLoss: 4.2949 D(x): 0.9029 D(G(z)): 0.1611 / 0.0243
[0/10][4/391] DisLoss: 0.3486 GenLoss: 3.2414 D(x): 0.8147 D(G(z)): 0.0873 / 0.0595
[0/10][5/391] DisLoss: 0.3944 GenLoss: 3.1589 D(x): 0.8537 D(G(z)): 0.1820 / 0.0587
[0/10][6/391] DisLoss: 0.4424 GenLoss: 3.3691 D(x): 0.8334 D(G(z)): 0.1766 / 0.0497
[0/10][7/391] DisLoss: 0.4209 GenLoss: 3.4479 D(x): 0.8552 D(G(z)): 0.1837 / 0.0518
[0/10][8/391] DisLoss: 0.4905 GenLoss: 2.5475 D(x): 0.7763 D(G(z)): 0.1380 / 0.1037
[0/10][9/391] DisLoss: 0.3635 GenLoss: 3.5627 D(x): 0.9132 D(G(z)): 0.2163 / 0.0372
[0/10][10/391] DisLoss: 0.4258 GenLoss: 3.1612 D(x): 0.8143 D(G(z)): 0.1428 / 0.0550
[0/10][11/391] DisLoss: 0.4637 GenLoss: 2.6054 D(x): 0.7948 D(G(z)): 0.1546

[0/10][97/391] DisLoss: 0.9351 GenLoss: 2.8239 D(x): 0.4596 D(G(z)): 0.0021 / 0.0843
[0/10][98/391] DisLoss: 0.4604 GenLoss: 2.2571 D(x): 0.8774 D(G(z)): 0.2262 / 0.1300
[0/10][99/391] DisLoss: 0.5898 GenLoss: 5.0801 D(x): 0.9091 D(G(z)): 0.3403 / 0.0107
[0/10][100/391] DisLoss: 0.8086 GenLoss: 1.8898 D(x): 0.5512 D(G(z)): 0.0422 / 0.2140
[0/10][101/391] DisLoss: 0.9898 GenLoss: 5.8564 D(x): 0.9175 D(G(z)): 0.5290 / 0.0068
[0/10][102/391] DisLoss: 0.8972 GenLoss: 2.4159 D(x): 0.5080 D(G(z)): 0.0247 / 0.1229
[0/10][103/391] DisLoss: 0.4617 GenLoss: 3.3032 D(x): 0.9309 D(G(z)): 0.2944 / 0.0494
[0/10][104/391] DisLoss: 0.3042 GenLoss: 3.7848 D(x): 0.8793 D(G(z)): 0.1478 / 0.0295
[0/10][105/391] DisLoss: 0.4379 GenLoss: 2.5136 D(x): 0.7669 D(G(z)): 0.1041 / 0.0997
[0/10][106/391] DisLoss: 0.4032 GenLoss: 4.0190 D(x): 0.8921 D(G(z)): 0.2327 / 0.0273
[0/10][107/391] DisLoss: 0.4118 GenLoss: 3.0438 D(x): 0.7612 D(G(z)): 0.0962 / 0.0592
[0/10][108/391] DisLoss: 0.4475 GenLoss: 5.8833 D(x): 0.9

[0/10][193/391] DisLoss: 0.4766 GenLoss: 4.3651 D(x): 0.6738 D(G(z)): 0.0167 / 0.0227
[0/10][194/391] DisLoss: 0.3919 GenLoss: 3.0462 D(x): 0.8413 D(G(z)): 0.1652 / 0.0724
[0/10][195/391] DisLoss: 0.7133 GenLoss: 6.0731 D(x): 0.8607 D(G(z)): 0.3914 / 0.0040
[0/10][196/391] DisLoss: 0.5310 GenLoss: 3.6008 D(x): 0.6467 D(G(z)): 0.0265 / 0.0414
[0/10][197/391] DisLoss: 0.5060 GenLoss: 4.5267 D(x): 0.9129 D(G(z)): 0.2916 / 0.0166
[0/10][198/391] DisLoss: 0.6351 GenLoss: 2.5158 D(x): 0.6687 D(G(z)): 0.1261 / 0.1068
[0/10][199/391] DisLoss: 1.2425 GenLoss: 9.9853 D(x): 0.9200 D(G(z)): 0.6326 / 0.0002
[0/10][200/391] DisLoss: 2.5073 GenLoss: 3.5732 D(x): 0.1624 D(G(z)): 0.0017 / 0.0623
[0/10][201/391] DisLoss: 0.7301 GenLoss: 4.5715 D(x): 0.9208 D(G(z)): 0.4034 / 0.0191
[0/10][202/391] DisLoss: 0.6741 GenLoss: 4.5019 D(x): 0.7508 D(G(z)): 0.2327 / 0.0244
[0/10][203/391] DisLoss: 0.6509 GenLoss: 4.3210 D(x): 0.7947 D(G(z)): 0.2837 / 0.0224
[0/10][204/391] DisLoss: 0.9085 GenLoss: 2.7683 D(x): 

[0/10][289/391] DisLoss: 0.1379 GenLoss: 3.5316 D(x): 0.8993 D(G(z)): 0.0169 / 0.0497
[0/10][290/391] DisLoss: 0.4109 GenLoss: 4.3684 D(x): 0.9449 D(G(z)): 0.2579 / 0.0210
[0/10][291/391] DisLoss: 0.4215 GenLoss: 2.9885 D(x): 0.7657 D(G(z)): 0.0929 / 0.0812
[0/10][292/391] DisLoss: 0.4127 GenLoss: 4.6826 D(x): 0.9737 D(G(z)): 0.2928 / 0.0150
[0/10][293/391] DisLoss: 0.6122 GenLoss: 3.0580 D(x): 0.6947 D(G(z)): 0.1291 / 0.0732
[0/10][294/391] DisLoss: 0.3918 GenLoss: 4.2461 D(x): 0.9098 D(G(z)): 0.2290 / 0.0221
[0/10][295/391] DisLoss: 0.2620 GenLoss: 4.8160 D(x): 0.9122 D(G(z)): 0.1250 / 0.0124
[0/10][296/391] DisLoss: 0.4896 GenLoss: 2.9354 D(x): 0.7448 D(G(z)): 0.0991 / 0.0702
[0/10][297/391] DisLoss: 0.3991 GenLoss: 4.9963 D(x): 0.9371 D(G(z)): 0.2566 / 0.0098
[0/10][298/391] DisLoss: 0.3160 GenLoss: 4.7403 D(x): 0.8475 D(G(z)): 0.1127 / 0.0127
[0/10][299/391] DisLoss: 0.4511 GenLoss: 3.3419 D(x): 0.7789 D(G(z)): 0.1316 / 0.0506
[0/10][300/391] DisLoss: 0.5947 GenLoss: 5.6776 D(x): 

[0/10][385/391] DisLoss: 0.8528 GenLoss: 2.4299 D(x): 0.6365 D(G(z)): 0.2142 / 0.1521
[0/10][386/391] DisLoss: 0.8080 GenLoss: 3.9618 D(x): 0.8556 D(G(z)): 0.4246 / 0.0318
[0/10][387/391] DisLoss: 0.4740 GenLoss: 3.6591 D(x): 0.7629 D(G(z)): 0.1277 / 0.0384
[0/10][388/391] DisLoss: 0.3729 GenLoss: 3.0106 D(x): 0.8070 D(G(z)): 0.1154 / 0.0684
[0/10][389/391] DisLoss: 0.4875 GenLoss: 4.2411 D(x): 0.9117 D(G(z)): 0.3008 / 0.0238
[0/10][390/391] DisLoss: 0.3473 GenLoss: 4.4891 D(x): 0.8608 D(G(z)): 0.1304 / 0.0228


FileNotFoundError: [Errno 2] No such file or directory: 'DCganOutput/real_samples.png'

In [None]:
plt.plot(counter_list, gen_loss_list, 'r.', label='Generator')
plt.plot(counter_list, dis_loss_list, 'g.', label='Discriminator')
plt.title("DCGAN Loss of Generator and Discriminator ")
plt.xlabel("Batch Number")
plt.ylabel("Binary Cross Entropy Loss")
plt.legend(loc="best")
plt.show()