# Step 3: SDG

In [1]:
%pwd

'C:\\Users\\flore\\source\\repos\\master-thesis-vt23\\notebooks'

In [2]:
# necessary imports for the section
from sdv.single_table import CTGANSynthesizer as CTGAN
from sdv.metadata import SingleTableMetadata

import pickle
import pandas as pd 
import os 
import sys
from io import StringIO

sys.path.append('../src')
from utils import (getPicklesFromDir, 
                   getExperimentConfig, 
                   extract_loss_info_from_stdout, 
                   create_loss_plot)

from mlflow_manager import MLFlowManager

# Get global experiment settings
config = getExperimentConfig()
# Get folders
folders = config['folders']
# Get dataset specific settings
dataset_settings = getPicklesFromDir(folders['settings_dir'])

In [3]:



def capture_stdout(func):
    def wrapper(*args, **kwargs):
        # Save the original stdout
        original_stdout = sys.stdout

        # Create a new StringIO object to temporarily redirect stdout
        sys.stdout = StringIO()
        
        # Call the original function and get its output
        func_output = func(*args, **kwargs)

        # Retrieve the captured stdout
        captured_stdout = sys.stdout.getvalue()
        
        # Close IO and restore the original stdout
        sys.stdout.close()
        sys.stdout = original_stdout

        # Return both the function output and the captured stdout
        return func_output, captured_stdout

    return wrapper

@capture_stdout
def train_sdg_model(model, data, sdg_name):
    print("#START#")
    print(sdg_name)
    model.fit(data)
    print("#END#")
    
    return model

In [4]:
# Specify datasets by Id, if None, all is run
run_dataset = config['run_dataset']

# get settings
quality_params = config['ctgan_param']['quality_params']
sd_size_factor = config['ctgan_param']['sd_size_factor']
num_SD = config['ctgan_param']['num_sd']


# run SDG generation
# for each dataset specific settings
for s_index, settings in enumerate(dataset_settings):
    
    if run_dataset is not None and settings['meta']['id'] not in run_dataset:
        continue
    
    metadata = SingleTableMetadata().load_from_json(settings['meta']['meta_filepath'])
    # Init experiment logging
    experiment_name = f"{settings['meta']['id']}-SDG"
    mlflow = MLFlowManager(experiment_name)
    
    # load original dataset
    cols_dtype=None
    if 'cols_dtype' in settings['meta']:
        cols_dtyped = settings['meta']['cols_dtype']
        
    original_data = pd.read_csv(f"{folders['real_dir']}{settings['meta']['filename']}", dtype=cols_dtype)
    
    # get the size to generate the synthetic data
    original_data_size = len(original_data)
    sd_size = original_data_size * sd_size_factor
    
    logg_tags = {'Source': settings['meta']['id']}
    
    # loop through the different quality parameters for the SDG
    for quality in quality_params:
        
        display(f"Start: SDG-{settings['meta']['id']}{quality}")
        logg_tags['Quality'] = quality
        
        sdg_name = f"S{settings['meta']['id']}{quality}"
        log_run = mlflow.start_run(sdg_name, tags=logg_tags)
        # Get path to save the artifacts, relative to notebooks dir
        artifact_path=log_run.info.artifact_uri.split('notebooks/')[1]
        
        mlflow.log_params(quality_params[quality])
        
        # creates model with sdg_param and quality_param as parameters
        model = CTGAN(metadata=metadata, **quality_params[quality])
        
        if 'sdg_constraints' in settings['meta']:
            model.add_constraints(constraints=settings['meta']['sdg_constraints'])
        
        print("#START#")
        print(sdg_name)
        #model.fit(original_data)

        model, stdout_loss = train_sdg_model(model, original_data, sdg_name)
        # extract loss, create loss plot and save it
        print("#END#")
        
        loss_dict = extract_loss_info_from_stdout(stdout_loss)
        fig = create_loss_plot(sdg_name, loss_dict[sdg_name])
        
        #save plot
        
        fig_path = f"{artifact_path}/{sdg_name}_loss_plot.png"
        fig.savefig(fig_path)
        #save data
        
        loss_df_path = f"{artifact_path}/{sdg_name}_loss.csv"
        loss_dict[sdg_name].to_csv(loss_df_path, index=False)        
        # saves the SDG model using cloudpickle
        
        model_path = f"{artifact_path}/{sdg_name}.pkl"
        model.save(model_path)
        #mlflow.end_run()
        
        # create num_SD SDGs and synthetic datasets for validating results
        for itr in range(num_SD):
            
            # creates Synthetic dataset name, using datset id, quality key, and itr number 
            # e.g. SD1Q1_2 means SDG trained on datset D1 with quality Q1 and copy num 2
            SD_name = f"S{settings['meta']['id']}{quality}_{str(itr)}"
            
            # relative file path for the synthetic dataset
            sd_path = f"{folders['sd_dir']}{SD_name}.csv"
            
            # generate synthetic data
            synthetic_data = model.sample(num_rows=sd_size)
            
            # save the synthetic dataset
            synthetic_data.to_csv(sd_path, index=False)
            
        mlflow.end_run()


'Start: SDG-D2'

#START#
SD2Q1


MemoryError: Unable to allocate 7.61 GiB for an array with shape (45211, 45211) and data type int32

In [None]:
""" The Loss values captured from the cell above's standard output will 
be used to create the generator vs discriminator loss plots.
"""

"""
#Deprecated
loss_values = extract_loss_info_from_stdout(stdout_loss.stdout)

if (loss_values not None):
    # Combine the loss values and save them
    combined_loss_df = pd.concat(loss_values.values(), keys=loss_values.keys(), axis=0, ignore_index=False)
    combined_loss_df = combined_loss_df.reset_index().rename(columns={'level_0': 'SDG'})
    combined_loss_df.to_csv(f"{folders['data_dir']}combined_sdg_loss.csv", index=False)

    for sdg_id in loss_values:
        fig = create_loss_plot(sdg_id, loss_values[sdg_id])
        # Save the plot to correct mlflow log
        run=mlflow.load_run_by_name(sdg_id)
        path=run.info.artifact_uri.replace("file:///", "")
        #save plot
        fig.savefig(f"{path}/{sdg_id}_loss_plot.png")
        #save data
        loss_values[sdg_id].to_csv(f"{path}/{sdg_id}.csv", index=False)

"""

---