This workshop is adapted from an original resource available [here](https://realpython.com/generative-adversarial-networks/). Many thanks to the original authors!

# Step 1: Generate Training Data

GANs are amazing. They can produce stunning art, realistic photographs, and 3D objects. Generally GANs are used for visual tasks like generating images, but in principle they can be used for much more.

Fundamentally, GANs just look at example data and try to generate new data that looks the same. So when you have a set of example data points, you can use a GAN to create more.

---

Training a GAN takes a long time, so it will be impossible to train an image-generating GAN during this workshop. (Although there is some code at the bottom of this notebook that trains the computer to draw the digits 0-9. It takes about half an hour to start getting interesting results, so it's an excellent thing to try on your own time.)

Instead, our simple GAN will generate points in 2D space that fall on a sine wave. Essentially, this means that our GAN is going to "learn" the sine function. By the time we're done, you will be able to feed it random input points and it will spit out points that are on the sine curve.

Is it the most riveting example ever? Perhaps not. But it demonstrates all the basic tools you need to create GANs that solve more complex problems; the only additional ingredient you need on top of what we're doing today is more training time.

---

Before doing anything else, let's import the libraries we need so that we have them from now on. Run the following code without changing anything:

In [None]:
import torch
from torch import nn

import math
import matplotlib.pyplot as plt

So how do we generate our training data? We want to create a big list of 2D points $(x, y)$ which are all on the sine wave. Let's start by making a list of points that are all $(0, 0)$ and go from there.

The following code generates a list of four zero points. Can you change it to generate a list of 1024 points?

In [None]:
# CHANGE ME! I need to generate 1024 points, not 4.

train_data_length = 4
train_data = torch.zeros((train_data_length, 2))

print("Training data:")
print(train_data)

print("Number of points:", len(train_data))

Good start. But of course, the points shouldn't all be $(0, 0)$. We want our points to be randomly distributed between $x = 0$ and $x = 2\pi$.

The following code generates a list of 1024 different numbers between 0 and 1:

In [None]:
torch.rand(train_data_length)

And this code generates 1024 different numbers between $0$ and $\pi$:

In [None]:
math.pi * torch.rand(train_data_length)

Can you write code that generates 1024 different numbers between $0$ and $2\pi$?

In [None]:
# TODO: Generate 1024 numbers between 0 and 2*pi

Amazing! Now... remember `train_data`, our list of points that are all $(0, 0)$? Let's update that list by replacing all the x values with our new list of random numbers.

Replace the `# ???` in the following code with your generation code (from $0$ to $2\pi$) from above.

If all goes well, you should see that the x values have been randomized but the y values are still 0.

In [None]:
# Change the x values of the train_data points to be random
train_data[:, 0] = # ???

print(train_data)

Let's graph a scatterplot of the points to see if they look right:

In [None]:
plt.plot(train_data[:, 0], train_data[:, 1], ".")

Cool! You should see a horizontal line from $x = 0$ to $x = 2\pi \approx 6.28$.

Now we just need to set the y values. The following code will set the y values to be the cosine of each x value.

Can you change it to `sin`?

In [None]:
# This code sets the y values to cos(x). Can you change it to sin(x)?
train_data[:, 1] = torch.cos(train_data[:, 0])

print(train_data)

Check your results, and make sure it looks like a sine wave (not cosine):

In [None]:
plt.plot(train_data[:, 0], train_data[:, 1], ".")

Amazing! This will be our training data for the GAN. The GAN's job will be to take in random points as creative fuel, and spit out points that are on this sine wave. Hopefully it can learn this task!

# Step 2: Train on Sine Wave

Every GAN is made up of two competing neural networks: The generator and the discriminator. To keep things (relatively) simple, we aren't going to go into great detail about how these networks work, and I'm not going to ask you to design them yourself.

Just know that the following code creates two models, `generator` and `discriminator`, each of which is a simple neural network to learn its respective task.

(You do not need to change this code at all; just run it.)

In [None]:
# Create the generator
# (Note that the definition inside __init__ begins with a 2 and ends with a 2,
# so the model takes a 2D point as input and gives a 2D point as output.)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 16), # Take 2D point as input...
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 2), # ...and give 2D point as output
        )

    def forward(self, x):
        output = self.model(x)
        return output

