# Train MedGAN with 5-Fold Cross Validation and 7 Runs (Final Version)

In [None]:

import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import MinMaxScaler
from data_loader import load_shuttle_data
from medgan_model import Medgan


In [None]:

# --- Load and preprocess real Shuttle data ---
df = pd.read_csv("datasets/shuttle.csv", header=None)  # <-- Replace with your path
X_real = df.iloc[:, :-1]
y_real = df.iloc[:, -1]

print("Real data loaded successfully.")


In [None]:

# --- Load data without labels for GAN training ---
X_train, X_test = load_shuttle_data("datasets/shuttle.csv", test_size=0.2, normalize=True, n_shuffle=10)
input_dim = X_train.shape[1]

print(f"Training data shape: {X_train.shape}")


In [None]:

# --- Set parameters ---
k_folds = 5
n_runs = 7
n_epochs = 100
batch_size = 64
learning_rate = 0.0002

print(f"Parameters set: Folds={k_folds}, Runs={n_runs}, Epochs={n_epochs}, Batch Size={batch_size}")


In [None]:

# --- Training with 5-Fold CV and 7 Runs ---

all_run_ae_losses = []
all_run_d_losses = []
all_run_g_losses = []

for run in range(n_runs):
    print(f"\n===== Starting Run {run+1}/{n_runs} =====")
    
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=run)
    
    fold_ae_losses = []
    fold_d_losses = []
    fold_g_losses = []
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
        print(f"  --- Fold {fold+1}/{k_folds} ---")
        
        # Split data
        X_tr, X_val = X_train[train_idx], X_train[val_idx]
        
        # Initialize fresh model
        medgan = Medgan(input_dim=input_dim, ae_loss_type='bce')
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        
        # Loss collectors for epochs
        ae_losses_epoch = []
        d_losses_epoch = []
        g_losses_epoch = []
        
        # Epochs
        for epoch in range(n_epochs):
            # Shuffle training data
            idx = np.random.permutation(len(X_tr))
            X_tr = X_tr[idx]
            
            # Mini-batch training
            for i in range(0, len(X_tr), batch_size):
                batch = X_tr[i:i+batch_size]
                noise = np.random.normal(size=(batch.shape[0], medgan.random_dim))
                
                with tf.GradientTape(persistent=True) as tape:
                    ae_loss, d_loss, g_loss = medgan.train_step(batch, noise)
                
                ae_vars = medgan.encoder.trainable_variables + medgan.decoder.trainable_variables
                d_vars = medgan.discriminator.trainable_variables
                g_vars = medgan.generator.trainable_variables
                
                optimizer.apply_gradients(zip(tape.gradient(ae_loss, ae_vars), ae_vars))
                optimizer.apply_gradients(zip(tape.gradient(d_loss, d_vars), d_vars))
                optimizer.apply_gradients(zip(tape.gradient(g_loss, g_vars), g_vars))
            
            # Print epoch progress
            print(f"Epoch {epoch+1}/{n_epochs} completed", end="\r")
        
        # After all epochs -> validate on fold
        random_val_data = np.random.normal(size=(X_val.shape[0], medgan.random_dim))
        ae_loss_val, d_loss_val, g_loss_val = medgan.train_step(X_val, random_val_data)
        
        fold_ae_losses.append(ae_loss_val.numpy())
        fold_d_losses.append(d_loss_val.numpy())
        fold_g_losses.append(g_loss_val.numpy())
    
    # Average fold losses for this run
    all_run_ae_losses.append(np.mean(fold_ae_losses))
    all_run_d_losses.append(np.mean(fold_d_losses))
    all_run_g_losses.append(np.mean(fold_g_losses))


In [None]:

# --- Final Averaging and Results ---

final_ae_loss = np.mean(all_run_ae_losses)
final_d_loss = np.mean(all_run_d_losses)
final_g_loss = np.mean(all_run_g_losses)

print("""
=======================================
Final Results after 7 Runs and 5-Fold CV
=======================================""")
print(f"Autoencoder Loss (AE): {final_ae_loss:.4f}")
print(f"Discriminator Loss (D): {final_d_loss:.4f}")
print(f"Generator Loss (G): {final_g_loss:.4f}")
