In [24]:
import pandas as pd
from sdv.single_table import CTGANSynthesizer
from sdv.metadata import Metadata
from sdv.metadata import SingleTableMetadata
from sdv.evaluation.single_table import get_column_plot
from sdv.evaluation.single_table import evaluate_quality
import os

In [25]:
# Load the real data
real_data = pd.read_csv('../datasets/ben10_master.csv')

In [26]:
# Load or create metadata
metadata_path = 'metadata.json'

if os.path.exists(metadata_path):
    metadata = Metadata.load_from_json(metadata_path)
    print("Loaded metadata from metadata.json")
else:
    metadata = Metadata()
    metadata.detect_table_from_dataframe(
        table_name='ben10_table',
        data=real_data
    )
    metadata.save_to_json(metadata_path)
    print("Created and saved new metadata to metadata.json")

Loaded metadata from metadata.json


In [27]:
# Initialize and fit the CTGAN Synthesizer
synthesizer = CTGANSynthesizer(
    metadata,
    enforce_min_max_values = True,
    enforce_rounding=False,
    epochs=8000,
    verbose=True,
    cuda=True
)
synthesizer.fit(data=real_data)

Gen. (-0.61) | Discrim. (-0.38): 100%|██████████| 8000/8000 [15:12<00:00,  8.77it/s]


In [28]:
# Generate synthetic data
synthetic_data = synthesizer.sample(num_rows=10000)

In [29]:
# Save the synthetic data to a CSV file
synthetic_data.to_csv('../datasets/ben10_CTGAN_synthetic.csv', index=False)
print("Synthetic dataset saved to ../datasets/ben10_synthetic.csv")

Synthetic dataset saved to ../datasets/ben10_synthetic.csv


In [30]:
# Evaluate the quality of the synthetic data
quality_report = evaluate_quality(
    real_data,
    synthetic_data,
    metadata
)
print("Column Shapes Details:", quality_report.get_details(property_name='Column Shapes'))
print("Column Pair Trends Details:", quality_report.get_details(property_name='Column Pair Trends'))

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 13/13 [00:00<00:00, 590.93it/s]|
Column Shapes Score: 95.45%

(2/2) Evaluating Column Pair Trends: |██████████| 78/78 [00:00<00:00, 224.77it/s]|
Column Pair Trends Score: 81.21%

Overall Score (Average): 88.33%

Column Shapes Details:                    Column        Metric     Score
0              alien_name  TVComplement  0.926000
1              enemy_name  TVComplement  0.937833
2          alien1_species  TVComplement  0.910867
3      alien1_home_planet  TVComplement  0.947333
4   alien1_strength_level  TVComplement  0.987267
5      alien1_speed_level  TVComplement  0.968400
6     alien1_intelligence  TVComplement  0.990267
7          alien2_species  TVComplement  0.939700
8      alien2_home_planet  TVComplement  0.959733
9   alien2_strength_level  TVComplement  0.974833
10     alien2_speed_level  TVComplement  0.961967
11    alien2_intelligence  TVComplement  0.985100
12                 winner  TVComplement  0.9188

In [31]:
fig = synthesizer.get_loss_values_plot()
fig.show()

In [32]:
# Visualize the distribution of a the winner column
fig = get_column_plot(
    real_data=real_data,
    synthetic_data=synthetic_data,
    column_name='winner',
    metadata=metadata
)

fig.show()