In [1]:
# stdlib
from typing import Any, List
  
# third party
import numpy as np
import pandas as pd
from ctgan import CTGAN,EnhancedCTGAN
from torch.nn import TransformerEncoder
# synthcity absolute
from synthcity.plugins.core.dataloader import DataLoader, GenericDataLoader
from synthcity.plugins.core.distribution import Distribution
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema
from synthcity.plugins.core.distribution import (
    Distribution,
    IntegerDistribution,
)


class sdv_ctgan_plugin(Plugin):
    """SDV CTGAN integration in synthcity."""

    def __init__(
        self,
        embedding_n_units: int = 128,
        epochs: int = 150,
        batch_size: int = 100,
        cat_limit: int = 25,
        dropout = 0.1,
        num_first = True,
        batch_first = True,
        num_layers = 2,
        num_heads = 4,
        **kwargs: Any
    ) -> None:
        super().__init__(**kwargs)
        self.cat_limit = cat_limit
        self.model = CTGAN(
            embedding_dim=embedding_n_units,
            batch_size=batch_size,
            epochs=epochs,
            verbose=False,
            dropout = dropout,
            num_first = num_first,
            batch_first = batch_first,
            num_layers = num_layers,
            num_heads = num_heads
        )

    @staticmethod
    def name() -> str:
        return "trans_ctgan"

    @staticmethod
    def type() -> str:
        return "debug"

    @staticmethod
    def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
        """
        We can customize the hyperparameter space, and use it in AutoML benchmarks.
        """
        return [
            IntegerDistribution(name="embedding_n_units", low=100, high=500, step=50),
            IntegerDistribution(name="batch_size", low=100, high=300, step=50),
            IntegerDistribution(name="n_iter", low=100, high=500, step=50),
        ]

    def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "sdvPlugin":
        """We selected the discrete columns based on the count of unique values, and train the CTGAN"""
        discrete_columns = []

        for col in X.columns:
            if len(X[col].unique()) < self.cat_limit:
                discrete_columns.append(col)
        from torch.nn import TransformerEncoder
        self.model.fit(X.dataframe(), discrete_columns=discrete_columns)
        return self

    def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFrame:
        return self._safe_generate(self.model.sample, count, syn_schema)

  from .autonotebook import tqdm as notebook_tqdm


    The default C++ compiler could not be found on your system.
    You need to either define the CXX environment variable or a symlink to the g++ command.
    For example if g++-8 is the command you can do
      import os
      os.environ['CXX'] = 'g++-8'
    


In [2]:
# synthcity absolute
from synthcity.plugins import Plugins

generators = Plugins()

generators.list()

[2025-02-22T21:38:17.729810+0800][20896][CRITICAL] module disabled: e:\qycache\anaconda\envs\LLM\lib\site-packages\synthcity\plugins\generic\plugin_goggle.py


['privbayes',
 'image_cgan',
 'dpgan',
 'arf',
 'ctgan',
 'pategan',
 'aim',
 'nflow',
 'bayesian_network',
 'fflows',
 'adsgan',
 'great',
 'survival_ctgan',
 'marginal_distributions',
 'survae',
 'radialgan',
 'survival_nflow',
 'timevae',
 'tvae',
 'dummy_sampler',
 'image_adsgan',
 'rtvae',
 'decaf',
 'ddpm',
 'uniform_sampler',
 'survival_gan',
 'timegan']

In [3]:
generators.add("trans_ctgan", sdv_ctgan_plugin)

generators.list()

['privbayes',
 'image_cgan',
 'dpgan',
 'arf',
 'ctgan',
 'pategan',
 'aim',
 'nflow',
 'bayesian_network',
 'fflows',
 'adsgan',
 'great',
 'survival_ctgan',
 'marginal_distributions',
 'survae',
 'radialgan',
 'survival_nflow',
 'timevae',
 'tvae',
 'dummy_sampler',
 'image_adsgan',
 'rtvae',
 'decaf',
 'ddpm',
 'uniform_sampler',
 'survival_gan',
 'trans_ctgan',
 'timegan']

In [None]:
def smart_normalization(data, discrete_cols, threshold=0.1):

    continuous_cols = [col for col in data.columns if col not in discrete_cols]
    if not continuous_cols:
        return data.copy(), False

    needs_scaling = False
    for col in continuous_cols:
        col_mean = data[col].mean()
        col_std = data[col].std()

        if abs(col_mean) > threshold or not (1-threshold < col_std < 1+threshold):
            needs_scaling = True
            break
    
    if needs_scaling:
        from sklearn.preprocessing import StandardScaler
        scaler = StandardScaler()
        data_scaled = data.copy()
        data_scaled[continuous_cols] = scaler.fit_transform(data[continuous_cols])
        return data_scaled, True
    else:
        return data.copy(), False
