In [1]:
import pandas as pd
import numpy as np
from ctgan_adapter import CtganAdapter
from ctgan_benchmark import evaluate_ctgan, print_evaluation_results
from ctgan_utils import preprocess_data, get_tstr_results
import json

In [2]:
# Load configuration
print("# Load configuration")
with open("config.json", "r") as f:
    config = json.load(f)

# Load configuration


In [3]:
# 1. Load and prepare the dataset
print("# 1. Load and prepare the dataset")
data_raw = pd.read_csv("nursery.csv")
print(f"Columns in dataset: {data_raw.columns.tolist()}")
print(f"Dataset shape: {data_raw.shape}")
print(data_raw.head())

# 1. Load and prepare the dataset
Columns in dataset: [' parents', 'has_nurs', 'form', 'children', 'housing', 'finance', 'social', 'health', 'Target']
Dataset shape: (12960, 9)
   parents has_nurs      form children     housing     finance         social  \
0    usual   proper  complete        1  convenient  convenient        nonprob   
1    usual   proper  complete        1  convenient  convenient        nonprob   
2    usual   proper  complete        1  convenient  convenient        nonprob   
3    usual   proper  complete        1  convenient  convenient  slightly_prob   
4    usual   proper  complete        1  convenient  convenient  slightly_prob   

        health     Target  
0  recommended  recommend  
1     priority   priority  
2    not_recom  not_recom  
3  recommended  recommend  
4     priority   priority  


In [4]:
# 2. Preprocess data and detect categorical columns
print("\n# 2. Preprocess data and detect categorical columns")
data, categorical_columns = preprocess_data(data_raw)
print(f"Detected categorical columns: {categorical_columns}")

2025-03-31 10:57:17,179 - INFO - Converted  parents to category type (has 3 unique values)
2025-03-31 10:57:17,184 - INFO - Converted has_nurs to category type (has 5 unique values)
2025-03-31 10:57:17,190 - INFO - Converted form to category type (has 4 unique values)
2025-03-31 10:57:17,195 - INFO - Converted children to category type (has 4 unique values)
2025-03-31 10:57:17,200 - INFO - Converted housing to category type (has 3 unique values)
2025-03-31 10:57:17,206 - INFO - Converted finance to category type (has 2 unique values)
2025-03-31 10:57:17,214 - INFO - Converted social to category type (has 3 unique values)
2025-03-31 10:57:17,221 - INFO - Converted health to category type (has 3 unique values)
2025-03-31 10:57:17,228 - INFO - Converted Target to category type (has 5 unique values)



# 2. Preprocess data and detect categorical columns
Detected categorical columns: [' parents', 'has_nurs', 'form', 'children', 'housing', 'finance', 'social', 'health', 'Target']


In [5]:
# 3. Define the target column for this dataset
print("\n# 3. Define the target column for this dataset")
target_column = "Target"
print(f"Target column: {target_column}")


# 3. Define the target column for this dataset
Target column: Target


In [7]:
# 4. Split the data into features and target
print("\n# 4. Split the data into features and target")
X = data.drop(columns=[target_column])
y = data[target_column]
print(f"Features shape: {X.shape}")
print(f"Target shape: {y.shape}")
print(f"Target distribution:\n{y.value_counts()}")


# 4. Split the data into features and target
Features shape: (12960, 8)
Target shape: (12960,)
Target distribution:
Target
not_recom     4320
priority      4266
spec_prior    4044
very_recom     328
recommend        2
Name: count, dtype: int64


In [8]:
# 5. Initialize and train CTGAN
print("\n# 5. Initialize and train CTGAN")
ctgan = CtganAdapter(**config["ctgan_params"])
print("Training CTGAN model...")
ctgan.fit(X, y)
print("Training completed")


# 5. Initialize and train CTGAN
Training CTGAN model...


Training Epochs:   0%|          | 1/300 [00:58<4:51:25, 58.48s/it]

Epoch 0, Loss D: 6.2771, Loss G: 1.0666


Training Epochs:  10%|█         | 31/300 [22:30<3:02:54, 40.80s/it]

Epoch 30, Loss D: 0.0373, Loss G: 0.3428


Training Epochs:  20%|██        | 61/300 [42:19<2:42:49, 40.88s/it]

Epoch 60, Loss D: -0.2794, Loss G: 0.8706


Training Epochs:  30%|███       | 91/300 [1:01:58<2:16:35, 39.21s/it]

Epoch 90, Loss D: -1.8146, Loss G: 1.3066


Training Epochs:  40%|████      | 121/300 [1:21:22<1:39:50, 33.46s/it]

Epoch 120, Loss D: -1.9714, Loss G: 1.2658


Training Epochs:  50%|█████     | 151/300 [1:33:58<1:01:45, 24.87s/it]

