In [None]:
import os
import pickle
import pandas as pd
from sdv.datasets.local import load_csvs
from sdv.metadata import SingleTableMetadata
from sdv.single_table import CTGANSynthesizer, TVAESynthesizer
import matplotlib.pyplot as plt

# Step 1 - CTGAN

In [None]:
# Import external data
datasets = load_csvs(folder_name='datasets/')
gan = datasets[''] #REVIEW - Update path

In [None]:
# Remove ranks with less than 2 mutations
gan = gan.loc[gan.groupby('rank')['rank'].transform('size') >= 3]

In [None]:
# Detect metadata
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data=gan)

# Validate metadata
metadata.validate()
metadata

In [None]:
# Parameters
epochs = 80 
batch_size = 4000
discriminator_dim = (8,8) 
discriminator_decay = 1e-6 
discriminator_lr = 2e-4 
discriminator_steps = 2 
embedding_dim = 128
generator_dim = (8,8) 
generator_decay = 1e-6
generator_lr = 2e-4 
pac = 10 

# Print parameters
discriminator_dim_print = f'{discriminator_dim[0]}_{discriminator_dim[1]}'
generator_dim_print = f'{generator_dim[0]}_{generator_dim[1]}'

# Define synthesizer options
synthesizer = CTGANSynthesizer(metadata,
                               epochs = epochs,
                               batch_size = batch_size,
                               discriminator_dim = discriminator_dim,
                               discriminator_decay = discriminator_decay,
                               discriminator_lr =discriminator_lr,
                               discriminator_steps = discriminator_steps,
                               embedding_dim = embedding_dim,
                               generator_dim = generator_dim,
                               generator_decay = generator_decay,
                               generator_lr = generator_lr,
                               pac = pac,
                               verbose = True,
                               enforce_min_max_values = False)

In [None]:
# Transform the columns
synthesizer.auto_assign_transformers(gan)
synthesizer.get_transformers()

In [None]:
# Train the model
synthesizer.fit(gan)
synthesizer.save(filepath=f'/outdir/ctgan_ep{epochs}_bs{batch_size}_ddim{discriminator_dim_print}_dlr{discriminator_lr}_ds{discriminator_steps}_edim{embedding_dim}_gdim{generator_dim_print}_glr{generator_lr}_pac{pac}_step1.pkl') #REVIEW - Update path

In [None]:
# Simulate data
synthetic_data = synthesizer.sample(num_rows=gan.shape[0])

## Compare real vs simulated results

In [None]:
# Sort the ranks
conversion = pd.DataFrame({'rank': pd.concat([gan['rank'], synthetic_data['rank']]).unique()})
conversion['conv'] = range(len(conversion))

In [None]:
# Real data
gan_bar = gan.merge(conversion, on='rank', how='left')
bar_pd = pd.crosstab(index=gan_bar['conv'], columns='count')
plt.bar(x = bar_pd['count'].index, height = bar_pd['count'])
plt.xlabel('Rank')
plt.ylabel('Frequency')
plt.title('Real data')
plt.show()

In [None]:
# Synthetic data
synthetic_bar = synthetic_data.merge(conversion, on='rank', how='left')
bar_pd = pd.crosstab(index=synthetic_bar['conv'], columns='count')
plt.bar(x = bar_pd['count'].index, height = bar_pd['count'])
plt.xlabel('Rank')
plt.ylabel('Frequency')
plt.title('Synthetic data')
plt.show()

# Step 2 - TVAE

In [None]:
# Import external data
datasets = load_csvs(folder_name='datasets/')
gan = datasets[''] #REVIEW - Update path

In [None]:
# Remove ranks with less than 2 mutations
gan = gan.loc[gan.groupby('rank')['rank'].transform('size') >= 3]

In [None]:
# Generate step1 data
synthesizer = CTGANSynthesizer.load('/outdir/_step1.pkl') #REVIEW - Update path
synthetic_data = synthesizer.sample(num_rows=gan.shape[0])

In [None]:
# For each window from step1
rank1 = gan['rank'].unique()

# Define parameters
epochs = 3000
batch_size = 5000 
compress_dims = (256,256) 
decompress_dims = (256,256) 
embedding_dim = 128 
l2scale = 1e-5 
loss_factor = 2 

# Print paramaters
compress_dims_print = f'{compress_dims[0]}_{compress_dims[1]}'
decompress_dims_print = f'{decompress_dims[0]}_{decompress_dims[1]}'

synthetic_data2:pd.DataFrame = pd.DataFrame()
for (idx,window) in enumerate(rank1):

    if idx == 103: #The last rank only has 1 row
        continue
    
    # Make a subset
    tmp = pd.DataFrame({'start': gan['start'][gan['rank'] == window]})
    
    # Detect metadata
    metadata = SingleTableMetadata()
    metadata.detect_from_dataframe(data=tmp)
        
    # Define synthesizer options
    synthesizer = TVAESynthesizer(metadata,
                              epochs = epochs,
                              batch_size = batch_size,
                              compress_dims = compress_dims,
                              decompress_dims = decompress_dims,
                              embedding_dim =embedding_dim,
                              l2scale = l2scale,
                              loss_factor = loss_factor,
                              enforce_min_max_values = True)
    
    # Transform the columns
    synthesizer.auto_assign_transformers(tmp)

    # Train the model
    synthesizer.fit(tmp)
    synthesizer.save(filepath=f'/outdir/bin{idx}_tvae_ep{epochs}_bs{batch_size}_cdim{compress_dims_print}_edim{embedding_dim}_ddim{decompress_dims_print}_step2.pkl') #REVIEW - Update path

    # Simulate data
    nrow = synthetic_data[synthetic_data['rank'] == window].shape[0]
    tmp_synthetic_data = synthesizer.sample(num_rows=nrow)
    
    # Concatenate the data
    synthetic_data2 = pd.concat([synthetic_data2, tmp_synthetic_data])

# Save simulated data
synthetic_data2.to_csv(f'/outdir/tvae_ep{epochs}_bs{batch_size}_cdim{compress_dims_print}_edim{embedding_dim}_ddim{decompress_dims_print}_step2.csv',
                      index=False) #REVIEW - Update path
synthetic_data.to_csv(f'/outdir/tvae_ep{epochs}_bs{batch_size}_cdim{compress_dims_print}_edim{embedding_dim}_ddim{decompress_dims_print}_step1.csv',
                      index=False) #REVIEW - Update path

# Step 3 - Save the models into a pickle

In [None]:
posModel = {}

# Step1 model
posModel['step1'] = CTGANSynthesizer.load('/outdir/step1.pkl') #REVIEW - Update path

# Step2 models
for (idx,window) in enumerate(rank1):
    posModel[window] = TVAESynthesizer.load(f'/outdir/bin{idx}_*_step2.pkl') #REVIEW - Update path

# Save the dictionary into a pickle file
with open('/outdir/positions_model.pkl', 'wb') as handle: #REVIEW - Update path
    pickle.dump(posModel, handle, protocol=pickle.HIGHEST_PROTOCOL)
