In [1]:
"""device setting"""
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
"""load dataset and specify column types"""
import pandas as pd
data = pd.read_csv('./loan.csv') 
continuous_features = [
    'Age',
    'Experience',
    'Income', 
    'CCAvg',
    'Mortgage',
]
categorical_features = [
    'Family',
    'Personal Loan',
    'Securities Account',
    'CD Account',
    'Online',
    'CreditCard'
]
integer_features = [
    'Age',
    'Experience',
    'Income', 
    'Mortgage'
]

In [3]:
"""DistVAE"""
from distvae_tabular import distvae

distvae = distvae.DistVAE(
    data=data,
    continuous_features=continuous_features,
    categorical_features=categorical_features,
    integer_features=integer_features,
    epochs=5 # for quick checking (default is 1000)
)

Tranform Continuous Features...: 100%|██████████| 5/5 [00:00<00:00, 619.12it/s]


In [4]:
"""training"""
distvae.train()

inner loop: 100%|██████████| 20/20 [00:01<00:00, 12.09it/s]


Epoch [001/5], loss: 21.6421, recon: 21.6358, KL: 0.0631, activated: 0.0000


inner loop: 100%|██████████| 20/20 [00:01<00:00, 13.20it/s]


Epoch [002/5], loss: 20.6265, recon: 20.6030, KL: 0.2351, activated: 0.0000


inner loop: 100%|██████████| 20/20 [00:01<00:00, 11.87it/s]


Epoch [003/5], loss: 20.1609, recon: 20.0515, KL: 1.0937, activated: 0.0000


inner loop: 100%|██████████| 20/20 [00:01<00:00, 12.32it/s]


Epoch [004/5], loss: 19.7933, recon: 19.5525, KL: 2.4079, activated: 0.0024


inner loop: 100%|██████████| 20/20 [00:01<00:00, 13.05it/s]

Epoch [005/5], loss: 19.3658, recon: 18.9908, KL: 3.7507, activated: 0.0037





In [5]:
"""generate synthetic data"""
syndata = distvae.generate_data(100)
syndata

Generate Synthetic Dataset...: 100%|██████████| 1/1 [00:00<00:00, 15.17it/s]


Unnamed: 0,Age,Experience,Income,CCAvg,Mortgage,ID,ZIP Code,Family,Education,Personal Loan,Securities Account,CD Account,Online,CreditCard
0,42,11,117,1.712880,26,1812,430,0,0,0,0,0,1,0
1,45,17,79,2.721187,32,3102,316,1,0,0,0,0,1,0
2,45,14,-20,-0.476261,3,996,316,0,1,0,0,0,1,0
3,33,17,32,2.210978,-14,24,277,1,1,0,0,0,0,0
4,41,17,109,2.218716,5,285,316,2,2,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,37,32,125,0.897688,-14,603,367,3,0,0,0,0,0,0
96,50,23,62,1.948308,98,3999,430,2,2,0,0,0,1,1
97,42,25,75,2.327831,86,4923,188,2,1,1,0,0,1,1
98,60,18,108,1.413141,102,2750,173,0,0,0,0,0,1,0


In [6]:
"""generate synthetic data with Differential Privacy"""
syndata = distvae.generate_data(100, lambda_=0.1)
syndata

Generate Synthetic Dataset...: 100%|██████████| 1/1 [00:00<00:00, 15.65it/s]


Unnamed: 0,Age,Experience,Income,CCAvg,Mortgage,ID,ZIP Code,Family,Education,Personal Loan,Securities Account,CD Account,Online,CreditCard
0,41,21,66,2.342595,5,2982,316,0,1,0,0,0,1,0
1,44,16,77,2.839641,25,876,367,2,0,0,0,0,1,0
2,43,8,37,2.633479,63,1625,310,0,2,0,0,0,1,0
3,37,11,106,0.498669,3,165,367,1,2,0,0,0,0,0
4,41,21,75,-2.032567,-11000,4890,139,3,0,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,52,9,74,2.564981,68,4902,367,1,0,0,0,0,1,0
96,50,27,86,0.993105,67,4803,35,3,1,0,0,0,0,0
97,53,22,78,1.422818,-3411,710,149,3,1,1,1,0,1,0
98,-97,25,82,1.741950,-4,703,430,1,2,0,0,0,0,0