generator = Generator()

In [None]:
# Create the discriminator
# (Note that the definition inside __init__ begins with a 2 and ends with a 1,
# so the model takes a 2D point as input and gives one number--the prediction for
# whether or not the point is real--as output.)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 256), # Take 2D point as input...
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1), # ...and give one number as output
            nn.Sigmoid(),
        )

    def forward(self, x):
        output = self.model(x)
        return output

discriminator = Discriminator()

## A slow-motion step

During the presentation, we created this diagram of a single training step for a GAN:

<center>
  <img src="https://i.imgur.com/iud7C21.png" alt="GAN Diagram" width="500" />
</center>

Eventually, we want to write code that performs this entire process many, many times inside a `for` loop. But let's start by walking through each part of the process one bit at a time. (We'll put it all together at the end.)

---

To start, let's generate some "latent space" samples. In our diagram, these look like TV static:

<center>
  <img src="https://i.imgur.com/7pdjDyX.png" alt="GAN Diagram with latent space labels highlighted" width="400" />
</center>

They are essentially just random input values (which is why it looks like TV static for a GAN that generates images). In our case, a random input value is just a random point in 2D space (which is not necessarily on the sine wave). We can generate these using the `torch.randn()` function, which genreates an array of random values with a given shape.

The following code generates a batch of `5` random `2`-dimensional points. Try changing the code to make `batch_size` be 32 instead, and check to see that 32 points appear.

In [None]:
# TODO: Change `batch_size` from 5 to 32
batch_size = 5

# Generate latent space samples
latent_space_samples = torch.randn((batch_size, 2))

# Plot the points
plt.plot(latent_space_samples[:, 0], latent_space_samples[:, 1], ".")

Great! Now let's take our `latent_space_samples` and pass them through the `generator` to get the `generated_samples`.

<center>
  <img src="https://i.imgur.com/T28WLV2.png" alt="GAN Diagram with generator process highlighted" width="400" />
</center>

The following code does just that:

In [None]:
# Use the `generator` to get `generated_samples` from `latent_space_samples`
generated_samples = generator(latent_space_samples)

# Plot the results
plt.plot(generated_samples.detach()[:, 0], generated_samples.detach()[:, 1], ".")

Right now, the points being produced just look like new random garbage. But as the generator trains and improves, it will start to produce points that actually land on the sine wave.

---

Next, we want to grab a batch of real samples from the original dataset of sine wave points we created in step 1.

<center>
  <img src="https://i.imgur.com/r2OIQLt.png" alt="GAN Diagram with real sampling highlighted" width="400" />
</center>

The pytorch library gives us a useful tool called a `DataLoader` that will do this for us. The following code creates a `DataLoader` called `train_loader` (because it loads the training examples). Then we can use a `for` loop to loop through each batch it gives us.

Edit the following code so that inside the `for` loop it prints out some useful information. (Your goal is to understand what `index` and `real_samples` are.)

In [None]:
# Create a DataLoader called `train_loader` that will give us batches of real samples
train_loader = torch.utils.data.DataLoader(
  train_data, batch_size=batch_size, shuffle=True
)

# Then we can iterate through the batches using a `for` loop:
for index, real_samples in enumerate(train_loader):
  # TODO: Print out some useful information. Try printing the index and real_samples

Now that we have the ability to get `generated_samples` and `real_samples`, it's time to put the `descriminator` to work:

<center>
  <img src="https://i.imgur.com/f2KA7AB.png" alt="GAN Diagram with discriminator highlighted" width="400" />
</center>

First, let's combine `real_samples` (a list of 32 points in 2D space) and `generated_samples` (another list of 32 points in 2D space) into one big list called `all_samples`. This will be what we give the discriminator as a test.

We'll use the `torch.cat()` concatenation function to combine the lists:

In [None]:
# Show that `real_samples` and `generated_samples` are both lists of 32 points in 2D space:
print("Shape of real samples:", real_samples.shape)
print("Shape of generated samples:", generated_samples.shape)

