In [None]:
import torch
from gan_model import GAN

# Define generator parameters
gen_kwargs = {
    'noise_dim': 100,
    'in_features': 50,
    'cond_dim': 10,  # Only needed for Transformer generator
    'out_features': 100,
    'd_model': 256,  # Only needed for Transformer generator
    'nhead': 4,  # Only needed for Transformer generator
    'num_encoder_layers': 3,  # Only needed for Transformer generator
    'num_decoder_layers': 3,  # Only needed for Transformer generator
    'layer_sizes': [128, 256, 512],
    'activation': 'relu',
    'batchnorm': True,
    'dropout_p': 0.2,
    'nonnegative_end_ind': 50,
    'use_skip_connections': False,
    'add_vector': None,  # Only needed if you want to add a specific vector to the output
}

# Define discriminator parameters
disc_kwargs = {
    'in_features': 50,
    'in_generated': 100,
    'd_model': 256,  # Only needed for Transformer discriminator
    'nhead': 4,  # Only needed for Transformer discriminator
    'num_encoder_layers': 3,  # Only needed for Transformer discriminator
    'num_decoder_layers': 3,  # Only needed for Transformer discriminator
    'layer_sizes': [128, 256, 512],
    'activation': 'relu',
    'batchnorm': True,
    'dropout_p': 0.2,
    'use_skip_connections': False,
}

# Instantiate the GAN model
gan = GAN(
    gen_arch='MLP',  # or 'Transformer'
    disc_arch='Transformer',  # or 'MLP'
    gen_kwargs=gen_kwargs,
    disc_kwargs=disc_kwargs
)

# Dummy data for training
batch_size = 16
noise = gan.generator.get_noise(batch_size)
real_data = torch.randn(batch_size, gen_kwargs['in_features'])
cond_data = torch.randn(batch_size, gen_kwargs['cond_dim'])  # Only needed for Transformer generator

# Forward pass through GAN
fake_data, disc_real, disc_fake = gan(noise, real_data, cond_data)

optimizer_G = torch.optim.Adam(gan.generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(gan.discriminator.parameters(), lr=0.0002)

for epoch in range(1):
    # Train Discriminator
    optimizer_D.zero_grad()
    _, disc_real, disc_fake = gan(noise, real_data, cond_data)
    loss_D = -(torch.log(disc_real) + torch.log(1 - disc_fake)).mean()
    loss_D.backward()
    optimizer_D.step()

    # Train Generator
    optimizer_G.zero_grad()
    fake_data, _, disc_fake = gan(noise, real_data, cond_data)
    loss_G = -torch.log(disc_fake).mean()
    loss_G.backward()
    optimizer_G.step()

    print(f"Epoch {epoch}: Loss D: {loss_D.item()}, Loss G: {loss_G.item()}")