Epoch 150, Loss D: -1.8867, Loss G: 1.2879


Training Epochs:  60%|██████    | 181/300 [1:46:18<49:40, 25.04s/it]  

Epoch 180, Loss D: -1.8101, Loss G: 1.3353


Training Epochs:  70%|███████   | 211/300 [1:58:38<37:51, 25.52s/it]

Epoch 210, Loss D: -1.6483, Loss G: 1.2837


Training Epochs:  80%|████████  | 241/300 [2:11:23<27:52, 28.35s/it]

Epoch 240, Loss D: -1.4880, Loss G: 1.2236


Training Epochs:  90%|█████████ | 271/300 [2:26:08<14:38, 30.28s/it]

Epoch 270, Loss D: -1.4197, Loss G: 1.0215


Training Epochs: 100%|██████████| 300/300 [2:43:48<00:00, 32.76s/it]

Training completed





In [9]:
# 6. Generate synthetic data
print("\n# 6. Generate synthetic data")
n_samples = 1000  
print(f"Generating {n_samples} synthetic samples...")
synthetic_data = ctgan.generate(n_samples)
print(f"Generated {len(synthetic_data)} synthetic samples")
print("Synthetic data head:")
print(synthetic_data.head())


# 6. Generate synthetic data
Generating 1000 synthetic samples...
Generated 1000 synthetic samples
Synthetic data head:
       parents     has_nurs        form children     housing     finance  \
0   great_pret  less_proper  incomplete        1    critical      inconv   
1   great_pret     improper  incomplete        3  convenient  convenient   
2   great_pret       proper      foster        2    critical      inconv   
3  pretentious       proper    complete        1    critical      inconv   
4  pretentious    very_crit    complete     more  convenient  convenient   

          social       health      Target  
0        nonprob  recommended  spec_prior  
1    problematic  recommended    priority  
2        nonprob    not_recom   not_recom  
3  slightly_prob     priority    priority  
4        nonprob    not_recom   not_recom  


In [10]:
# 7. Evaluate quality using TSTR and other metrics
print("\n# 7. Evaluate quality using TSTR and other metrics")
print("Running evaluation...")
evaluation_results = evaluate_ctgan(data, synthetic_data, target_column=target_column)
print_evaluation_results(evaluation_results)


2025-03-31 13:41:49,259 - INFO - Encoded categorical target with mapping: {'not_recom': 0, 'priority': 1, 'recommend': 2, 'spec_prior': 3, 'very_recom': 4}



# 7. Evaluate quality using TSTR and other metrics
Running evaluation...


2025-03-31 13:41:56,110 - ERROR - Error training XGBoost classifier: Invalid classes inferred from unique values of `y`.  Expected: [0 1 2 3], got [0 1 3 4]
2025-03-31 13:41:56,117 - INFO - Encoded categorical targets with mapping: {'not_recom': 0, 'priority': 1, 'recommend': 2, 'spec_prior': 3, 'very_recom': 4}
2025-03-31 13:42:03,178 - ERROR - Error in TSTR with XGBoost: Invalid classes inferred from unique values of `y`.  Expected: [0 1 2], got [0 1 3]
2025-03-31 13:42:03,178 - INFO - CTGAN Evaluation Results:
2025-03-31 13:42:03,178 - INFO - 
Likelihood Fitness: Not applicable for fully categorical data
2025-03-31 13:42:03,178 - INFO - 
Statistical Similarity Metrics:
2025-03-31 13:42:03,186 - INFO -   - Jensen-Shannon Divergence Mean (Categorical): 0.0198
2025-03-31 13:42:03,186 - INFO - 
Machine Learning Efficacy Metrics on Real Data:
2025-03-31 13:42:03,186 - INFO -   LogisticRegression:
2025-03-31 13:42:03,189 - INFO -     - Accuracy: 0.9267
2025-03-31 13:42:03,189 - INFO -    

In [11]:
# 8. Extract and display TSTR results specifically
print("\n# 8. TSTR Performance Results")
tstr_results = get_tstr_results(evaluation_results)
if tstr_results is not None:
    print(tstr_results)


# 8. TSTR Performance Results
                    Accuracy        F1
LogisticRegression  0.899383  0.888663
RandomForest        0.867593  0.857337
MLP                 0.849306  0.839284
XGBoost                  NaN       NaN


In [12]:
# 9. Save the synthetic data
print("\n# 9. Save synthetic data")
output_path = "nursery_synthetic.csv"
synthetic_data.to_csv(output_path, index=False)
print(f"Synthetic data saved to {output_path}")

print("\nTest completed successfully!")


# 9. Save synthetic data
Synthetic data saved to nursery_synthetic.csv

Test completed successfully!
