In [1]:
import torch
from torch import nn

import math
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

In [2]:
torch.manual_seed(1)

<torch._C.Generator at 0x1ec34c03f30>

In [3]:
train_data_length = 1024
train_data = torch.zeros((train_data_length, 2))
train_data[:, 0] = 2 * math.pi * torch.rand(train_data_length)
train_data[:, 1] = torch.sin(train_data[:, 0])
train_labels = torch.zeros(train_data_length)
train_set = [
    (train_data[i], train_labels[i]) for i in range(train_data_length)
]

In [4]:
train_labels

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [5]:
plt.plot(train_data[:, 0], train_data[:, 1], ".")

[<matplotlib.lines.Line2D at 0x1ec39b7ce90>]

In [6]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        output = self.model(x)
        return output

In [8]:
discriminator = Discriminator()

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 2),
        )

    def forward(self, x):
        output = self.model(x)
        return output

generator = Generator()

In [10]:
lr = 0.001
num_epochs = 300
loss_function = nn.BCELoss()

In [11]:
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

In [12]:
for epoch in range(num_epochs):
    for n, (real_samples, _) in enumerate(train_loader):
        # Data for training the discriminator
        real_samples_labels = torch.ones((batch_size, 1))
        latent_space_samples = torch.randn((batch_size, 2))
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1))
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 2))

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        optimizer_generator.step()

        # Show loss
        if epoch % 10 == 0 and n == batch_size - 1:
            print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
            print(f"Epoch: {epoch} Loss G.: {loss_generator}")

Epoch: 0 Loss D.: 0.24032321572303772
Epoch: 0 Loss G.: 2.227731704711914
Epoch: 10 Loss D.: 0.6768372058868408
Epoch: 10 Loss G.: 0.8582908511161804
Epoch: 20 Loss D.: 0.6395218372344971
Epoch: 20 Loss G.: 1.1333470344543457
Epoch: 30 Loss D.: 0.7259066700935364
Epoch: 30 Loss G.: 0.7502137422561646
Epoch: 40 Loss D.: 0.6980507969856262
Epoch: 40 Loss G.: 0.6755833625793457
Epoch: 50 Loss D.: 0.6807307004928589
Epoch: 50 Loss G.: 0.7314662933349609
Epoch: 60 Loss D.: 0.6899004578590393
Epoch: 60 Loss G.: 0.6731136441230774
Epoch: 70 Loss D.: 0.7024956941604614
Epoch: 70 Loss G.: 0.704626202583313
Epoch: 80 Loss D.: 0.6843791007995605
Epoch: 80 Loss G.: 0.6842287182807922
Epoch: 90 Loss D.: 0.6455491185188293
Epoch: 90 Loss G.: 0.803554892539978
Epoch: 100 Loss D.: 0.6961420178413391
Epoch: 100 Loss G.: 0.6905702948570251
Epoch: 110 Loss D.: 0.690811276435852
Epoch: 110 Loss G.: 0.6690704226493835
Epoch: 120 Loss D.: 0.6924923658370972
Epoch: 120 Loss G.: 0.6940162181854248
Epoch: 130 

In [13]:
latent_space_samples = torch.randn(100, 2)
generated_samples = generator(latent_space_samples)

In [14]:
generated_samples = generated_samples.detach()
plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")

[<matplotlib.lines.Line2D at 0x1ec3e109520>]

In [15]:
generated_samples

tensor([[ 1.0972,  0.9353],
        [ 3.7861, -0.5362],
        [ 1.9094,  1.0108],
        [ 1.7305,  1.0169],
        [ 2.8937,  0.3262],
        [ 0.3284,  0.3809],
        [ 7.2420, -0.0284],
        [ 0.6798,  0.6673],
        [ 3.5401, -0.3008],
        [ 3.8419, -0.5886],
        [ 0.7306,  0.7132],
        [ 4.9814, -1.0209],
        [ 2.4591,  0.6952],
        [ 6.4962, -0.1363],
        [ 4.0299, -0.7586],
        [ 2.6406,  0.5625],
        [ 5.7852, -0.5178],
        [ 6.3938, -0.2606],
        [ 1.1104,  0.9396],
        [ 3.7113, -0.4690],
        [ 4.9523, -0.9905],
        [ 3.6367, -0.3913],
        [ 5.6782, -0.6718],
        [ 1.8010,  1.0253],
        [ 2.0489,  0.9510],
        [ 4.9745, -1.0489],
        [ 2.5648,  0.6199],
        [ 2.1575,  0.8968],
        [ 1.4427,  1.0346],
        [ 0.6222,  0.6295],
        [ 4.7314, -0.9987],
        [ 2.9474,  0.2649],
        [ 1.0631,  0.9216],
        [ 2.8311,  0.3862],
        [ 3.6361, -0.3876],
        [ 3.9082, -0