In [1]:
# Importing Libraries
import torch 
import torch.nn as nn
import torchvision
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


In [2]:
# Parameters
batch_size = 128
img_size = 64
noise_channels = 100
img_channels = 3
learning_rate = 0.0002
num_epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Transforms
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((img_size, img_size)),
    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
# Loading Dataset
dataset = torchvision.datasets.CelebA(root='./data/', download=True, transform=transforms)
dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
dataloader = iter(dataloader)

# Looking at sample images
sample_images, labels = next(dataloader)
def plot_images(images, labels, number):
  for i in range(number):
    image = sample_images[i].permute(1, 2, 0)
    plt.subplot(4, number//4, i+1)
    plt.imshow(image)

plot_images(sample_images, labels, 16)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [5]:
# Designing A Model
class Generator(nn.Module):
  def __init__(self, img_channels, noise_channels):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        nn.ConvTranspose2d(noise_channels, 1024, kernel_size=4, stride=1, padding=0),
        nn.BatchNorm2d(1024),
        nn.ReLU(0.2),
        nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.ReLU(0.2),
        nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(0.2),
        nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(0.2),
        nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.gen(x)

class Discriminator(nn.Module):
  def __init__(self, img_channels):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        nn.Conv2d(img_channels, 128, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),
        nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),
        nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.2),
        nn.Conv2d(1024, 1, kernel_size=4, stride=2, padding=0),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.disc(x)

def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

# img_channels = 3
# noise_channels = 100
# batch_size = 32
# noise = torch.randn((batch_size, noise_channels, 1, 1))
# image = torch.randn((batch_size, img_channels, img_size, img_size))
# gen = Generator(img_channels, noise_channels)
# disc = Discriminator(img_channels)
# gen_output = gen(noise)
# print('Gen Shape: ', gen_output.size())
# disc_output = disc(image)
# print('Disc Shape: ', disc_output.size())

In [6]:
# Initializing Models
gen = Generator(img_channels, noise_channels).to(device)
disc = Discriminator(img_channels).to(device)

# Initializing Weights
initialize_weights(gen)
initialize_weights(disc)

# Loss 
criterion = nn.BCELoss()

# Optimizer
gen_optim = optim.Adam(gen.parameters(), lr=learning_rate)
disc_optim = optim.Adam(disc.parameters(), lr=learning_rate)

# Noise
noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)

In [7]:
disc_losses = []
gen_losses = []
for epoch in range(1):
  for i, (images, labels) in enumerate(dataloader):
    real = images.to(device)
    noise = torch.randn(batch_size, noise_channels, 1, 1).to(device)
    fake = gen(noise)


    # Train Discriminator
    disc_real = disc(real).reshape(-1)
    disc_fake = disc(fake).reshape(-1)
    dLoss_real = criterion(disc_real, torch.ones_like(disc_real))
    dLoss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    disc_loss = (dLoss_real + dLoss_fake) / 2
    
    # Discriminator Backpropagation
    disc_optim.zero_grad()
    disc_loss.backward(retain_graph=True)
    disc_optim.step()

    # Training Generator
    output = disc(fake).reshape(-1)
    gen_loss = criterion(output, torch.ones_like(output))
    
    # Generator Backprop
    gen_optim.zero_grad()
    gen_loss.backward()
    gen_optim.step()

    if i % 100 == 0:
      with torch.no_grad():
        print(f'Epoch {epoch}/{num_epochs} Batch {i}/{len(dataloader)} GenLoss: {gen_loss.item()} DiscLoss: {disc_loss.item()}')
        # Plot Real vs Fake Images.
        gen_losses.append(gen_loss.item())
        disc_losses.append(disc_loss.item())


Epoch 0/5 Batch 0/1272 GenLoss: 0.9090712070465088 DiscLoss: 0.6962864398956299


KeyboardInterrupt: 

In [None]:
plt.plot(disc_losses, color='orange')
plt.plot(gen_losses, color='green')
plt.show()