# Updated Train MedGAN with 5-Fold Cross Validation and 7 Runs

In [None]:

import numpy as np
from sklearn.model_selection import KFold
from data_loader import load_shuttle_data  # or load_nursery_data / load_letter_data
from medgan_model import Medgan
import tensorflow as tf


In [None]:

# Load your dataset
csv_path = "path/to/your/data.csv"  # <-- Change this to your CSV file path
X_train, X_test = load_shuttle_data(csv_path)


In [None]:

# Parameters
k_folds = 5
n_runs = 7
epochs = 50
batch_size = 128


In [None]:

# To collect results across all runs
all_ae_losses = []
all_d_losses = []
all_g_losses = []

for run in range(n_runs):
    print(f"Starting Run {run+1}/{n_runs}")

    kf = KFold(n_splits=k_folds, shuffle=True, random_state=run)

    run_ae_losses = []
    run_d_losses = []
    run_g_losses = []

    for fold, (train_index, val_index) in enumerate(kf.split(X_train)):
        print(f"  Fold {fold+1}/{k_folds}")

        X_tr, X_val = X_train[train_index], X_train[val_index]

        model = Medgan(input_dim=X_train.shape[1])
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

        for epoch in range(epochs):
            random_data = np.random.normal(size=(X_tr.shape[0], model.random_dim))

            with tf.GradientTape(persistent=True) as tape:
                ae_loss, d_loss, g_loss = model.train_step(X_tr, random_data)

            ae_vars = model.encoder.trainable_variables + model.decoder.trainable_variables
            d_vars = model.discriminator.trainable_variables
            g_vars = model.generator.trainable_variables

            optimizer.apply_gradients(zip(tape.gradient(ae_loss, ae_vars), ae_vars))
            optimizer.apply_gradients(zip(tape.gradient(d_loss, d_vars), d_vars))
            optimizer.apply_gradients(zip(tape.gradient(g_loss, g_vars), g_vars))

        # Evaluate after fold
        random_data_val = np.random.normal(size=(X_val.shape[0], model.random_dim))
        ae_loss_val, d_loss_val, g_loss_val = model.train_step(X_val, random_data_val)

        run_ae_losses.append(ae_loss_val.numpy())
        run_d_losses.append(d_loss_val.numpy())
        run_g_losses.append(g_loss_val.numpy())

    # After 5 folds
    all_ae_losses.append(np.mean(run_ae_losses))
    all_d_losses.append(np.mean(run_d_losses))
    all_g_losses.append(np.mean(run_g_losses))


In [None]:

# Final average across 7 runs
final_ae_loss = np.mean(all_ae_losses)
final_d_loss = np.mean(all_d_losses)
final_g_loss = np.mean(all_g_losses)

print("\nFinal Results After 7 Runs with 5-Fold CV:")
print(f"AE Loss: {final_ae_loss:.4f}")
print(f"Discriminator Loss: {final_d_loss:.4f}")
print(f"Generator Loss: {final_g_loss:.4f}")
