# Synthesize a Table (CTGAN)

In this notebook, we'll use the SDV to create synthetic data for a single table and evaluate it. The SDV uses machine learning to learn patterns from real data and emulates them when creating synthetic data.

We'll use the **CTGAN** algorithm to do this. CTGAN uses generative adversarial networks (GANs) to create synthesize data with high fidelity.

# 0. Installation

Install the SDV library.


In [None]:
%pip install sdv

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sdv
  Downloading sdv-1.1.0-py2.py3-none-any.whl (117 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.9/117.9 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting Faker<15,>=10 (from sdv)
  Downloading Faker-14.2.1-py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
Collecting copulas<0.10,>=0.9.0 (from sdv)
  Downloading copulas-0.9.0-py2.py3-none-any.whl (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.3/54.3 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ctgan<0.8,>=0.7.2 (from sdv)
  Downloading ctgan-0.7.3-py2.py3-none-any.whl (26 kB)
Collecting deepecho<0.5,>=0.4.1 (from sdv)
  Downloading deepecho-0.4.1-py2.py3-none-any.whl (28 kB)
Collecting rdt<2,>=1.4.2 (from sdv)
  Downloading rdt-1.4.2-py2.py3-none-an

# 1. Loading the demo data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
data_link = '/content/drive/MyDrive/Dataset/CICIDS2017/train and test/train_data.csv'

In [None]:
import pandas as pd

data = pd.read_csv(data_link)

In [None]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1696725 entries, 0 to 1696724
Data columns (total 78 columns):
 #   Column                       Dtype  
---  ------                       -----  
 0   Flow Duration                int64  
 1   Total Fwd Packets            int64  
 2   Total Backward Packets       int64  
 3   Total Length of Fwd Packets  int64  
 4   Total Length of Bwd Packets  int64  
 5   Fwd Packet Length Max        int64  
 6   Fwd Packet Length Min        int64  
 7   Fwd Packet Length Mean       float64
 8   Fwd Packet Length Std        float64
 9   Bwd Packet Length Max        int64  
 10  Bwd Packet Length Min        int64  
 11  Bwd Packet Length Mean       float64
 12  Bwd Packet Length Std        float64
 13  Flow Bytes/s                 float64
 14  Flow Packets/s               float64
 15  Flow IAT Mean                float64
 16  Flow IAT Std                 float64
 17  Flow IAT Max                 int64  
 18  Flow IAT Min                 int64  
 19  

In [None]:
data.tail(5)

Unnamed: 0,Flow Duration,Total Fwd Packets,Total Backward Packets,Total Length of Fwd Packets,Total Length of Bwd Packets,Fwd Packet Length Max,Fwd Packet Length Min,Fwd Packet Length Mean,Fwd Packet Length Std,Bwd Packet Length Max,...,min_seg_size_forward,Active Mean,Active Std,Active Max,Active Min,Idle Mean,Idle Std,Idle Max,Idle Min,Label
1696720,5184565,8,6,385,3974,210,0,48.125,77.93484,1992,...,20,0.0,0.0,0,0,0.0,0.0,0,0,BENIGN
1696721,116282843,18,15,838,3532,419,0,46.555556,135.496569,1766,...,32,105545.9,220564.8,770572,38717,10000000.0,3360.645,10000000,9997860,BENIGN
1696722,31376,1,1,54,149,54,54,54.0,0.0,149,...,32,0.0,0.0,0,0,0.0,0.0,0,0,BENIGN
1696723,3,2,0,0,0,0,0,0.0,0.0,0,...,32,0.0,0.0,0,0,0.0,0.0,0,0,DoS Hulk
1696724,102630653,14,3,2541,6,231,0,181.5,98.363337,6,...,32,3204070.0,4530471.0,6407596,543,19200000.0,18800000.0,51300000,5839202,DoS slowloris


## Extracted rare classes from cicids017
"Heartbleed", "Bot", "Infiltration"

In [None]:
# Create a list of the labels to extract
labels = ["Heartbleed", "Bot", "Infiltration"]

# Extract the rows where the Label column is in the list of labels
extracted = data[data["Label"].isin(labels)]

In [None]:
extracted.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1198 entries, 1265 to 1696683
Data columns (total 78 columns):
 #   Column                       Non-Null Count  Dtype  
---  ------                       --------------  -----  
 0   Flow Duration                1198 non-null   int64  
 1   Total Fwd Packets            1198 non-null   int64  
 2   Total Backward Packets       1198 non-null   int64  
 3   Total Length of Fwd Packets  1198 non-null   int64  
 4   Total Length of Bwd Packets  1198 non-null   int64  
 5   Fwd Packet Length Max        1198 non-null   int64  
 6   Fwd Packet Length Min        1198 non-null   int64  
 7   Fwd Packet Length Mean       1198 non-null   float64
 8   Fwd Packet Length Std        1198 non-null   float64
 9   Bwd Packet Length Max        1198 non-null   int64  
 10  Bwd Packet Length Min        1198 non-null   int64  
 11  Bwd Packet Length Mean       1198 non-null   float64
 12  Bwd Packet Length Std        1198 non-null   float64
 13  Flow Bytes/s

In [None]:
# Print count of each type
print(extracted['Label'].value_counts())

Bot             1167
Infiltration      25
Heartbleed         6
Name: Label, dtype: int64


##Write rare classes into a csv file

In [None]:
csv_link = '/content/drive/MyDrive/Dataset/CICIDS2017/train and test/rare_class.csv'

In [None]:
import csv
extracted.to_csv(csv_link, index=False)

## Load rare classes for resampling

In [None]:
from sdv.metadata import SingleTableMetadata

metadata = SingleTableMetadata()

In [None]:
metadata.detect_from_csv(filepath=csv_link)

In [None]:
metadata.validate()

In [None]:
metadata

{
    "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1",
    "columns": {
        "Flow Duration": {
            "sdtype": "numerical"
        },
        "Total Fwd Packets": {
            "sdtype": "numerical"
        },
        "Total Backward Packets": {
            "sdtype": "numerical"
        },
        "Total Length of Fwd Packets": {
            "sdtype": "numerical"
        },
        "Total Length of Bwd Packets": {
            "sdtype": "numerical"
        },
        "Fwd Packet Length Max": {
            "sdtype": "numerical"
        },
        "Fwd Packet Length Min": {
            "sdtype": "numerical"
        },
        "Fwd Packet Length Mean": {
            "sdtype": "numerical"
        },
        "Fwd Packet Length Std": {
            "sdtype": "numerical"
        },
        "Bwd Packet Length Max": {
            "sdtype": "numerical"
        },
        "Bwd Packet Length Min": {
            "sdtype": "numerical"
        },
        "Bwd Packet Length Mean": {
            "sd

In [None]:
df = pd.read_csv(csv_link)

In [None]:
df.head()

Unnamed: 0,Flow Duration,Total Fwd Packets,Total Backward Packets,Total Length of Fwd Packets,Total Length of Bwd Packets,Fwd Packet Length Max,Fwd Packet Length Min,Fwd Packet Length Mean,Fwd Packet Length Std,Bwd Packet Length Max,...,min_seg_size_forward,Active Mean,Active Std,Active Max,Active Min,Idle Mean,Idle Std,Idle Max,Idle Min,Label
0,41,1,1,6,6,6,6,6.0,0.0,6,...,20,0.0,0.0,0,0,0.0,0.0,0,0,Bot
1,79912,4,3,207,134,195,0,51.75,95.541876,128,...,20,0.0,0.0,0,0,0.0,0.0,0,0,Bot
2,17,1,1,6,6,6,6,6.0,0.0,6,...,20,0.0,0.0,0,0,0.0,0.0,0,0,Bot
3,19,1,1,6,6,6,6,6.0,0.0,6,...,20,0.0,0.0,0,0,0.0,0.0,0,0,Bot
4,17,1,1,6,6,6,6,6.0,0.0,6,...,20,0.0,0.0,0,0,0.0,0.0,0,0,Bot


In [None]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1198 entries, 0 to 1197
Data columns (total 78 columns):
 #   Column                       Non-Null Count  Dtype  
---  ------                       --------------  -----  
 0   Flow Duration                1198 non-null   int64  
 1   Total Fwd Packets            1198 non-null   int64  
 2   Total Backward Packets       1198 non-null   int64  
 3   Total Length of Fwd Packets  1198 non-null   int64  
 4   Total Length of Bwd Packets  1198 non-null   int64  
 5   Fwd Packet Length Max        1198 non-null   int64  
 6   Fwd Packet Length Min        1198 non-null   int64  
 7   Fwd Packet Length Mean       1198 non-null   float64
 8   Fwd Packet Length Std        1198 non-null   float64
 9   Bwd Packet Length Max        1198 non-null   int64  
 10  Bwd Packet Length Min        1198 non-null   int64  
 11  Bwd Packet Length Mean       1198 non-null   float64
 12  Bwd Packet Length Std        1198 non-null   float64
 13  Flow Bytes/s      

# 2. Basic Usage

## 2.1 Creating a Synthesizer

An SDV **synthesizer** is an object that you can use to create synthetic data. It learns patterns from the real data and replicates them to generate synthetic data.

In [None]:
from sdv.single_table import CTGANSynthesizer

synthesizer = CTGANSynthesizer(metadata)
synthesizer.fit(df)



1m56

For larger datasets, this phase may take longer. A drawback of a GAN-based model like CTGAN is performance.

When this code finishes running, the synthesizer is ready to use.

In [None]:
synthesizer.get_parameters()

{'enforce_min_max_values': True,
 'enforce_rounding': True,
 'locales': None,
 'embedding_dim': 128,
 'generator_dim': (256, 256),
 'discriminator_dim': (256, 256),
 'generator_lr': 0.0002,
 'generator_decay': 1e-06,
 'discriminator_lr': 0.0002,
 'discriminator_decay': 1e-06,
 'batch_size': 500,
 'discriminator_steps': 1,
 'log_frequency': True,
 'verbose': False,
 'epochs': 300,
 'pac': 10,
 'cuda': True}

## 2.2 Generating Synthetic Data
Use the `sample` function and pass in any number of rows to synthesize.

In [None]:
from sdv.sampling import Condition

Bot = Condition(
    num_rows=10000,
    column_values={'Label': 'Bot'}
)

Infiltration = Condition(
    num_rows=10000,
    column_values={'Label': 'Infiltration'}
)
Heartbleed = Condition(
    num_rows=10000,
    column_values={'Label': 'Heartbleed'}
)

In [None]:
synthetic_data = synthesizer.sample_from_conditions(
    conditions=[Bot, Heartbleed, Infiltration],
    output_file_path='/content/drive/MyDrive/Dataset/CICIDS2017/train and test/resampling.csv'
)

Sampling conditions: 100%|██████████| 30000/30000 [00:24<00:00, 1224.41it/s]


In [None]:
synthetic_data.head()

Unnamed: 0,Flow Duration,Total Fwd Packets,Total Backward Packets,Total Length of Fwd Packets,Total Length of Bwd Packets,Fwd Packet Length Max,Fwd Packet Length Min,Fwd Packet Length Mean,Fwd Packet Length Std,Bwd Packet Length Max,...,min_seg_size_forward,Active Mean,Active Std,Active Max,Active Min,Idle Mean,Idle Std,Idle Max,Idle Min,Label
0,72353180,2717,0,12248,0,2085,0,68.729699,19.752802,0,...,20,0.0,8539700.0,0,0,0.0,26606.3,26741848,0,Bot
1,398755,21,4,1317,0,208,0,24.556069,25.782277,108,...,20,0.0,0.0,37980,13612,0.0,0.0,171432,55210,Bot
2,1,1,19,0,3093,0,0,95.229123,0.0,0,...,20,0.0,21906.51,0,15300,0.0,150348.7,0,158680,Bot
3,1,38,0,2866110,0,0,5,84.287619,22.896713,24,...,20,0.0,1023.752,0,85322,0.0,107398.4,12880,0,Bot
4,66274388,27,0,0,0,463,0,78.056624,26.651754,33,...,28,9580577.0,18727.31,0,19292,0.0,7728867.0,540354,0,Bot


In [None]:
# Print count of each type
print(synthetic_data['Label'].value_counts())

Bot             10000
Heartbleed      10000
Infiltration    10000
Name: Label, dtype: int64


## 2.3 Anonymization - SKIP THIS PART

In the original dataset, we had some sensitive columns such as the guest's email, billing address and phone number. In the synthetic data, these columns are **fully anonymized** -- they contain entirely fake values that follow the format of the original.

In [None]:
sensitive_column_names = ['guest_email', 'billing_address', 'credit_card_number']

real_data[sensitive_column_names].head(3)

Unnamed: 0,guest_email,billing_address,credit_card_number
0,michaelsanders@shaw.net,"49380 Rivers Street\nSpencerville, AK 68265",4075084747483975747
1,randy49@brown.biz,"88394 Boyle Meadows\nConleyberg, TN 22063",180072822063468
2,webermelissa@neal.com,"0323 Lisa Station Apt. 208\nPort Thomas, LA 82585",38983476971380


In [None]:
synthetic_data[sensitive_column_names].head(3)

Unnamed: 0,guest_email,billing_address,credit_card_number
0,dsullivan@example.net,"90469 Karla Knolls Apt. 781\nSusanberg, NC 28401",5161033759518983
1,steven59@example.org,"1080 Ashley Creek Apt. 622\nWest Amy, NM 25058",4133047413145475690
2,brandon15@example.net,"99923 Anderson Trace Suite 861\nNorth Haley, T...",4977328103788


_Note that any repeated values between the real and synthetic data occur by random chance. This ensures that an attacker won't be able to guess the real, sensitive values based on these columns alone._

## 2.4 Evaluating Real vs. Synthetic Data
The synthetic data replicates the **mathematical properties** of the real data. To get more insight, we can use the `evaluation` module.

In [None]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    df,
    synthetic_data,
    metadata
)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
The real data in column 'Bwd Avg Bulk Rate' contains a constant value. Correlation is undefined for constant data.


The synthetic data in column 'Bwd Avg Bulk Rate' contains a constant value. Correlation is undefined for constant data.


The real data in column 'Bwd PSH Flags' contains a constant value. Correlation is undefined for constant data.


The synthetic data in column 'Bwd PSH Flags' contains a constant value. Correlation is undefined for constant data.


The real data in column 'Fwd URG Flags' contains a constant value. Correlation is undefined for constant data.


The synthetic data in column 'Fwd URG Flags' contains a constant value. Correlation is undefined for constant data.


The real data in column 'Bwd URG Flags' contains a constant value. Correlation is undefined for constant data.


The synthetic data in column 'Bwd URG Flags' contains a constant value. Correlation is undefined for constant data.


The


Overall Quality Score: 76.29%

Properties:
Column Shapes: 64.42%
Column Pair Trends: 88.17%


The report allows us to visualize the different properties that were captured. For example, the visualization below shows us _which_ individual column shapes were well-captured and which weren't.

In [None]:
from sdv.evaluation.single_table import run_diagnostic

diagnostic_report = run_diagnostic(
    real_data=df,
    synthetic_data=synthetic_data,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [01:46<00:00, 26.64s/it]


DiagnosticResults:

SUCCESS:
✓ The synthetic data covers over 90% of the categories present in the real data
✓ Over 90% of the synthetic rows are not copies of the real data
✓ The synthetic data follows over 90% of the min/max boundaries set by the real data

! The synthetic data is missing more than 10% of the numerical ranges present in the real data





In [None]:
quality_report.get_score()

0.7629301694966113

In [None]:
quality_report.get_properties()

Unnamed: 0,Property,Score
0,Column Shapes,0.644165
1,Column Pair Trends,0.881695


In [None]:
quality_report.get_details(property_name='Column Shapes')

Unnamed: 0,Column,Metric,Quality Score
0,Flow Duration,KSComplement,0.477175
1,Total Fwd Packets,KSComplement,0.374475
2,Total Backward Packets,KSComplement,0.507671
3,Total Length of Fwd Packets,KSComplement,0.513696
4,Total Length of Bwd Packets,KSComplement,0.288239
...,...,...,...
73,Idle Mean,KSComplement,0.570197
74,Idle Std,KSComplement,0.563351
75,Idle Max,KSComplement,0.308797
76,Idle Min,KSComplement,0.472964


In [None]:
quality_report.get_visualization('Column Shapes')

## 2.5 Visualizing the Data
For more insights, we can visualize the real vs. synthetic data.

Let's perform a 1D visualization comparing a column of the real data to the synthetic data.

In [None]:
from sdv.evaluation.single_table import get_column_plot

fig = get_column_plot(
    real_data=df,
    synthetic_data=synthetic_data,
    column_name='Label',
    metadata=metadata
)
    
fig.show()

We can also visualize in 2D, comparing the correlations of a pair of columns.

In [None]:
from sdv.evaluation.single_table import get_column_pair_plot

fig = get_column_pair_plot(
    real_data=df,
    synthetic_data=synthetic_data,
    column_names=['min_seg_size_forward', 'act_data_pkt_fwd'],
    metadata=metadata
)
    
fig.show()

## 2.6 Saving and Loading
We can save the synthesizer to share with others and sample more synthetic data in the future.

In [None]:
synthesizer.save('/content/drive/MyDrive/Models/my_synthesizer.pkl')

In [None]:
synthesizer = CTGANSynthesizer.load('my_synthesizer.pkl')

# 3. CTGAN Customization
When using this synthesizer, we can make a tradeoff between training time and data quality using the `epochs` parameter: Higher `epochs` means that the synthesizer will train for longer, and ideally improve the data quality.
 

In [None]:
custom_synthesizer = CTGANSynthesizer(
    metadata,
    epochs=600)
custom_synthesizer.fit(df)


No rounding scheme detected for column 'Flow Bytes/s'. Data will not be rounded.


No rounding scheme detected for column 'Flow Packets/s'. Data will not be rounded.


No rounding scheme detected for column 'Flow IAT Mean'. Data will not be rounded.


No rounding scheme detected for column 'Flow IAT Std'. Data will not be rounded.


No rounding scheme detected for column 'Fwd IAT Mean'. Data will not be rounded.


No rounding scheme detected for column 'Fwd IAT Std'. Data will not be rounded.


No rounding scheme detected for column 'Bwd IAT Mean'. Data will not be rounded.


No rounding scheme detected for column 'Bwd IAT Std'. Data will not be rounded.


No rounding scheme detected for column 'Packet Length Variance'. Data will not be rounded.


No rounding scheme detected for column 'Active Mean'. Data will not be rounded.


No rounding scheme detected for column 'Active Std'. Data will not be rounded.


No rounding scheme detected for column 'Idle Mean'. Data will not be rounded.


In [None]:
custom_synthesizer.get_parameters()

{'enforce_min_max_values': True,
 'enforce_rounding': True,
 'locales': None,
 'embedding_dim': 128,
 'generator_dim': (256, 256),
 'discriminator_dim': (256, 256),
 'generator_lr': 0.0002,
 'generator_decay': 1e-06,
 'discriminator_lr': 0.0002,
 'discriminator_decay': 1e-06,
 'batch_size': 500,
 'discriminator_steps': 1,
 'log_frequency': True,
 'verbose': False,
 'epochs': 600,
 'pac': 10,
 'cuda': True}

<font color="maroon"><i><b>This code takes about 10 min to run.</b></i></font>

After we've trained our synthesizer, we can verify the changes to the data quality by creating some synthetic data and evaluating it.

In [None]:
from sdv.sampling import Condition

Bot = Condition(
    num_rows=10000,
    column_values={'Label': 'Bot'}
)

Infiltration = Condition(
    num_rows=10000,
    column_values={'Label': 'Infiltration'}
)
Heartbleed = Condition(
    num_rows=10000,
    column_values={'Label': 'Heartbleed'}
)

In [None]:
synthetic_data_customized = custom_synthesizer.sample_from_conditions(
    conditions=[Bot, Heartbleed, Infiltration],
    output_file_path='/content/drive/MyDrive/Dataset/CICIDS2017/train and test/resampling_customized.csv'
)

Sampling conditions: 100%|██████████| 30000/30000 [00:29<00:00, 1009.12it/s]


In [None]:
quality_report = evaluate_quality(
    df,
    synthetic_data_customized,
    metadata
)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
The real data in column 'Bwd Avg Bulk Rate' contains a constant value. Correlation is undefined for constant data.


The synthetic data in column 'Bwd Avg Bulk Rate' contains a constant value. Correlation is undefined for constant data.


The real data in column 'Bwd PSH Flags' contains a constant value. Correlation is undefined for constant data.


The synthetic data in column 'Bwd PSH Flags' contains a constant value. Correlation is undefined for constant data.


The real data in column 'Fwd URG Flags' contains a constant value. Correlation is undefined for constant data.


The synthetic data in column 'Fwd URG Flags' contains a constant value. Correlation is undefined for constant data.


The real data in column 'Bwd URG Flags' contains a constant value. Correlation is undefined for constant data.


The synthetic data in column 'Bwd URG Flags' contains a constant value. Correlation is undefined for constant data.


The


Overall Quality Score: 75.6%

Properties:
Column Shapes: 63.31%
Column Pair Trends: 87.88%


In [None]:
from sdv.evaluation.single_table import run_diagnostic

diagnostic_report = run_diagnostic(
    real_data=df,
    synthetic_data=synthetic_data_customized,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [01:49<00:00, 27.36s/it]


DiagnosticResults:

SUCCESS:
✓ The synthetic data covers over 90% of the categories present in the real data
✓ Over 90% of the synthetic rows are not copies of the real data
✓ The synthetic data follows over 90% of the min/max boundaries set by the real data

! The synthetic data is missing more than 10% of the numerical ranges present in the real data





In [None]:
quality_report.get_score()

0.7559699641342805

In [None]:
quality_report.get_properties()

Unnamed: 0,Property,Score
0,Column Shapes,0.6331
1,Column Pair Trends,0.87884


In [None]:
quality_report.get_details(property_name='Column Shapes')

Unnamed: 0,Column,Metric,Quality Score
0,Flow Duration,KSComplement,0.547675
1,Total Fwd Packets,KSComplement,0.292108
2,Total Backward Packets,KSComplement,0.354571
3,Total Length of Fwd Packets,KSComplement,0.228963
4,Total Length of Bwd Packets,KSComplement,0.365419
...,...,...,...
73,Idle Mean,KSComplement,0.837364
74,Idle Std,KSComplement,0.678118
75,Idle Max,KSComplement,0.653297
76,Idle Min,KSComplement,0.368564


In [None]:
quality_report.get_visualization('Column Shapes')

While GANs are able to model complex patterns and shapes, it is not easy to understand how they are learning -- but it is possible to modify the underlying architecture of the neural networks.

For users who are familiar with the GAN architecture, there are extra parameters you can use to tune CTGAN to your particular needs. For more details, see [the CTGAN documentation](https://docs.sdv.dev/sdv/single-table-data/modeling/synthesizers/ctgansynthesizer).