In [None]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms, utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

cudnn.benchmark = True

##  normal gan (worked fine, ig)

###  data formatting

first we have to convert all our rgba images to rgb so pil doesn't freak out

In [None]:
import os
from PIL import Image

dataroot = "data/"

need_conversion = False
if need_conversion:
    for start, dirs, files in os.walk(dataroot):
        for f in files:
            path = os.path.join(start, f)
            ext = os.path.splitext(path)[1]
            if ext in [".png", ".jpg", ".jpeg"]:
                temp_image = Image.open(path)
                background = Image.new("RGBA", temp_image.size, (255, 255, 255))
                temp_image = temp_image.convert("RGBA")
                temp_image = Image.alpha_composite(background, temp_image).convert("RGB")
                temp_image.save(path)


then we make our dataset/dataloader, nothing much here

In [None]:
image_size = 64

dataset = datasets.ImageFolder(
    root=dataroot,
    transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
)

batch_size = 64
workers = 4
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size,
    shuffle=True, num_workers=workers
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(), (1, 2, 0)))

## model init

and here we initialize our generator/discriminator along with their losses/optimizers  
their code is stored in `gans.py`

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
nz = 100  # size of gen input
ngf = 64  # size of feature maps in generator
ndf = 64  # size of feature maps in discriminator
nc = 3  # num of channels

In [None]:
from importlib import reload
import gans

reload(gans)

gen = gans.Generator(nc, nz, ngf).to(device)
gen_path = "models/gan/gen.pth"
gen.apply(weights_init)
# gen.load_state_dict(torch.load(gen_path))


discr = gans.Discriminator(nc, ndf).to(device)
discr_path = "models/gan/discr.pth"
discr.apply(weights_init)
# discr.load_state_dict(torch.load(discr_path))

In [None]:
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

In [None]:
lr = 0.0002
beta1 = 0.5
optim_d = torch.optim.Adam(discr.parameters(), lr=lr, betas=(beta1, 0.999))
optim_g = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta1, 0.999))

## training time

i loved it when pytorch said "it's torchin time" and then torched all over the models

In [None]:
img_list = []
g_loss = []
d_loss = []
iters = 0

num_epochs = 200
for epoch in range(num_epochs):
    tqdm_data = tqdm(dataloader)
    for i, data in enumerate(tqdm_data):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        discr.zero_grad()
        # Format batch
        real_emoji = data[0].to(device)
        b_size = real_emoji.size(0)
        label = torch.full((b_size, 1), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = discr(real_emoji)
        # Calculate loss on all-real batch
        err_d_rl = criterion(output, label)
        # Calculate gradients for D in backward pass
        err_d_rl.backward()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = gen(noise)
        label.fill_(fake_label)
        
        # Classify all fake batch with D
        output = discr(fake.detach())
        # Calculate D's loss on the all-fake batch
        err_d_fk = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_d_fk.backward()
        # Update D
        optim_d.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        gen.zero_grad()
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = discr(fake)
        # Calculate G's loss based on this output
        label.fill_(real_label)  # fake labels are real for generator cost
        err_g = criterion(output, label)
        # Calculate gradients for G
        err_g.backward()
        # Update G
        optim_g.step()

        # Output training stats
        update = [False, ""]
        err_d = err_d_rl + err_d_fk
        if i % 10 == 0:
            update = [
                True,
                f"[{epoch + 1}/{num_epochs}][{i}/{len(dataloader)}]\t"
                f"d_loss: {err_d.item():.4f}\tg_loss: {err_g.item():.4f}"
            ]

        # Save Losses for plotting later
        g_loss.append(err_g.item())
        d_loss.append(err_d.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if iters % 500 == 0 or (epoch == num_epochs - 1 and i == len(dataloader) - 1):
            with torch.no_grad():
                fake = gen(fixed_noise).cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
        if update[0]:
            tqdm_data.set_description(update[1])
    
    torch.save(gen.state_dict(), gen_path)
    torch.save(discr.state_dict(), discr_path)


In [None]:
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(g_loss, label="G")
plt.plot(d_loss, label="D")
plt.xlabel("Iterations")
plt.ylabel("Loss (%)")
plt.legend()
plt.show()

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))

# Plot the fake images from the last epoch
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.show()

