In [7]:
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision.transforms import Resize, ToTensor, Normalize
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as tt

import matplotlib.pyplot as plt
import numpy as np
#from vae_utils import get_vector_from_label, add_vector_to_images, morph_faces

print(torch.__version__)
print(torchvision.__version__)

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")
device

2.0.0
0.15.1


device(type='mps')

In [16]:
# setting hyperparameters
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01

In [3]:
transforms = tt.Compose(
    [
        tt.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        tt.ToTensor(),
        tt.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
    ]
)

#dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=True)

dataset = datasets.ImageFolder(root="/Users/parkermoesta/datasets/CelebA/img_align_celeba/", transform=transforms)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [23]:
class Discriminator(nn.Module):
    def __init__(self, channel_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input N x channel_img x 64 x 64
            nn.Conv2d(
                channel_img, features_d, kernel_size=4, stride=2, padding=1
            ), # 32x32
            nn.LeakyReLU(0.2),
            self.d_block(features_d, features_d*2, 4, 2, 1), # out: (batch_size, features_d*2, 16, 16)
            self.d_block(features_d*2, features_d*4, 4, 2, 1), # out: (batch_size, features_d*4, 8, 8)
            self.d_block(features_d*4, features_d*8, 4, 2, 1), # out: (batch_size, features_d*8, 4, 4)
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # out: (batch_size, 1, 1, 1)
        )

    def d_block(self, in_channels, out_channels, kernel_size, stride, padding):
        
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )
    
    def forward(self, x):
        return self.disc(x)

In [24]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self.gen_block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
            self.gen_block(features_g*16, features_g*8, 4, 2, 1), # N x f_g*8 x 8 x 8
            self.gen_block(features_g*8, features_g*4, 4, 2, 1), # N x f_g*4 x 16 x 16
            self.gen_block(features_g*4, features_g*2, 4, 2, 1), # N x f_g*2 x 32 x 32
            nn.ConvTranspose2d(
                features_g*2, channels_img, kernel_size=4, stride=2, padding=1
            ), # N x channels_img x 64 x 64
            nn.Tanh(), # [-1, 1]
        )
    
    def gen_block(self, in_channels, out_channels, kernel_size, stride, padding):
        
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.BatchNorm2d(out_channels), # don't need to use bias as batchnorm's learnable parameters
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        return self.gen(x)

In [25]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [26]:
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
# need to initialize weights on generator and discriminator
initialize_weights(gen)
initialize_weights(critic)

In [27]:
# Optimizers 
opt_gen = torch.optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = torch.optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

# loss function
#criterion = nn.BCELoss()

In [28]:
# setting fixed noise for visualization
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")

step = 0 # for printing to tensorboard

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real,_) in enumerate(dataloader):
        real = real.to(device)

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1) # flatten
            critic_fake = critic(fake).reshape(-1) # flatten
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
            critic.zero_grad()
            loss_critic.backward(retain_graph=True) # retain_graph=True to prevent error
            opt_critic.step()

            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP) # clamp is used to clip the weights of the discriminator

        ## Train Generator: min -E[critic(gen_fake)]
        output = critic(fake).reshape(-1)
        loss_gen = -torch.mean(output) # we want to minimize the loss
        loss_gen.backward()
        opt_gen.step()
        # print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
            step += 1


Epoch [0/5] Batch 0/1583                   Loss D: -0.0709, loss G: 0.0147


KeyboardInterrupt: 