In [54]:
import pandas as pd
from ydata_synthetic.synthesizers.regular import RegularSynthesizer
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters


data_frame = pd.read_csv("insurance.csv")

data_frame.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1338 entries, 0 to 1337
Data columns (total 7 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   age       1338 non-null   int64  
 1   sex       1338 non-null   object 
 2   bmi       1338 non-null   float64
 3   children  1338 non-null   int64  
 4   smoker    1338 non-null   object 
 5   region    1338 non-null   object 
 6   charges   1338 non-null   float64
dtypes: float64(2), int64(2), object(3)
memory usage: 73.3+ KB


In [55]:
num_cols = ['age','bmi', 'children', 'charges']
cat_cols = ['sex','smoker', 'region']

# Defining the training parameters
batch_size = 900
epochs = 500+1
learning_rate = 2e-4
beta_1 = 0.1
beta_2 = 0.2

ctgan_args = ModelParameters(batch_size=batch_size, lr=learning_rate, betas=(beta_1, beta_2))

train_args = TrainParameters(epochs=epochs)

synth = RegularSynthesizer(modelname='ctgan', model_parameters=ctgan_args)
synth.fit(data=data_frame[data_frame['smoker'] == 'yes'], train_arguments=train_args, num_cols=num_cols, cat_cols=cat_cols)
synth.save('insurance_syntetic_yes.pkl')

In [56]:
#########################################################
#    Loading and sampling from a trained synthesizer    #
#########################################################

synth = RegularSynthesizer.load('insurance_syntetic_yes.pkl')
synth_data = synth.sample(1000)
synth_data.to_csv('insurance_syntetic_yes.txt', sep=',', index=False)

In [57]:
synth = RegularSynthesizer(modelname='ctgan', model_parameters=ctgan_args)
synth.fit(data=data_frame[data_frame['smoker'] == 'no'], train_arguments=train_args, num_cols=num_cols, cat_cols=cat_cols)
synth.save('insurance_syntetic_no.pkl')



Epoch: 0 | critic_loss: 7.713175296783447 | generator_loss: 0.684279203414917
Epoch: 1 | critic_loss: 7.541848659515381 | generator_loss: 0.6597513556480408
Epoch: 2 | critic_loss: 7.436697006225586 | generator_loss: 0.6570126414299011
Epoch: 3 | critic_loss: 7.369967460632324 | generator_loss: 0.6342743039131165
Epoch: 4 | critic_loss: 7.218049049377441 | generator_loss: 0.6673294901847839
Epoch: 5 | critic_loss: 7.020167827606201 | generator_loss: 0.6536417603492737
Epoch: 6 | critic_loss: 6.889869213104248 | generator_loss: 0.6503921151161194
Epoch: 7 | critic_loss: 6.713210105895996 | generator_loss: 0.5866379737854004
Epoch: 8 | critic_loss: 6.523467063903809 | generator_loss: 0.5901046395301819
Epoch: 9 | critic_loss: 6.420361042022705 | generator_loss: 0.5595600605010986
Epoch: 10 | critic_loss: 6.126209735870361 | generator_loss: 0.5609115958213806
Epoch: 11 | critic_loss: 5.896271705627441 | generator_loss: 0.5880247354507446
Epoch: 12 | critic_loss: 5.688381671905518 | genera

In [58]:
#########################################################
#    Loading and sampling from a trained synthesizer    #
#########################################################

synth = RegularSynthesizer.load('insurance_syntetic_no.pkl')
synth_data = synth.sample(3000)
synth_data.to_csv('insurance_syntetic_no.txt', sep=',', index=False)