In [None]:
import matplotlib
matplotlib.use('Agg')  # Non-interactive, no GUI
import matplotlib.pyplot as plt

# Save the plot instead of showing it
import torch
data = torch.tensor([
    [1.0, 2.0],
    [2.0, 3.0],
    [3.0, 2.5],
    [4.0, 4.0]
])

plt.plot(data[:, 0], data[:, 1])
plt.title("Sanity Check")
plt.savefig("sanity_plot.png")

In [None]:
import torch
import matplotlib.pyplot as plt

# Sanity test: small 2D tensor
data = torch.tensor([
    [1.0, 2.0],
    [2.0, 3.0],
    [3.0, 2.5],
    [4.0, 4.0]
])

# Simple plot
plt.plot(data[:, 0], data[:, 1])
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Sanity Check Plot")
plt.grid(True)
plt.show()

In [2]:
import torch
from torch import nn
import math
import matplotlib.pyplot as plt

In [3]:
torch.manual_seed(1)

<torch._C.Generator at 0x1f9458a02b0>

In [4]:
train_data_length = 1024
train_data = torch.zeros((train_data_length, 2))
train_data[:, 0] = 2 * math.pi * torch.rand(train_data_length)
train_data[:, 1] = torch.sin(train_data[:, 0])
train_labels = torch.zeros(train_data_length)
train_set = [
    (train_data[i], train_labels[i]) for i in range(train_data_length)
]
train_data.shape

torch.Size([1024, 2])

In [4]:
train_labels

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [5]:
print(train_data.shape)
print(train_data.device)
print(train_data)

torch.Size([1024, 2])
cpu
tensor([[ 4.7603, -0.9989],
        [ 1.7550,  0.9831],
        [ 2.5326,  0.5721],
        ...,
        [ 3.0366,  0.1048],
        [ 2.1415,  0.8415],
        [ 1.1515,  0.9134]])


In [None]:
plt.scatter(train_data[:, 0], train_data[:, 1], s=5)
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Scatter Plot of 2D Points")
plt.show()

In [None]:
plt.plot(train_data[:, 0], train_data[:, 1], ".")
plt.show() 

In [6]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        output = self.model(x)
        return output

In [8]:
discriminator = Discriminator()

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 2),
        )

    def forward(self, x):
        output = self.model(x)
        return output

generator = Generator()

In [10]:
lr = 0.001
num_epochs = 300
loss_function = nn.BCELoss()

In [11]:
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

In [12]:
for epoch in range(num_epochs):
    for n, (real_samples, _) in enumerate(train_loader):
        # Data for training the discriminator
        real_samples_labels = torch.ones((batch_size, 1))
        latent_space_samples = torch.randn((batch_size, 2))
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1))
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels)
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 2))

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        optimizer_generator.step()

        # Show loss
        if epoch % 10 == 0 and n == batch_size - 1:
            print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
            print(f"Epoch: {epoch} Loss G.: {loss_generator}")

Epoch: 0 Loss D.: 0.24032321572303772
Epoch: 0 Loss G.: 2.227731466293335
Epoch: 10 Loss D.: 0.6768372058868408
Epoch: 10 Loss G.: 0.8582908511161804
Epoch: 20 Loss D.: 0.6395218372344971
Epoch: 20 Loss G.: 1.1333469152450562
Epoch: 30 Loss D.: 0.7259066104888916
Epoch: 30 Loss G.: 0.7502138614654541
Epoch: 40 Loss D.: 0.698050856590271
Epoch: 40 Loss G.: 0.6755833029747009
Epoch: 50 Loss D.: 0.6791377663612366
Epoch: 50 Loss G.: 0.7212958335876465
Epoch: 60 Loss D.: 0.6789308190345764
Epoch: 60 Loss G.: 0.7021872997283936
Epoch: 70 Loss D.: 0.6948771476745605
Epoch: 70 Loss G.: 0.6239566802978516
Epoch: 80 Loss D.: 0.6906977891921997
Epoch: 80 Loss G.: 0.6636168956756592
Epoch: 90 Loss D.: 0.66363525390625
Epoch: 90 Loss G.: 0.7764437198638916
Epoch: 100 Loss D.: 0.7190923690795898
Epoch: 100 Loss G.: 0.6860860586166382
Epoch: 110 Loss D.: 0.6613975763320923
Epoch: 110 Loss G.: 0.7114980220794678
Epoch: 120 Loss D.: 0.6866812705993652
Epoch: 120 Loss G.: 0.7393763661384583
Epoch: 130 

In [13]:
latent_space_samples = torch.randn(100, 2)
generated_samples = generator(latent_space_samples)

In [14]:
generated_samples = generated_samples.detach()
plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")

[<matplotlib.lines.Line2D at 0x1eecf037d40>]

In [15]:
generated_samples

tensor([[ 1.0624,  0.9014],
        [ 3.7070, -0.5651],
        [ 1.7968,  1.0005],
        [ 1.5020,  1.0356],
        [ 2.8605,  0.2287],
        [ 0.2720,  0.3065],
        [ 5.8030,  0.0683],
        [ 0.6268,  0.6052],
        [ 3.4981, -0.3757],
        [ 3.8958, -0.7212],
        [ 0.6414,  0.6239],
        [ 4.8305, -0.9712],
        [ 2.3808,  0.6713],
        [ 5.7797, -0.2467],
        [ 3.9632, -0.7576],
        [ 2.3912,  0.6731],
        [ 5.6864, -0.4606],
        [ 5.6860, -0.3285],
        [ 1.0040,  0.8699],
        [ 3.7615, -0.6254],
        [ 4.8772, -0.9591],
        [ 3.5233, -0.3922],
        [ 5.3365, -0.6542],
        [ 1.6349,  1.0278],
        [ 1.8248,  1.0044],
        [ 4.7032, -0.9912],
        [ 2.3623,  0.6933],
        [ 2.0088,  0.9257],
        [ 1.3600,  1.0099],
        [ 0.5252,  0.5317],
        [ 4.6842, -1.0139],
        [ 2.7776,  0.3186],
        [ 0.9989,  0.8727],
        [ 2.6012,  0.5004],
        [ 3.4045, -0.2722],
        [ 3.8580, -0