In [20]:
import torch
from torch.utils import data
from tqdm.auto import tqdm

from importlib import reload

%matplotlib inline

In [21]:
# import the dataset and loader from data_utils.py
import data_utils

image_folder_path = "data/image"
im_dim = 64

batch_size = 1
# create a dataset so that dataset[i] returns the ith image
rl_data = data_utils.EmojiDataset(image_folder_path, (im_dim, im_dim))
# make a dataloader that returns the images as batches for parallel processing
rl_loader = data.DataLoader(rl_data, batch_size)

In [22]:
import models

reload(models)

generator = models.Generator(im_dim)
discriminator = models.Discriminator(im_dim)

# use the gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)

In [23]:
# Initialize the loss function
criterion = torch.nn.BCELoss()

# Create batch of latent vectors that we will use to visualize the progression of the generator
fixed_noise = torch.randn(64, 100, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# set a learning rate
lr = 0.1

# Setup optimizers for both generator and discriminator
optim_d = torch.optim.AdamW(discriminator.parameters(), lr=lr)
optim_g = torch.optim.AdamW(generator.parameters(), lr=lr)

In [24]:
# functions that save and load the model and optimizer
save_to = "./checkpoints/model.pt"
def save(path, gen, disc, op_g, op_d):
    torch.save(
        {
            "generator_weights" : gen.state_dict(),
            "discriminator_weights" : disc.state_dict(),
            "generator_optimizer_weights" : op_g.state_dict(),
            "discriminator_optimizer_weights" : op_d.state_dict(),
        },
        path
    )


def load(path):
    # initialize 
    dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(path)
    gen = models.Generator().to(dev)
    disc = models.Discriminator().to(dev)

    op_d = torch.optim.Adam(gen.parameters(), lr=lr)
    op_g = torch.optim.Adam(disc.parameters(), lr=lr)

    gen.load_state_dict(checkpoint["generator_weights"])
    disc.load_state_dict(checkpoint["discriminator_weights"])
    op_g.load_state_dict(checkpoint["generator_optimizer_weights"])
    op_d.load_state_dict(checkpoint["discriminator_optimizer_weights"])

    return gen, disc, op_g, op_d

In [31]:
generator.train()
discriminator.train()

rl_tensor = torch.full((batch_size, 1), real_label, device=device)
fk_tensor = torch.full((batch_size, 1), fake_label, device=device)

num_epochs = 1
for epoch in tqdm(range(1, num_epochs + 1)):
    for i, img in enumerate(rl_loader):
        ########################################################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        #######################################################
        ## Train with all-real batch
        # Format batch
        img = img.to(device)

        # Forward pass real batch through D
        optim_d.zero_grad()
        output = discriminator(img)

        # Calculate loss on all-real batch
        d_rl_loss = criterion(output, rl_tensor)

        ## Train with all-fake batch
        # Generate batch of latent vectors
        # shape: [batch size, channel #, side length, side length]
        latent_vec = torch.randn(batch_size, 64, 32, 32, device=device)

        # Generate fake image batch with G
        fake_img = generator(latent_vec)

        # Classify all fake batch with D
        output = discriminator(fake_img)

        # Calculate D's loss on the all-fake batch
        d_fk_loss = criterion(output, fk_tensor)

        # Compute error of D as sum over the fake and the real batches
        d_err = d_rl_loss + d_fk_loss
        d_err.backward(retain_graph=True)

        # Update D
        optim_d.step()

        ########################################################
        # (2) Update G network: maximize log(D(G(z)))
        #######################################################
        
        # Since we just updated D, perform another forward pass of all-fake batch through D
        optim_g.zero_grad()
        output = discriminator(fake_img)  # updated discriminator!

        # Calculate G's loss based on this output
        g_loss = criterion(output, rl_tensor)

        # Calculate gradients for G
        g_loss.backward()

        # Update G
        optim_g.step()

        # # Output training stats
       
        # Save Losses for plotting later

        # Check how the generator is doing by saving G's output on fixed_noise


  0%|          | 0/1 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# generate images from the model