In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
import pandas as pd

df = pd.read_csv('data/cancer_data.csv')
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 24 columns):
 #   Column                    Non-Null Count  Dtype 
---  ------                    --------------  ----- 
 0   Age                       1000 non-null   int64 
 1   Gender                    1000 non-null   object
 2   Air Pollution             1000 non-null   int64 
 3   Alcohol use               1000 non-null   int64 
 4   Dust Allergy              1000 non-null   int64 
 5   OccuPational Hazards      1000 non-null   int64 
 6   Genetic Risk              1000 non-null   int64 
 7   chronic Lung Disease      1000 non-null   int64 
 8   Balanced Diet             1000 non-null   int64 
 9   Obesity                   1000 non-null   int64 
 10  Smoking                   1000 non-null   int64 
 11  Passive Smoker            1000 non-null   int64 
 12  Chest Pain                1000 non-null   int64 
 13  Coughing of Blood         1000 non-null   int64 
 14  Fatigue                  

In [3]:
df.replace(
    {
        'Gender':{'Male':0,'Female':1},
        'Level':{'Normal':0,'Benign':1,'Malignant':2}
    },
    inplace=True
)
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 24 columns):
 #   Column                    Non-Null Count  Dtype
---  ------                    --------------  -----
 0   Age                       1000 non-null   int64
 1   Gender                    1000 non-null   int64
 2   Air Pollution             1000 non-null   int64
 3   Alcohol use               1000 non-null   int64
 4   Dust Allergy              1000 non-null   int64
 5   OccuPational Hazards      1000 non-null   int64
 6   Genetic Risk              1000 non-null   int64
 7   chronic Lung Disease      1000 non-null   int64
 8   Balanced Diet             1000 non-null   int64
 9   Obesity                   1000 non-null   int64
 10  Smoking                   1000 non-null   int64
 11  Passive Smoker            1000 non-null   int64
 12  Chest Pain                1000 non-null   int64
 13  Coughing of Blood         1000 non-null   int64
 14  Fatigue                   1000 non-null  

In [4]:
from sdv.metadata import SingleTableMetadata

metadata = SingleTableMetadata()
metadata.detect_from_dataframe(df)

In [6]:
from sdv.single_table import TVAESynthesizer

synthesizer = TVAESynthesizer(
    metadata,
    embedding_dim=512,
    compress_dims=(512, 512),
    decompress_dims=(512, 512),
    epochs=5000,
)

synthesizer.fit(df)

synthesizer.get_loss_values()

Unnamed: 0,Epoch,Batch,Loss
0,0,0,95.572693
1,0,1,124.664413
2,1,0,97.038643
3,1,1,96.498344
4,2,0,100.876518
...,...,...,...
9995,4997,1,4.055758
9996,4998,0,3.408872
9997,4998,1,3.473060
9998,4999,0,3.702660


In [8]:
# synthesizer = TVAESynthesizer.load('synthesizer.pkl')
synthesizer.save('synthesizer.pkl')

In [15]:
from sdv.evaluation.single_table import evaluate_quality

synth_df = synthesizer.sample(300000, batch_size=500)
big_df = pd.concat([df, synth_df], ignore_index=True).drop_duplicates()
report = evaluate_quality(df, big_df, metadata)
big_df.info()

Sampling rows: 100%|██████████| 300000/300000 [02:05<00:00, 2383.81it/s]


Generating report ...
(1/2) Evaluating Column Shapes: : 100%|██████████| 24/24 [00:00<00:00, 92.67it/s]
(2/2) Evaluating Column Pair Trends: : 100%|██████████| 276/276 [00:33<00:00,  8.32it/s]

Overall Score: 81.45%

Properties:
- Column Shapes: 90.81%
- Column Pair Trends: 72.1%
<class 'pandas.core.frame.DataFrame'>
Index: 101640 entries, 0 to 300995
Data columns (total 24 columns):
 #   Column                    Non-Null Count   Dtype
---  ------                    --------------   -----
 0   Age                       101640 non-null  int64
 1   Gender                    101640 non-null  int64
 2   Air Pollution             101640 non-null  int64
 3   Alcohol use               101640 non-null  int64
 4   Dust Allergy              101640 non-null  int64
 5   OccuPational Hazards      101640 non-null  int64
 6   Genetic Risk              101640 non-null  int64
 7   chronic Lung Disease      101640 non-null  int64
 8   Balanced Diet             101640 non-null  int64
 9   Obesity       

In [16]:
big_df.replace(
    {
        'Gender':{0:'Male',1:'Female'},
        'Level':{0:'Normal',1:'Benign',2:'Malignant'}
    },
    inplace=True
)

big_df.to_csv('data/big_cancer_data.csv', index=False)