# Then merge the two into one big list of samples...
all_samples = torch.cat((real_samples, generated_samples))

# ...and check that the combined list has 64 points in 2D space:
print("Shape of ALL samples (combined):", all_samples.shape)

Now that we have our test questions, we can ask the descriminator to give us its report:

In [None]:
discriminator_result = discriminator(all_samples)

print(discriminator_result)

As you can see, the output is a big list of 64 numbers. Each number corresponds to one of the samples from `all_samples`, the "test" we gave the discriminator. Each number is the probability the descriminator thinks each sample has of being real.

**What do you notice about the probabilities?**

Most likely, they will all be around 50%. That's because the discriminator hasn't learned anything yet, so it is basically just guessing. Over time, the discriminator will improve. (Although the generator will also get better at being tricky; it's a bit of an arms race.)

---

Finally, we need to build an "answer key" to the test, so that the discriminator and generator can both look at the results and improve themselves during backpropagation.

In machine learning lingo, this answer key is called the "labels" of the dataset, because we're labelling each of the samples with whether it is actually real or generated.

The following code should create `real_samples_labels`, which is a list of ones (because the label 1 means "real") and `generated_samples_labels` which is a list of zeros (because the label 0 means "generated"). Then they can be combined to give `all_samples_labels`, our answer key to the overall test.

You will need to modify the code below to do just that:

In [None]:
# Create `real_samples_labels`, a list of 32 ones:
real_samples_labels = torch.ones((batch_size, 1))

# TODO: Use the function torch.zeros(), which is just like torch.ones(),
# to create `generated_samples_labels`:
generated_samples_labels = # ???

# TODO: Combine `real_samples_labels` and `generated_samples_labels` into a single list
# called `all_samples_labels` using the function torch.cat() in the exact same way as
# we used it before to combine `real_samples` and `generated_samples`
all_samples_labels = # ???

print(all_samples_labels)

Amazing! Hopefully you see a list with a bunch of 1s and then a bunch of 0s. If so, you've made the answer key (`all_samples_labels`) correctly.

## Actual Training
At this point, you've performed all the individual processes that go into one training step. Now it's just a matter of putting it all together.

The following code, when completed, will run 300 "epochs" of training (which means running the training process on the entire dataset 300 times).

Your job is to fill in the missing blanks, indicated by `# ???`, by copying the code you wrote above. Once you've done that, you should be able to run this entire massive code block and watch the generator and discriminator learn.

In [None]:
history_of_results = []

lr = 0.001 # learning rate
loss_function = nn.BCELoss() # binary cross-entropy loss function

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

for epoch in range(300):
  # This loop gets batches of `real_samples` from the training dataset
  for n, real_samples in enumerate(train_loader):
    # Get `generated_samples` by plugging latent space samples (TV static) into the generator
    latent_space_samples = # ???
    generated_samples = # ???

    # Combine `real_samples` and `generated_samples` into one list
    all_samples = # ???

    # Create labels (the "answer key") for the data
    real_samples_labels = # ???
    generated_samples_labels = # ???
    all_samples_labels = # ???

    # Train the discriminator
    discriminator.zero_grad()
    discriminator_result = # ???
    loss_discriminator = loss_function(discriminator_result, all_samples_labels)
    loss_discriminator.backward(retain_graph=True)
    optimizer_discriminator.step()

    # Train the generator
    generator.zero_grad()
    output_discriminator_generated = discriminator(generated_samples)
    loss_generator = loss_function(
      output_discriminator_generated, real_samples_labels # optimize for tricking the descriminator into thinking these are real samples
    )
    loss_generator.backward()
    optimizer_generator.step()

  # Show loss once every 10 epochs
  if epoch % 10 == 0:
    print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
    print(f"Epoch: {epoch} Loss G.: {loss_generator}")

  # Save data to build a graph once every 50 epochs
  if epoch % 50 == 0:
    latent_space_samples = torch.randn(100, 2)
    generated_samples = generator(latent_space_samples)
    generated_samples = generated_samples.detach()
    history_of_results.append(generated_samples)

Once the training is done, run the following code to visualize the generator output over the course of training (i.e. how it learned over time):

In [None]:
from matplotlib.pyplot import figure

f, axs = plt.subplots(1, len(history_of_results))

for i, generated_samples in enumerate(history_of_results):
  axs[i].set_title("Epoch {}".format((i + 1) * 50))
  axs[i].plot(history_of_results[i][:, 0], history_of_results[i][:, 1], ".")

f.set_size_inches(24, 3)

Hopefully, you will se that in the beginning (epoch 50), the generator was really bad and basically just smeared the points out on the screen. But by the end, it should be arranging the points into a sine wave.

You can run this code to give the generator one final test, and make sure it arranges points correctly:

In [None]:
# Final test of generator
latent_space_samples = torch.randn(100, 2)
generated_samples = generator(latent_space_samples)
generated_samples = generated_samples.detach()
plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")

Also, as a little curiosity, you can check out what the discriminator has learned. Hopefully it can successfully identify where the sine wave is and is not:

In [None]:
# Final test of discriminator
xs = torch.linspace(0, 2 * math.pi, steps=314)
ys = torch.linspace(-1, 1, steps=100)
x, y = torch.meshgrid(xs, ys, indexing='xy')
z = discriminator(torch.stack((x.flatten(), y.flatten()), 1))
Z = z.reshape((100, 314))

plt.imshow(Z.detach(), interpolation='bilinear', origin='lower', extent=[0,2*math.pi,-1,1])

Usually it will be very good on the left and a little messy on the right. With more training time, this would slowly improve.

---

You have now successfully trained a GAN! If you want something fun to do during the workshop, you can try going back to step 1 and generating training data for some new curve (maybe cosine, or some fun polynomial). Then re-run the training process and see if the GAN can learn to produce points on that new curve.

---

And, if you have time to run code at home, you can try out the code below. It isn't commented, but it does essentially the exact same thing as the code above. The only difference is that it learns to produce handwritten digit images rather than sine wave points.

Here is a GIF showing what that result will look like:

<center>
  <img src="https://files.realpython.com/media/fig_gan_mnist.5d8784a85944.gif" alt="GAN digit training result over time" width="500" />
</center>

# Step 3 (Take-Home Bonus): ~Hand~written Digits

In [None]:
import torch
from torch import nn

import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

In [None]:
device = ""
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

In [None]:
train_set = torchvision.datasets.MNIST(
    root=".", train=True, download=True, transform=transform
)

In [None]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

In [None]:
real_samples, mnist_labels = next(iter(train_loader))
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    plt.imshow(real_samples[i].reshape(28, 28), cmap="gray_r")
    plt.xticks([])
    plt.yticks([])

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), 784)
        output = self.model(x)
        return output

