In [1]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, classification_report
)
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline
from copy import deepcopy

import tensorflow as tf
from medgan_model import Medgan
from data_loader import load_nursery_data



In [3]:
# --- Load Nursery Dataset ---
x_train, x_test, y_train, y_test = load_nursery_data("datasets/nursery.csv")
input_dim = x_train.shape[1]

# --- Fit LabelEncoder on all labels to avoid unseen class error ---
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
label_encoder.fit(pd.concat([y_train, y_test], axis=0))  # Combine before fitting

y_train_encoded = label_encoder.transform(y_train)
y_test_encoded = label_encoder.transform(y_test)


# --- Initialize and Train MedGAN ---
medgan = Medgan(input_dim=input_dim, ae_loss_type='bce')
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)
batch_size = 128
n_epochs = 300

for epoch in range(n_epochs):
    np.random.shuffle(x_train)
    for i in range(0, len(x_train), batch_size):
        batch = x_train[i:i + batch_size]
        noise = np.random.normal(size=(batch.shape[0], medgan.random_dim))
        medgan.train_step(batch, noise)
    print(f"Epoch {epoch+1}/{n_epochs} completed")



Epoch 1/300 completed
Epoch 2/300 completed
Epoch 3/300 completed
Epoch 4/300 completed
Epoch 5/300 completed
Epoch 6/300 completed
Epoch 7/300 completed
Epoch 8/300 completed
Epoch 9/300 completed
Epoch 10/300 completed
Epoch 11/300 completed
Epoch 12/300 completed
Epoch 13/300 completed
Epoch 14/300 completed
Epoch 15/300 completed
Epoch 16/300 completed
Epoch 17/300 completed
Epoch 18/300 completed
Epoch 19/300 completed
Epoch 20/300 completed
Epoch 21/300 completed
Epoch 22/300 completed
Epoch 23/300 completed
Epoch 24/300 completed
Epoch 25/300 completed
Epoch 26/300 completed
Epoch 27/300 completed
Epoch 28/300 completed
Epoch 29/300 completed
Epoch 30/300 completed
Epoch 31/300 completed
Epoch 32/300 completed
Epoch 33/300 completed
Epoch 34/300 completed
Epoch 35/300 completed
Epoch 36/300 completed
Epoch 37/300 completed
Epoch 38/300 completed
Epoch 39/300 completed
Epoch 40/300 completed
Epoch 41/300 completed
Epoch 42/300 completed
Epoch 43/300 completed
Epoch 44/300 complet

In [4]:

# --- Generate Synthetic Features ---
synthetic_data = medgan.generate_data(num_samples=1000)

# --- Generate Synthetic Labels (randomly from real y_train distribution) ---
synthetic_labels = np.random.choice(y_train_encoded, size=synthetic_data.shape[0])

# --- Split Synthetic Data ---
synthetic_x_train, synthetic_x_test, synthetic_y_train, synthetic_y_test = train_test_split(
    synthetic_data, synthetic_labels, test_size=0.2, stratify=synthetic_labels, random_state=42
)


In [5]:
# --- Define Classifiers ---
models = {
    "Random Forest": RandomForestClassifier(random_state=42),
    "MLP Classifier": Pipeline([
        ('scaler', StandardScaler()),
        ('mlp', MLPClassifier(max_iter=300, random_state=42))
    ]),
    "KNN Classifier": Pipeline([
        ('scaler', StandardScaler()),
        ('knn', KNeighborsClassifier(n_neighbors=5))
    ]),
    "Logistic Regression": Pipeline([
        ('scaler', StandardScaler()),
        ('lr', LogisticRegression(max_iter=300, random_state=42, class_weight='balanced'))
    ])
}

# --- Evaluate Models (Trained on Synthetic, Tested on Real) ---
results_synthetic = {}

In [7]:
for name, model in models.items():
    clf = deepcopy(model)
    clf.fit(synthetic_x_train, synthetic_y_train)
    y_pred = clf.predict(x_test)

    acc = accuracy_score(y_test_encoded, y_pred)
    prec = precision_score(y_test_encoded, y_pred, average='macro', zero_division=0)
    rec = recall_score(y_test_encoded, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_test_encoded, y_pred, average='macro', zero_division=0)

    results_synthetic[name] = {
        "Accuracy": acc,
        "Precision": prec,
        "Recall": rec,
        "F1 Score": f1
    }

    print(f"\n=== {name} (Trained on Synthetic, Tested on Real) ===")
    print(classification_report(y_test_encoded, y_pred, target_names=label_encoder.classes_, zero_division=0))

# --- Summary Table ---
results_synthetic_df = pd.DataFrame(results_synthetic).T
print("\nSummary of Synthetic-Trained Classifiers on Real Test Data:")
print(results_synthetic_df)


=== Random Forest (Trained on Synthetic, Tested on Real) ===
              precision    recall  f1-score   support

   not_recom       0.42      0.08      0.13       870
    priority       0.37      0.09      0.14       873
   recommend       0.00      0.00      0.00         2
  spec_prior       0.29      0.83      0.43       785
  very_recom       0.00      0.00      0.00        62

    accuracy                           0.31      2592
   macro avg       0.22      0.20      0.14      2592
weighted avg       0.35      0.31      0.22      2592






=== MLP Classifier (Trained on Synthetic, Tested on Real) ===
              precision    recall  f1-score   support

   not_recom       0.38      0.41      0.39       870
    priority       0.37      0.30      0.33       873
   recommend       0.00      0.00      0.00         2
  spec_prior       0.28      0.33      0.30       785
  very_recom       0.00      0.00      0.00        62

    accuracy                           0.34      2592
   macro avg       0.21      0.21      0.21      2592
weighted avg       0.34      0.34      0.34      2592


=== KNN Classifier (Trained on Synthetic, Tested on Real) ===
              precision    recall  f1-score   support

   not_recom       0.38      0.54      0.45       870
    priority       0.38      0.29      0.33       873
   recommend       0.00      0.00      0.00         2
  spec_prior       0.35      0.30      0.32       785
  very_recom       0.25      0.02      0.03        62

    accuracy                           0.37      2592
   ma