In [1]:
"""device setting"""
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
"""load dataset and specify column types"""
import pandas as pd
data = pd.read_csv('./loan.csv') 
continuous_features = [
    'Age',
    'Experience',
    'Income', 
    'CCAvg',
    'Mortgage',
]
categorical_features = [
    'Family',
    'Personal Loan',
    'Securities Account',
    'CD Account',
    'Online',
    'CreditCard'
]
integer_features = [
    'Age',
    'Experience',
    'Income', 
    'Mortgage'
]

In [3]:
"""DistVAE"""
from distvae_tabular import distvae

distvae = distvae.DistVAE(
    data=data,
    continuous_features=continuous_features,
    categorical_features=categorical_features,
    integer_features=integer_features,
    epochs=5 # for quick checking (default is 1000)
)

Tranform Continuous Features...: 100%|██████████| 5/5 [00:00<00:00, 849.32it/s]


In [4]:
"""training"""
distvae.train()

inner loop: 100%|██████████| 10/10 [00:00<00:00, 12.27it/s]


Epoch [001/5], loss: 21.8929, recon: 21.8910, KL: 0.0197, activated: 0.0000


inner loop: 100%|██████████| 10/10 [00:00<00:00, 13.21it/s]


Epoch [002/5], loss: 20.9521, recon: 20.9433, KL: 0.0885, activated: 0.0000


inner loop: 100%|██████████| 10/10 [00:00<00:00, 12.39it/s]


Epoch [003/5], loss: 20.6083, recon: 20.5974, KL: 0.1090, activated: 0.0000


inner loop: 100%|██████████| 10/10 [00:00<00:00, 13.29it/s]


Epoch [004/5], loss: 20.3527, recon: 20.3105, KL: 0.4215, activated: 0.0000


inner loop: 100%|██████████| 10/10 [00:00<00:00, 13.30it/s]

Epoch [005/5], loss: 20.1602, recon: 20.0909, KL: 0.6932, activated: 0.0000





In [5]:
"""generate synthetic data"""
syndata = distvae.generate_data(100)
syndata

Generate Synthetic Dataset...: 100%|██████████| 1/1 [00:00<00:00, 13.52it/s]


Unnamed: 0,Age,Experience,Income,CCAvg,Mortgage,ID,ZIP Code,Family,Education,Personal Loan,Securities Account,CD Account,Online,CreditCard
0,45,14,109,1.850601,36,1812,430,0,0,0,1,0,1,0
1,54,23,80,2.333430,42,3102,185,1,0,0,0,0,1,0
2,45,11,41,0.702516,17,996,316,0,1,0,0,0,0,0
3,40,24,53,2.084561,-17,1715,277,1,1,0,0,0,1,0
4,43,19,99,2.047455,28,285,316,2,2,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,42,34,95,0.743443,-3,603,367,3,0,0,0,0,0,1
96,46,21,87,2.897642,74,2838,430,2,0,0,1,0,0,0
97,43,32,97,2.985854,111,4923,352,2,0,1,0,0,0,1
98,61,15,112,1.746705,83,830,173,0,0,0,0,0,1,0


In [6]:
"""generate synthetic data with Differential Privacy"""
syndata = distvae.generate_data(100, lambda_=0.1)
syndata

Generate Synthetic Dataset...: 100%|██████████| 1/1 [00:00<00:00, 13.65it/s]


Unnamed: 0,Age,Experience,Income,CCAvg,Mortgage,ID,ZIP Code,Family,Education,Personal Loan,Securities Account,CD Account,Online,CreditCard
0,44,27,75,2.813653,18,2982,316,0,1,1,0,1,1,0
1,52,23,75,2.340590,1,876,367,2,1,0,0,0,1,0
2,43,7,77,3.830019,49,2704,310,0,2,0,0,0,1,0
3,45,18,134,0.968052,11,165,367,1,2,0,0,0,0,1
4,43,25,72,-2.296475,-11020,4890,139,3,1,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,56,13,65,1.128928,57,3710,367,0,1,0,0,0,1,0
96,46,27,107,1.811131,50,4803,35,3,1,0,0,0,0,0
97,54,28,99,2.104103,-3394,2345,149,3,1,1,1,0,1,0
98,-97,26,88,1.994714,0,1331,430,1,0,0,0,0,0,0