In [None]:
import matplotlib.animation as animation
from IPython.display import HTML

fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

# w-gan time (let's hope this works better)

## construct models
and some parameters that the constructors need

In [None]:
nc = 3  # input image channels
nz = 100  # size of the latent z vector
ng = 64
ndf = 64
ngf = 64  # Size of feature maps in generator
n_extra_layers = 2  # help='Number of extra layers on gen and disc

In [None]:
import wgans

wgen_path = "models/wgan/gen.pth"
mlp_g = False
if mlp_g:
    gen = wgans.MLPGenerator(image_size, nz, nc, ngf)
else:
    gen = wgans.DCGANGenerator(image_size, nz, nc, ngf, n_extra_layers)

# gen.load_state_dict(wgen_path)
gen = gen.to(device)

In [None]:
wdiscr_path = "models/wgan/discr.pth"
mlp_d = False
if mlp_d:
    discr = wgans.MLPDiscriminator(image_size, nz, nc, ndf)
else:
    discr = wgans.DCGANDiscriminator(image_size, nc, ndf, n_extra_layers)

# discr.load_state_dict(wdiscr_path)
discr = discr.to(device)

In [None]:
adam = True
lr_d = 0.00005
lr_g = 0.00005
beta1 = 0.5

if adam:
    optim_d = torch.optim.Adam(discr.parameters(), lr=lr_d, betas=(beta1, 0.999))
    optim_g = torch.optim.Adam(gen.parameters(), lr=lr_g, betas=(beta1, 0.999))
else:
    optim_d = torch.optim.RMSprop(discr.parameters(), lr=lr_d)
    optim_g = torch.optim.RMSprop(gen.parameters(), lr=lr_g)

In [None]:
fixed_noise = torch.randn((batch_size, nz, 1, 1), device=device)
one = torch.tensor([1.], device=device)
mone = one * -1

In [None]:
gen_iters = 0
img_list = []
g_loss = []
d_loss = []
epochs = 150
for epoch in range(epochs):
    data_iter = iter(dataloader)
    i = 0
    pbar = tqdm(total=len(dataloader), position=0, leave=True)
    pbar.set_description(
        f"[{epoch}/{epochs}][{i}/{len(dataloader)}][{gen_iters}]\t"
        f"d_loss_avg: NA\tg_loss: NA"
    )
    while i < len(dataloader):
        ############################
        # (1) Update D network
        ###########################
        d_iters = 100 if gen_iters < 25 or gen_iters % 500 == 0 else 5
        
        j = 0
        err_d_avg = torch.tensor(0)
        while j < d_iters and i < len(dataloader):
            j += 1

            data = next(data_iter)
            real_emoji = data[0].to(device)
            b_size = real_emoji.size(0)
            
            i += 1
            pbar.update(1)

            # train with real
            discr.zero_grad()
            err_d_rl = -discr(real_emoji)
            err_d_rl.backward()

            # train with fake
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            with torch.no_grad():
                fake = gen(noise)
            err_d_fk = discr(fake)
            err_d_fk.backward()
            
            optim_d.step()
            err_d_avg += err_d_rl + err_d_fk

        # err_d_avg /= j

        ############################
        # (2) Update G network
        ########################### 
        gen.zero_grad()
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = gen(noise)
        err_g = -discr(fake)
        err_g.backward()
        optim_g.step()

        gen_iters += 1
        d_loss.append(err_d_avg.item())
        g_loss.append(err_g.item())

        pbar.set_description(
            f"[{epoch}/{epochs}][{i}/{len(dataloader)}][{gen_iters}]\t"
            f"d_loss_avg: {err_d_avg.item():.4f}\tg_loss: {err_g.item():.4f}"
        )
        # real_emoji = real_cpu.mul(0.5).add(0.5)  # what ~kevin

        with torch.no_grad():
            fake = gen(fixed_noise).cpu()
            # fake.data = fake.data.mul(0.5).add(0.5)
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

    # do checkpointing
    torch.save(gen.state_dict(), wgen_path)
    torch.save(discr.state_dict(), wdiscr_path)

In [None]:
real_batch = next(iter(dataloader))

plt.figure(figsize=(15, 15))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(), (1, 2, 0)))

# show the fake images again just like last time
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.show()