import pandas as pd
import numpy as np

def auto_detect_discrete_columns(data, unique_ratio_threshold=0.05, unique_count_threshold=20):

    discrete_cols = []
    
    for col in data.columns:
        col_data = data[col].dropna()

        if len(col_data) == 0:
            continue

        dtype = col_data.dtype

        if dtype in [object, 'category', bool]:
            discrete_cols.append(col)
            continue

        if np.issubdtype(dtype, np.number):
            n_unique = col_data.nunique()
            total = len(col_data)
            
            if n_unique <= unique_count_threshold:
                discrete_cols.append(col)
                continue

            unique_ratio = n_unique / total
            if unique_ratio < unique_ratio_threshold:

                if np.issubdtype(dtype, np.integer):
                    discrete_cols.append(col)
                elif (col_data == col_data.astype(int)).all():
                    discrete_cols.append(col)
    
    return discrete_cols

if __name__ == "__main__":

    #real_path = "../synthcity-main/tutorials/covertype_preprocessed.csv"
    real_path = "../CTAB-GAN-main/Real_Datasets/Credit.csv"
    #real_path = "../CTAB-GAN-main/Real_Datasets/Adult3.csv"
    #real_path = '../CTGAN-main/CTGAN-main/examples/csv/train_clean.csv'
    real_data = pd.read_csv(real_path)
    
    discrete_cols = auto_detect_discrete_columns(real_data)

    real_data_normalized, scaled = smart_normalization(real_data, discrete_cols)

In [None]:
from sklearn.datasets import load_breast_cancer,load_diabetes
real_path = "../CTAB-GAN-main/Real_Datasets/Adult3.csv"
#real_path = '../CTGAN-main/CTGAN-main/examples/csv/train_clean.csv'
#real_path = "C../CTAB-GAN-main/Real_Datasets/Credit.csv"
#real_path = "../synthcity-main/tutorials/covertype_preprocessed.csv"
#real_path = "../CTAB-GAN-main/Real_Datasets/train2.csv"
#real_path = "../CTAB-GAN-main/Real_Datasets/creditcard2.csv"
data = pd.read_csv(real_path)
loader = GenericDataLoader(data)
#loader = GenericDataLoader(real_data_normalized)

In [None]:
# Train the new plugin
from torch.nn import TransformerEncoder
gen = generators.get("trans_ctgan", epochs=250)
#"E:\qycache\anaconda\envs\LLM\Lib\site-packages\ctgan\synthesizers\ctgan.py" 
gen.fit(loader)

  0%|          | 0/250 [00:00<?, ?it/s]
Hype-parameter:[dropout:0.1,num_first:True,batch_first:True,num_heads:4,num_layers:2]
epoch 0
newtype2
epoch 1
newtype2
epoch 2
newtype2
epoch 3
newtype2
epoch 4
newtype2
epoch 5
newtype2
epoch 6
newtype2
epoch 7
newtype2
epoch 8
newtype2
epoch 9
newtype2
epoch 10
newtype2
epoch 11
newtype2
epoch 12
newtype2
epoch 13
newtype2
epoch 14
newtype2
epoch 15
newtype2
epoch 16
newtype2
epoch 17
newtype2
epoch 18
newtype2
epoch 19
newtype2
epoch 20
newtype2
epoch 21
newtype2
epoch 22
newtype2
epoch 23
newtype2
epoch 24
newtype2
epoch 25
newtype2
epoch 26
newtype2
epoch 27
newtype2
epoch 28
newtype2
epoch 29
newtype2
epoch 30
newtype2
epoch 31
newtype2
epoch 32
newtype2
epoch 33
newtype2
epoch 34
newtype2
epoch 35
newtype2
epoch 36
newtype2
epoch 37
newtype2
epoch 38
newtype2
epoch 39
newtype2
epoch 40
newtype2
epoch 41
newtype2
epoch 42
newtype2
epoch 43
newtype2
epoch 44
newtype2
epoch 45
newtype2
epoch 46
newtype2
epoch 47
newtype2
epoch 48
newtype2
ep

<__main__.sdv_ctgan_plugin at 0x150434e92e0>

In [12]:
# Generate some new data
a = gen.generate(count=2000).dataframe()
a.to_csv('./OriginalCTGAN-Adult_31.csv', index=False)