discriminator = Discriminator().to(device=device)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )

    def forward(self, x):
        output = self.model(x)
        output = output.view(x.size(0), 1, 28, 28)
        return output

generator = Generator().to(device=device)

In [None]:
lr = 0.0001
num_epochs = 20
loss_function = nn.BCELoss()

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

This next code is the big training step. It will take a long time... (It ends after 20 epochs.)

In [None]:
for epoch in range(num_epochs):
    for n, (real_samples, mnist_labels) in enumerate(train_loader):
        # Data for training the discriminator
        real_samples = real_samples.to(device=device)
        real_samples_labels = torch.ones((batch_size, 1)).to(
            device=device
        )
        latent_space_samples = torch.randn((batch_size, 100)).to(
            device=device
        )
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1)).to(
            device=device
        )
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels
        )
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 100)).to(
            device=device
        )

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        optimizer_generator.step()

        # Show loss
        if n == batch_size - 1:
            print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
            print(f"Epoch: {epoch} Loss G.: {loss_generator}")

In [None]:
torch.save(discriminator.state_dict(), "discriminator")
torch.save(generator.state_dict(), "generator")

In [None]:
latent_space_samples = torch.randn(batch_size, 100).to(device=device)
generated_samples = generator(latent_space_samples)

In [None]:
generated_samples = generated_samples.cpu().detach()
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    plt.imshow(generated_samples[i].reshape(28, 28), cmap="gray_r")
    plt.xticks([])
    plt.yticks([])