In [None]:
import torch
from torch import optim, nn, utils, Tensor
import lightning as L

import numpy as np
import matplotlib.pyplot as plt

from Network import Generator, Discriminator
from Signal_Generator import *
from Signal_Analyzer import *

In [None]:
dataset = []

for i in range(10):
    SG = Signal_Generator(num_sources=1, noise_amplitude=1)
    signals = SG.generating_signal()
    params = SG.printing_parameters()
    signal = signals['Signal'].values

    signal_tensor = torch.tensor(signal, dtype=torch.float).unsqueeze(0).unsqueeze(0)
    params_tensor = torch.tensor(params, dtype=torch.float).unsqueeze(0)

    dataset.append((signal_tensor, params_tensor))

num_latent_variables = 10
z = torch.randn(1, num_latent_variables, 1)

In [None]:
class GAN(L.LightningModule):
    def __init__(
        self,
        dataset,
        num_latent_variables: int = 10,
        lr: float = 0.0001,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False
        self.dataset = dataset

        # networks
        self.generator = Generator(in_channels=1, num_latent_variables=num_latent_variables, length=len(signal), num_parameters=len(params))
        self.discriminator = Discriminator(input_channels=1, length=len(signal), num_parameters=len(params))
        self.criterion = nn.BCELoss()

    def forward(self, signal_tensor, z):
        return self.generator(signal_tensor, z)

    def adversarial_loss(self, output_d, y):
        return self.criterion(output_d, y)

    def training_step(self, batch, batch_idxx):
        signal_tensor, params_tensor = batch

        z = torch.randn(signal_tensor.size(0), self.hparams.num_latent_variables, 1).type_as(signal_tensor)

        optimizer_g, optimizer_d = self.optimizers()

        # Train Generator
        self.generated_params = self(signal_tensor, z)
        fake_output = self.discriminator(signal_tensor, self.generated_params, z)
        g_loss = self.adversarial_loss(fake_output, torch.ones_like(fake_output))
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()

        # Train Discriminator
        fake_params = self.generator(signal_tensor, z).detach()
        real_output = self.discriminator(signal_tensor, params_tensor, z)
        fake_output = self.discriminator(signal_tensor, fake_params, z)

        real_loss = self.adversarial_loss(real_output, torch.ones_like(real_output))
        fake_loss = self.adversarial_loss(fake_output, torch.zeros_like(fake_output))
        d_loss = (real_loss + fake_loss) / 2
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
            
    def configure_optimizers(self):
        optimizer_g = optim.Adam(self.generator.parameters(), lr=self.hparams.lr)
        optimizer_d = optim.Adam(self.discriminator.parameters(), lr=self.hparams.lr)
        return [optimizer_g, optimizer_d]
    
    def train_dataloader(self):
        return utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)

In [None]:
model = GAN(dataset=dataset)

trainer = L.Trainer(max_epochs=10, devices=1 if torch.cuda.is_available() else 0, accelerator="gpu" if torch.cuda.is_available() else "cpu")
trainer.fit(model)