In [16]:
import pandas as pd
import glob
import os
import torch
import numpy as np
from collections import defaultdict
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sklearn.utils import shuffle
import sdv
from sdv.single_table import CTGANSynthesizer
from sdv.metadata import SingleTableMetadata
import plotly.subplots as sp
import plotly.graph_objects as go

In [17]:
data = pd.read_csv("data/train_X.csv")
data = data.sample(frac=0.1223, random_state=22)

In [18]:
display(dir(sdv.single_table))

['CTGANSynthesizer',
 'CopulaGANSynthesizer',
 'GaussianCopulaSynthesizer',
 'TVAESynthesizer',
 '__all__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 'base',
 'copulagan',
 'copulas',
 'ctgan',
 'utils']

In [19]:
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
python_dict = metadata.to_dict()

In [20]:
metadata.validate_data(data=data)
import torch
print(torch.cuda.is_available()) 
print(torch.cuda.get_device_name(0)) 


True
NVIDIA GeForce RTX 4090


In [22]:
synthesizer = CTGANSynthesizer(
    metadata,
    enforce_rounding=False,
    epochs=100,
    batch_size=500, 
    pac=5,     
    verbose=True
)


synthesizer.fit(data)

synthetic_data = synthesizer.sample(num_rows=1200000)


The 'SingleTableMetadata' is deprecated. Please use the new 'Metadata' class for synthesizers.

Gen. (-0.09) | Discrim. (-0.01): 100%|██████████| 100/100 [1:08:57<00:00, 41.38s/it]


In [None]:
synthetic_data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,0.004033,0.771967,0.513299,0.264703,0.157676,0.487437,0.000220,1.00000,0.000165,0.000018,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0
1,0.003506,0.777136,0.379838,0.229758,0.303190,0.887261,0.000000,1.00000,0.000183,0.000013,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0
2,0.003471,0.776111,0.032084,0.830562,0.631533,0.906900,0.000190,1.00000,0.000189,0.000139,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0
3,0.004033,0.776635,0.506780,0.753028,0.029415,0.443860,0.000146,1.00000,0.000189,0.000000,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0
4,0.391260,0.405719,0.434949,0.221919,0.804530,0.815630,0.998842,0.00131,0.000044,0.000000,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1199995,0.004335,0.772278,0.421279,0.392138,0.688748,0.716551,0.000000,1.00000,0.000071,0.000011,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0
1199996,0.004020,0.776841,0.817479,0.073521,0.058460,0.165057,0.000121,1.00000,0.000136,0.000000,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0
1199997,0.003361,0.776280,0.039561,0.810913,0.742475,0.579413,0.000000,1.00000,0.000189,0.000000,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0
1199998,0.393544,0.403136,0.826173,0.454794,0.215914,0.451817,1.000000,0.00000,0.000037,0.000000,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0


In [None]:
from sdv.evaluation.single_table import get_column_plot

figs = []
for i,j in enumerate(list(data.columns)[:-1]) : 

    fig = get_column_plot(
        real_data=data,
        synthetic_data=synthetic_data,
        column_name=j,
        metadata=metadata
    )
    figs.append(fig)
    
fig = sp.make_subplots(rows=2, cols=3, subplot_titles=list(data.columns)[:-1])

rows =2
columns = 3
for i in range(rows) :
    for j in range(columns) :
        count = i+j
        
        fig.add_trace(figs[count]['data'][0], row=i+1, col=j+1)
        fig.add_trace(figs[count]['data'][1], row=i+1, col=j+1)
    

fig.update_layout(height=600, width=900, title_text='2x3 Subplots')

fig.show()

In [None]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    data,
    synthetic_data,
    metadata)

Generating report ...

(1/2) Evaluating Column Shapes: |██████████| 18/18 [00:02<00:00,  6.81it/s]|
Column Shapes Score: 73.39%

(2/2) Evaluating Column Pair Trends: |██████████| 153/153 [00:14<00:00, 10.45it/s]|
Column Pair Trends Score: 76.83%

Overall Score (Average): 75.11%

