In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

In [2]:
class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))
        
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g ):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
      self._block(z_dim, features_g*16, 4, 1, 0, bias = False),
      self._block(features_g*16, features_g*8, 4, 2, 1, bias = False),
      self._block(features_g*8, features_g*4, 4, 2, 1, bias = False),
      self._block(features_g*4, features_g*2, 4, 2, 1, bias = False),
      nn.ConvTranspose2d(features_g*2, channels_img, 4, 2, 1, bias = False),
      nn.Tanh(),
    )


  def _block(self, in_channels, out_channels, kernel_size, stride, padding, bias = False):
    return nn.Sequential(
      nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias = False),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
    )

  def forward(self, x):
    return self.gen(x)

In [5]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
      nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      self._block(features_d, features_d*2, 4, 2, 1, bias=False),
      self._block(features_d*2, features_d*4, 4, 2, 1, bias=False),
      self._block(features_d*4, features_d*8, 4, 2, 1, bias=False),
      nn.Conv2d(features_d*8, 1, kernel_size=4, stride=1, padding=0, bias=False),
      nn.Sigmoid(),
    )

  
  def _block(self, in_channels, out_channels, kernel_size, stride, padding, bias = False):
    return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias = False),
      nn.BatchNorm2d(out_channels),
      nn.LeakyReLU(0.2, inplace=True),
    )

  def forward(self, x):
    return self.disc(x)

In [6]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):

      nn.init.normal_(m.weight, mean=0, std=0.02)

def denorm(x):
  out = (x + 1) / 2
  return out.clamp(0, 1)

In [8]:
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image

# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
                       download=True)

# comment mnist above and uncomment below if train on CelebA
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )
            save_image(denorm(fake.detach()), f"images/{batch_idx}.png")

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
            step += 1

Epoch [0/5] Batch 0/469                   Loss D: 0.6927, loss G: 0.7039
Epoch [0/5] Batch 100/469                   Loss D: 0.0145, loss G: 4.1803
Epoch [0/5] Batch 200/469                   Loss D: 0.0043, loss G: 5.3020
Epoch [0/5] Batch 300/469                   Loss D: 0.2967, loss G: 2.0884
Epoch [0/5] Batch 400/469                   Loss D: 0.5023, loss G: 0.8054
Epoch [1/5] Batch 0/469                   Loss D: 0.5258, loss G: 3.4703
Epoch [1/5] Batch 100/469                   Loss D: 0.3241, loss G: 1.6199
Epoch [1/5] Batch 200/469                   Loss D: 0.4841, loss G: 1.2232
Epoch [1/5] Batch 300/469                   Loss D: 0.4298, loss G: 1.4435
Epoch [1/5] Batch 400/469                   Loss D: 0.4864, loss G: 1.5969
Epoch [2/5] Batch 0/469                   Loss D: 0.4349, loss G: 1.8920
Epoch [2/5] Batch 100/469                   Loss D: 0.4429, loss G: 1.5381
Epoch [2/5] Batch 200/469                   Loss D: 0.4915, loss G: 1.0780
Epoch [2/5] Batch 300/469      