In [0]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.autograd import Variable

In [0]:
# Image preprocessing
transform = transforms.Compose([transforms.Resize(64),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                                     std=(0.5, 0.5, 0.5))])

# transform = transforms.Compose([transforms.ToTensor(),
#                                 transforms.Normalize(mean=(0.5, 0.5, 0.5),
#                                                      std=(0.5, 0.5, 0.5))])

In [0]:
stl10 = dsets.STL10(root="./dataLsun",
                   download=True,
                  transform=transform)

# Load CIFAR 10 dataset
cifar_trainset = dsets.CIFAR10(root="./data",
                               train=True,
                               download=True,
                               transform=transform)



Files already downloaded and verified
Files already downloaded and verified


In [0]:
batch_size = 128

# Parameters loader
cifar_loader = DataLoader(cifar_trainset,
                          batch_size=batch_size,
                          shuffle=True)

stl10_loader = DataLoader(stl10,
                         batch_size=batch_size,
                         shuffle=True)

cifar_loader.dataset
stl10_loader.dataset

Dataset STL10
    Number of datapoints: 5000
    Split: train
    Root Location: ./dataLsun
    Transforms (if any): Compose(
                             Resize(size=64, interpolation=PIL.Image.BILINEAR)
                             ToTensor()
                             Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                         )
    Target Transforms (if any): None

In [0]:
for val, _ in stl10_loader:
    print(val.shape)
    vutils.save_image(val,
                              './temp.png',
                              normalize=True)
    break

torch.Size([128, 3, 64, 64])


In [0]:
# Generator model

class Generator(nn.Module):
    def __init__(self, Z_dim, ngf, ncc):
        super(Generator, self).__init__()
        
        self.layers = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(Z_dim, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, ncc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (ncc) x 64 x 64
        )
        
    def forward(self, input_):
        output = self.layers(input_)
        print(output.shape)
        return output

In [0]:
# Discriminator model

class Discriminator(nn.Module):
    def __init__(self, ndf, ncc):
        super(Discriminator, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(ncc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
            # nn.Sigmoid()
            
        )
    
    def forward(self, input_):
        # print("before")
        output = self.layers(input_)
        print(output.shape)
        # print("after")
        return output.view(-1, 1).squeeze(1)

In [0]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [0]:
# Dimension of the latent space
Z_dim = 128

# Number of color channel in the final image
ncc = 3

# Number of ?
ngf, ndf = 64, 64

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

cuda:0


In [0]:
netG = Generator(Z_dim, ngf, ncc).to(device)
netG.apply(weights_init)
print(netG)

netD = Discriminator(ndf, ncc).to(device)
netD.apply(weights_init)
print(netD)

Generator(
  (layers): Sequential(
    (0): ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
Discriminator(
  (layers): Sequential(
    (0)

In [0]:
# Setup loss function
criterion = nn.BCELoss()
criterion = nn.MSELoss()

fixed_noise = torch.randn(batch_size, Z_dim, 1, 1, device=device)
real_label = 1
fake_label = 0

learning_rate = 0.0002
nb_epochs = 30

# Setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [0]:
for epoch in range(nb_epochs):
    for i, data in enumerate(cifar_loader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, device=device)

        output = netD(real_cpu)
        # print(output)
        # print(label)
        errD_real = criterion(output, label) # 0.5 * torch.mean((output-label)**2) # criterion(output, label)
        errD_real.backward()
        
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, Z_dim, 1, 1, device=device)
        
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        
        errD_fake = criterion(output, label) # 0.5 * torch.mean((output-label)**2) # criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label) # 0.5 * torch.mean((output-label)**2) # criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        print(f'[{epoch}/{nb_epochs}][{i}/{len(cifar_loader)}]\
              Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f}\
              D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')
        if i % 100 == 0:
            vutils.save_image(real_cpu,
                              './real_samples.png',
                              normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                              f'./fake_samples_epoch_{epoch}.png',
                              normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), f'./netG_epoch_{epoch}.pth')
    torch.save(netD.state_dict(), f'./netD_epoch_{epoch}.pth')
        

torch.Size([128, 1, 1, 1])
torch.Size([128, 3, 64, 64])
torch.Size([128, 1, 1, 1])
torch.Size([128, 1, 1, 1])
[0/30][0/391]              Loss_D: 2.4020 Loss_G: 12.4874              D(x): 0.4877 D(G(z)): 0.3675 / -2.5063
torch.Size([128, 3, 64, 64])
torch.Size([128, 1, 1, 1])
torch.Size([128, 3, 64, 64])
torch.Size([128, 1, 1, 1])
torch.Size([128, 1, 1, 1])
[0/30][1/391]              Loss_D: 10.5672 Loss_G: 33.4812              D(x): 2.4051 D(G(z)): 2.0548 / -4.7148
torch.Size([128, 1, 1, 1])
torch.Size([128, 3, 64, 64])
torch.Size([128, 1, 1, 1])
torch.Size([128, 1, 1, 1])
[0/30][2/391]              Loss_D: 7.6727 Loss_G: 3.7971              D(x): -1.1091 D(G(z)): -0.4139 / -0.8127
torch.Size([128, 1, 1, 1])
torch.Size([128, 3, 64, 64])
torch.Size([128, 1, 1, 1])
torch.Size([128, 1, 1, 1])
[0/30][3/391]              Loss_D: 8.9395 Loss_G: 21.2113              D(x): 0.3443 D(G(z)): 2.4656 / -3.5107
torch.Size([128, 1, 1, 1])
torch.Size([128, 3, 64, 64])
torch.Size([128, 1, 1, 1])
torch.

KeyboardInterrupt: ignored