In [1]:
import torch
from torch import nn

import math
import matplotlib.pyplot as plt

In [2]:
torch.manual_seed(1)

<torch._C.Generator at 0x1871c9fbf70>

In [3]:
train_data_length = 1024
train_data = torch.zeros((train_data_length, 2))
train_data[:, 0] = 2 * math.pi * torch.rand(train_data_length)
train_data[:, 1] = torch.sin(train_data[:, 0])
train_labels = torch.zeros(train_data_length)
train_set = [
    (train_data[i], train_labels[i]) for i in range(train_data_length)
]

In [4]:
train_labels

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [None]:
plt.plot(train_data[:, 0], train_data[:, 1], ".")

[<matplotlib.lines.Line2D at 0x18721a1b680>]

In [4]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        output = self.model(x)
        return output

In [6]:
discriminator = Discriminator()

In [7]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 2),
        )

    def forward(self, x):
        output = self.model(x)
        return output

generator = Generator()

In [8]:
lr = 0.001
num_epochs = 300
loss_function = nn.BCELoss()

In [9]:
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

In [10]:
for epoch in range(num_epochs):
    for n, (real_samples, _) in enumerate(train_loader):
        # Data for training the discriminator
        real_samples_labels = torch.ones((batch_size, 1))
        latent_space_samples = torch.randn((batch_size, 2))
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1))
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 2))

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        optimizer_generator.step()

        # Show loss
        if epoch % 10 == 0 and n == batch_size - 1:
            print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
            print(f"Epoch: {epoch} Loss G.: {loss_generator}")

Epoch: 0 Loss D.: 0.21308739483356476
Epoch: 0 Loss G.: 1.7989581823349
Epoch: 10 Loss D.: 0.6099681258201599
Epoch: 10 Loss G.: 0.907738447189331
Epoch: 20 Loss D.: 0.6200631856918335
Epoch: 20 Loss G.: 0.9532092809677124
Epoch: 30 Loss D.: 0.6277890205383301
Epoch: 30 Loss G.: 0.9933774471282959
Epoch: 40 Loss D.: 0.639619767665863
Epoch: 40 Loss G.: 0.8615650534629822
Epoch: 50 Loss D.: 0.6653093099594116
Epoch: 50 Loss G.: 0.8684064745903015
Epoch: 60 Loss D.: 0.6464164853096008
Epoch: 60 Loss G.: 0.9181209206581116
Epoch: 70 Loss D.: 0.6513052582740784
Epoch: 70 Loss G.: 0.8436679244041443
Epoch: 80 Loss D.: 0.6663874387741089
Epoch: 80 Loss G.: 0.665390133857727
Epoch: 90 Loss D.: 0.6980440616607666
Epoch: 90 Loss G.: 0.6386311650276184
Epoch: 100 Loss D.: 0.676981508731842
Epoch: 100 Loss G.: 0.6501144170761108
Epoch: 110 Loss D.: 0.6719574332237244
Epoch: 110 Loss G.: 0.7596195340156555
Epoch: 120 Loss D.: 0.6511386036872864
Epoch: 120 Loss G.: 0.7798190712928772
Epoch: 130 Los

In [11]:
latent_space_samples = torch.randn(100, 2)
generated_samples = generator(latent_space_samples)

In [None]:
generated_samples = generated_samples.detach()
plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")

[<matplotlib.lines.Line2D at 0x1bd526a0b30>]