# GAN usage in DEGANN

### This notebook demonstrates how to create and train a GAN using DEGANN.

### Import necessary libraries:

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from degann.networks import IModel

from degann.networks.topology.densenet.topology_config import DenseNetParams
from degann.networks.topology.densenet.compile_config import DenseNetCompileParams
from degann.networks.topology.gan.topology_config import GANTopologyParams
from degann.networks.topology.gan.compile_config import GANCompileParams

### Prepare data for neural network training:

In [None]:
# Define target function to approximate: y = sin(10x)
def sin10x(x):
    return np.sin(10 * x)


# Generate synthetic dataset
data_size = 2048
# Create input data with 20% extra to account for train/test split
X = np.linspace(0, 1, int(data_size / 0.8)).reshape(-1, 1)  # Reshape for Keras compatibility
y = sin10x(X)

# Split data into training and testing sets (80% train, 20% test)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2
)

### Define and create GAN architecture:

In [None]:
# Generator configuration: Learns to produce realistic outputs given random inputs
# Architecture: 3 hidden layers with 32 neurons each using Leaky ReLU activation
gen_config = DenseNetParams(
    input_size=1,  # Input dimension (random noise)
    block_size=[32, 32, 32],  # Hidden layer structure
    output_size=1,  # Output dimension (matching real data)
    activation_func="leaky_relu",
)

# Discriminator configuration: Distinguishes real vs generated data pairs
# Architecture: 3 hidden layers with 32 neurons each using Leaky ReLU activation
disc_config = DenseNetParams(
    input_size=2,  # Input dimension (concatenated [input, output] pairs)
    block_size=[32, 32, 32],
    output_size=1,  # Binary classification (real/fake)
    activation_func="leaky_relu",
)

# Combine components into GAN architecture
gan_params = GANTopologyParams(
    generator_params=gen_config,
    discriminator_params=disc_config,
)
gan = IModel(gan_params)  # Instantiate GAN model

### Compile GAN:

In [None]:
# Generator compilation parameters
gen_compile_config = DenseNetCompileParams(
    rate=0.0002,
    optimizer="Adam",
    loss_func="BinaryCrossentropy",  # Measures ability to fool discriminator
    metric_funcs=["mean_absolute_error"],  # Track output similarity to real data
)
# Discriminator compilation parameters
disc_compile_config = DenseNetCompileParams(
    rate=0.0002,
    optimizer="Adam",
    loss_func="BinaryCrossentropy",  # Standard binary classification loss
    metric_funcs=["binary_accuracy"],  # Track classification performance
)
# Combine compilation configurations
gan_compile_config = GANCompileParams(
    generator_params=gen_compile_config,
    discriminator_params=disc_compile_config,
)
gan.compile(gan_compile_config)  # Finalize model setup

### Train GAN and log metrics to tensorboard:

In [None]:
# Configure training logs directory
log_dir = "./gan_usage_log"
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=0,
    embeddings_freq=0,
    update_freq="epoch",  # Log metrics after each epoch
)
# Train GAN for 1500 epochs with batch size 64
gan.train(
    X_train,
    y_train,
    epochs=1500,
    mini_batch_size=64,
    callbacks=[tensorboard_callback],  # Enable TensorBoard logging
)

### Evaluate GAN on test data:

In [None]:
# Evaluate model performance on unseen data
gan.evaluate(
    X_test,
    y_test,
    batch_size=64,
    callbacks=[tensorboard_callback],
)

# Visualize results
plt.figure(figsize=(10, 6))

# Plot real data
plt.scatter(X_test, y_test, c="blue", label="Real Data", alpha=0.5)

# Plot generated data
plt.scatter(X_test, gan.feedforward(X_test), c="red", label="Generated Data", alpha=0.5)

# Plot ideal target function for reference
x = np.linspace(0, 1, 100)
plt.plot(x, sin10x(x), c="green", linestyle="--", label="Ideal: y = sin(10x)")

plt.xlabel("X")
plt.ylabel("y")
plt.title("Data Distribution Comparison")
plt.legend()
plt.grid(True)
plt.show()