In [1]:
"""device setting"""
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
"""load dataset and specify column types"""
import pandas as pd
data = pd.read_csv('./loan.csv') 
continuous_features = [
    'Age',
    'Experience',
    'Income', 
    'CCAvg',
    'Mortgage',
]
categorical_features = [
    'Family',
    'Personal Loan',
    'Securities Account',
    'CD Account',
    'Online',
    'CreditCard'
]
integer_features = [
    'Age',
    'Experience',
    'Income', 
    'Mortgage'
]

In [3]:
"""DistVAE"""
from distvae_tabular import distvae

distvae = distvae.DistVAE(
    data=data,
    continuous_features=continuous_features,
    categorical_features=categorical_features,
    integer_features=integer_features,
    
    seed=42,
    latent_dim=4,
    beta=0.1,
    hidden_dim=128,
    
    epochs=5, # for quick checking 
    batch_size=256,
    lr=0.001,
    
    step=0.1,
    threshold=1e-8,
    device="cpu"
)

Tranform Continuous Features...: 100%|██████████| 5/5 [00:00<00:00, 636.54it/s]


In [4]:
"""training"""
distvae.train()

Training...:  20%|██        | 1/5 [00:01<00:07,  1.86s/it]

Epoch [001/5], loss: 21.5486, recon: 21.5438, KL: 0.0479, activated: 0.0000


Training...:  40%|████      | 2/5 [00:03<00:05,  1.76s/it]

Epoch [002/5], loss: 20.6885, recon: 20.6783, KL: 0.1013, activated: 0.0000


Training...:  60%|██████    | 3/5 [00:05<00:03,  1.70s/it]

Epoch [003/5], loss: 20.3174, recon: 20.2761, KL: 0.4135, activated: 0.0000


Training...:  80%|████████  | 4/5 [00:06<00:01,  1.67s/it]

Epoch [004/5], loss: 20.0674, recon: 19.9621, KL: 1.0529, activated: 0.0057


Training...: 100%|██████████| 5/5 [00:08<00:00,  1.70s/it]

Epoch [005/5], loss: 19.7280, recon: 19.4921, KL: 2.3594, activated: 0.0514





In [7]:
"""generate synthetic data"""
syndata = distvae.generate_data(100)
syndata

Generate Synthetic Dataset...: 100%|██████████| 1/1 [00:00<00:00, 14.85it/s]


Unnamed: 0,Age,Experience,Income,CCAvg,Mortgage,ID,ZIP Code,Family,Education,Personal Loan,Securities Account,CD Account,Online,CreditCard
0,53,16,90,1.560485,7,3928,295,0,0,0,0,0,1,0
1,45,9,101,1.090905,76,1714,241,3,2,0,0,0,0,0
2,33,17,120,2.587753,26,2198,292,0,2,1,0,0,0,1
3,61,42,85,1.682516,-2,92,267,0,1,0,0,0,1,0
4,45,15,93,4.307488,-19,4381,316,3,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,56,23,44,1.513769,9,3293,8,1,2,0,0,0,1,1
96,44,13,106,1.069319,15,231,367,3,0,0,0,0,1,0
97,41,15,61,3.347709,-6,4967,445,1,2,0,1,1,0,0
98,39,14,132,3.859935,361,4686,316,3,2,0,0,0,1,0


In [6]:
"""generate synthetic data with Differential Privacy"""
syndata = distvae.generate_data(100, lambda_=0.1)
syndata

Generate Synthetic Dataset...: 100%|██████████| 1/1 [00:00<00:00, 14.09it/s]


Unnamed: 0,Age,Experience,Income,CCAvg,Mortgage,ID,ZIP Code,Family,Education,Personal Loan,Securities Account,CD Account,Online,CreditCard
0,43,14,101,3.124581,10,515,430,1,0,1,0,1,0,1
1,38,25,88,1.405275,-5,2078,367,3,2,0,0,0,1,0
2,37,16,94,1.111694,31,2449,331,2,0,0,0,0,0,1
3,64,35,41,1.027695,183,3812,316,3,2,0,0,0,1,0
4,44,22,90,3.305277,-8,2008,141,2,0,0,1,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,53,35,119,0.301635,25,111,399,1,2,0,0,0,0,0
96,41,28,49,2.027475,128,1479,40,2,1,0,0,0,0,1
97,41,12,123,3.356443,112,3899,267,2,0,1,0,1,1,0
98,38,65,-497,2.140378,166,2172,277,3,2,0,0,0,1,0
