In [1]:
from sdv.single_table import (
    CopulaGANSynthesizer, CTGANSynthesizer, GaussianCopulaSynthesizer, TVAESynthesizer)
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torch.utils.data
from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional
from torch.nn import functional as F
import logging
import pandas as pd
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.preprocessing import KBinsDiscretizer
import sdgym
from sdv.metadata.single_table import SingleTableMetadata
from sdgym.datasets import load_dataset
from sklearn.preprocessing import LabelEncoder
import rdt
from rdt.transformers import AnonymizedFaker, IDGenerator, RegexGenerator, get_default_transformers
from copy import deepcopy
import inspect
import copy
import traceback
import importlib
from pathlib import Path
from rdt.transformers.pii.anonymization import get_anonymized_transformer
import warnings
import sys
from pandas.core.tools.datetimes import _guess_datetime_format_for_array
from pandas.api.types import is_float_dtype, is_integer_dtype
import json
import datetime
import pkg_resources
import cloudpickle
import os
from collections import defaultdict
import uuid
import math
from copulas.multivariate import GaussianMultivariate
from tqdm import tqdm as tqdm1
from tqdm import tqdm as tqdm2
import functools
import copulas
from ctgan import CTGAN
from sdgym import create_single_table_synthesizer
from joblib import Parallel, delayed
from rdt.transformers import ClusterBasedNormalizer, OneHotEncoder
from collections import namedtuple
import contextlib
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from keras.layers import Input, Dense
from keras.models import Model
from sklearn.ensemble import RandomForestClassifier
from table_evaluator import TableEvaluator
from scipy.spatial.distance import cdist
from rdt import HyperTransformer

In [2]:
class DataProcessor:
    """Single table data processor.

    This class handles all pre and post processing that is done to a single table to get it ready
    for modeling and finalize sampling. These processes include formatting, transformations,
    anonymization and constraint handling.

    Args:
        metadata (metadata.SingleTableMetadata):
            The single table metadata instance that will be used to apply constraints and
            transformations to the data.
        enforce_rounding (bool):
            Define rounding scheme for FloatFormatter. If True, the data returned by
            reverse_transform will be rounded to that place. Defaults to True.
        enforce_min_max_values (bool):
            Specify whether or not to clip the data returned by reverse_transform of the numerical
            transformer, FloatFormatter, to the min and max values seen during fit.
            Defaults to True.
        model_kwargs (dict):
            Dictionary specifying the kwargs that need to be used in each tabular
            model when working on this table. This dictionary contains as keys the name of the
            TabularModel class and as values a dictionary containing the keyword arguments to use.
            This argument exists mostly to ensure that the models are fitted using the same
            arguments when the same DataProcessor is used to fit different model instances on
            different slices of the same table.
        table_name (str):
            Name of table this processor is for. Optional.
        locales (str or list):
            Default locales to use for AnonymizedFaker transformers. Optional, defaults to using
            Faker's default locale.
    """

    _DTYPE_TO_SDTYPE = {
        'i': 'numerical',
        'f': 'numerical',
        'O': 'categorical',
        'b': 'boolean',
        'M': 'datetime',
    }

    _COLUMN_RELATIONSHIP_TO_TRANSFORMER = {
        'address': 'RandomLocationGenerator',
    }

    def _update_numerical_transformer(self, enforce_rounding, enforce_min_max_values):
        custom_float_formatter = rdt.transformers.FloatFormatter(
            missing_value_replacement='mean',
            missing_value_generation='random',
            learn_rounding_scheme=enforce_rounding,
            enforce_min_max_values=enforce_min_max_values
        )
        self._transformers_by_sdtype.update({'numerical': custom_float_formatter})

    def _detect_multi_column_transformers(self):
        """Detect if there are any multi column transformers in the metadata.

        Returns:
            dict:
                A dictionary mapping column names to the multi column transformer.
        """
        result = {}
        # if self.metadata.column_relationships:
        #     for relationship in self.metadata._valid_column_relationships:
        #         column_names = tuple(relationship['column_names'])
        #         relationship_type = relationship['type']
        #         if relationship_type in self._COLUMN_RELATIONSHIP_TO_TRANSFORMER:
        #             transformer_name = self._COLUMN_RELATIONSHIP_TO_TRANSFORMER[relationship_type]
        #             module = getattr(rdt.transformers, relationship_type)
        #             transformer = getattr(module, transformer_name)
        #             result[column_names] = transformer(locales=self._locales)

        return result

    def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
                 model_kwargs=None, table_name=None, locales=None):
        self.metadata = metadata
        self._enforce_rounding = enforce_rounding
        self._enforce_min_max_values = enforce_min_max_values
        self._model_kwargs = model_kwargs or {}
        self._locales = locales
        self._constraints_list = []
        self._constraints = []
        self._constraints_to_reverse = []
        self._custom_constraint_classes = {}

        self._transformers_by_sdtype = deepcopy(get_default_transformers())
        self._transformers_by_sdtype['id'] = rdt.transformers.RegexGenerator()
        del self._transformers_by_sdtype['text']
        self.grouped_columns_to_transformers = self._detect_multi_column_transformers()

        self._update_numerical_transformer(enforce_rounding, enforce_min_max_values)
        self._hyper_transformer = rdt.HyperTransformer()
        self.table_name = table_name
        self._dtypes = None
        self.fitted = False
        self.formatters = {}
        self._primary_key = self.metadata.primary_key
        self._prepared_for_fitting = False
        self._keys = deepcopy(self.metadata.alternate_keys)
        if self._primary_key:
            self._keys.append(self._primary_key)
        self.columns = None

    def _get_grouped_columns(self):
        """Get the columns that are part of a multi column transformer.

        Returns:
            list:
                A list of columns that are part of a multi column transformer.
        """
        return [
            col for col_tuple in self.grouped_columns_to_transformers for col in col_tuple
        ]

    def _get_columns_in_address_transformer(self):
        """Get the columns that are part of an address transformer.

        Returns:
            list:
                A list of columns that are part of the address transformers.
        """
        try:
            _check_import_address_transformers()
            result = []
            for col_tuple, transformer in self.grouped_columns_to_transformers.items():
                is_randomlocationgenerator = isinstance(
                    transformer, rdt.transformers.address.RandomLocationGenerator
                )
                is_regionalanonymizer = isinstance(
                    transformer, rdt.transformers.address.RegionalAnonymizer
                )
                if is_randomlocationgenerator or is_regionalanonymizer:
                    result.extend(list(col_tuple))

            return result
        except ImportError:
            return []

    def get_model_kwargs(self, model_name):
        """Return the required model kwargs for the indicated model.

        Args:
            model_name (str):
                Qualified Name of the model for which model kwargs
                are needed.

        Returns:
            dict:
                Keyword arguments to use on the indicated model.
        """
        return deepcopy(self._model_kwargs.get(model_name))

    def set_model_kwargs(self, model_name, model_kwargs):
        """Set the model kwargs used for the indicated model.

        Args:
            model_name (str):
                Qualified Name of the model for which the kwargs will be set.
            model_kwargs (dict):
                The key word arguments for the model.
        """
        self._model_kwargs[model_name] = model_kwargs

    def get_sdtypes(self, primary_keys=False):
        """Get a ``dict`` with the ``sdtypes`` for each column of the table.

        Args:
            primary_keys (bool):
                Whether or not to include the primary key fields. Defaults to ``False``.

        Returns:
            dict:
                Dictionary that contains the column names and ``sdtypes``.
        """
        sdtypes = {}
        for name, column_metadata in self.metadata.columns.items():
            sdtype = column_metadata['sdtype']

            if primary_keys or (name not in self._keys):
                sdtypes[name] = sdtype

        return sdtypes

    def _validate_custom_constraint_name(self, class_name):
        reserved_class_names = list(get_subclasses(Constraint))
        if class_name in reserved_class_names:
            error_message = (
                f"The name '{class_name}' is a reserved constraint name. "
                'Please use a different one for the custom constraint.'
            )
            raise InvalidConstraintsError(error_message)

    def _validate_custom_constraints(self, filepath, class_names, module):
        errors = []
        for class_name in class_names:
            try:
                self._validate_custom_constraint_name(class_name)
            except InvalidConstraintsError as err:
                errors += err.errors

            if not hasattr(module, class_name):
                errors.append(f"The constraint '{class_name}' is not defined in '{filepath}'.")

        if errors:
            raise InvalidConstraintsError(errors)

    def load_custom_constraint_classes(self, filepath, class_names):
        """Load a custom constraint class for the current synthesizer.

        Args:
            filepath (str):
                String representing the absolute or relative path to the python file where
                the custom constraints are declared.
            class_names (list):
                A list of custom constraint classes to be imported.
        """
        path = Path(filepath)
        module = load_module_from_path(path)
        self._validate_custom_constraints(filepath, class_names, module)
        for class_name in class_names:
            constraint_class = getattr(module, class_name)
            self._custom_constraint_classes[class_name] = constraint_class

    def add_custom_constraint_class(self, class_object, class_name):
        """Add a custom constraint class for the synthesizer to use.

        Args:
            class_object (sdv.constraints.Constraint):
                A custom constraint class object.
            class_name (str):
                The name to assign this custom constraint class. This will be the name to use
                when writing a constraint dictionary for ``add_constraints``.
        """
        self._validate_custom_constraint_name(class_name)
        self._custom_constraint_classes[class_name] = class_object

    def _validate_constraint_dict(self, constraint_dict):
        """Validate a constraint against the single table metadata.

        Args:
            constraint_dict (dict):
                A dictionary containing:
                    * ``constraint_class``: Name of the constraint to apply.
                    * ``constraint_parameters``: A dictionary with the constraint parameters.
        """
        params = {'constraint_class', 'constraint_parameters'}
        keys = constraint_dict.keys()
        missing_params = params - keys
        if missing_params:
            raise SynthesizerInputError(
                f'A constraint is missing required parameters {missing_params}. '
                'Please add these parameters to your constraint definition.'
            )

        extra_params = keys - params
        if extra_params:
            raise SynthesizerInputError(
                f'Unrecognized constraint parameter {extra_params}. '
                'Please remove these parameters from your constraint definition.'
            )

        constraint_class = constraint_dict['constraint_class']
        constraint_parameters = constraint_dict['constraint_parameters']
        try:
            if constraint_class in self._custom_constraint_classes:
                constraint_class = self._custom_constraint_classes[constraint_class]

            else:
                constraint_class = Constraint._get_class_from_dict(constraint_class)

        except KeyError:
            raise InvalidConstraintsError(f"Invalid constraint class ('{constraint_class}').")

        if 'column_name' in constraint_parameters:
            column_names = [constraint_parameters.get('column_name')]
        else:
            column_names = constraint_parameters.get('column_names')

        columns_in_address = self._get_columns_in_address_transformer()
        if columns_in_address and column_names:
            address_constraint_columns = set(column_names) & set(columns_in_address)
            if address_constraint_columns:
                to_print = "', '".join(address_constraint_columns)
                raise InvalidConstraintsError(
                    f"The '{to_print}' columns are part of an address. You cannot add constraints "
                    'to columns that are part of an address group.'
                )

        constraint_class._validate_metadata(**constraint_parameters)

    def add_constraints(self, constraints):
        """Add constraints to the data processor.

        Args:
            constraints (list):
                List of constraints described as dictionaries in the following format:
                    * ``constraint_class``: Name of the constraint to apply.
                    * ``constraint_parameters``: A dictionary with the constraint parameters.
        """
        errors = []
        validated_constraints = []
        for constraint_dict in constraints:
            constraint_dict = deepcopy(constraint_dict)
            if 'constraint_parameters' in constraint_dict:
                constraint_dict['constraint_parameters'].update({'metadata': self.metadata})
            try:
                self._validate_constraint_dict(constraint_dict)
                validated_constraints.append(constraint_dict)
            except (AggregateConstraintsError, InvalidConstraintsError) as e:
                reformated_errors = '\n'.join(map(str, e.errors))
                errors.append(reformated_errors)

        if errors:
            raise InvalidConstraintsError(errors)

        self._constraints_list.extend(validated_constraints)
        self._prepared_for_fitting = False

    def get_constraints(self):
        """Get a list of the current constraints that will be used.

        Returns:
            list:
                List of dictionaries describing the constraints for this data processor.
        """
        constraints = deepcopy(self._constraints_list)
        for i in range(len(constraints)):
            del constraints[i]['constraint_parameters']['metadata']

        return constraints

    def _load_constraints(self):
        loaded_constraints = []
        default_constraints_classes = list(get_subclasses(Constraint))
        for constraint in self._constraints_list:
            if constraint['constraint_class'] in default_constraints_classes:
                loaded_constraints.append(Constraint.from_dict(constraint))

            else:
                constraint_class = self._custom_constraint_classes[constraint['constraint_class']]
                loaded_constraints.append(
                    constraint_class(**constraint.get('constraint_parameters', {}))
                )

        return loaded_constraints

    def _fit_constraints(self, data):
        self._constraints = self._load_constraints()
        errors = []
        for constraint in self._constraints:
            try:
                constraint.fit(data)
            except Exception as e:
                errors.append(e)

        if errors:
            raise AggregateConstraintsError(errors)

    def _transform_constraints(self, data, is_condition=False):
        errors = []
        if not is_condition:
            self._constraints_to_reverse = []

        for constraint in self._constraints:
            try:
                data = constraint.transform(data)
                if not is_condition:
                    self._constraints_to_reverse.append(constraint)

            except (MissingConstraintColumnError, FunctionError) as error:
                if isinstance(error, MissingConstraintColumnError):
                    LOGGER.info(
                        'Unable to transform %s with columns %s because they are not all available'
                        ' in the data. This happens due to multiple, overlapping constraints.',
                        constraint.__class__.__name__,
                        error.missing_columns
                    )
                    log_exc_stacktrace(LOGGER, error)
                else:
                    # Error came from custom constraint. We don't want to crash but we do
                    # want to log it.
                    LOGGER.info(
                        'Unable to transform %s with columns %s due to an error in transform: \n'
                        '%s\nUsing the reject sampling approach instead.',
                        constraint.__class__.__name__,
                        constraint.column_names,
                        str(error)
                    )
                    log_exc_stacktrace(LOGGER, error)
                if is_condition:
                    indices_to_drop = data.columns.isin(constraint.constraint_columns)
                    columns_to_drop = data.columns.where(indices_to_drop).dropna()
                    data = data.drop(columns_to_drop, axis=1)

            except Exception as error:
                errors.append(error)

        if errors:
            raise AggregateConstraintsError(errors)

        return data

    def _update_transformers_by_sdtypes(self, sdtype, transformer):
        self._transformers_by_sdtype[sdtype] = transformer

    @staticmethod
    def create_anonymized_transformer(sdtype, column_metadata, enforce_uniqueness, locales=None):
        """Create an instance of an ``AnonymizedFaker``.

        Read the extra keyword arguments from the ``column_metadata`` and use them to create
        an instance of an ``AnonymizedFaker`` transformer.

        Args:
            sdtype (str):
                Sematic data type or a ``Faker`` function name.
            column_metadata (dict):
                A dictionary representing the rest of the metadata for the given ``sdtype``.
            enforce_uniqueness (bool):
                If ``True`` overwrite ``enforce_uniqueness`` with ``True`` to ensure unique
                generation for primary keys.
            locales (str or list):
                Locale or list of locales to use for the AnonymizedFaker transfomer. Optional,
                defaults to using Faker's default locale.

        Returns:
            Instance of ``rdt.transformers.pii.AnonymizedFaker``.
        """
        kwargs = {'locales': locales}
        for key, value in column_metadata.items():
            if key not in ['pii', 'sdtype']:
                kwargs[key] = value

        if enforce_uniqueness:
            kwargs['enforce_uniqueness'] = True

        try:
            transformer = get_anonymized_transformer(sdtype, kwargs)
        except AttributeError as error:
            raise SynthesizerInputError(
                f"The sdtype '{sdtype}' is not compatible with any of the locales. To "
                "continue, try changing the locales or adding 'en_US' as a possible option."
            ) from error

        return transformer

    def create_regex_generator(self, column_name, sdtype, column_metadata, is_numeric):
        """Create a ``RegexGenerator`` for the ``id`` columns.

        Read the keyword arguments from the ``column_metadata`` and use them to create
        an instance of a ``RegexGenerator``. If ``regex_format`` is not present in the
        metadata a default ``[0-1a-z]{5}`` will be used for object like data and an increasing
        integer from ``0`` will be used for numerical data. Also if the column name is a primary
        key or alternate key this will enforce the values to be unique.

        Args:
            column_name (str):
                Name of the column.
            sdtype (str):
                Sematic data type or a ``Faker`` function name.
            column_metadata (dict):
                A dictionary representing the rest of the metadata for the given ``sdtype``.
            is_numeric (boolean):
                A boolean representing whether or not data type is numeric or not.

        Returns:
            transformer:
                Instance of ``rdt.transformers.text.RegexGenerator`` or
                ``rdt.transformers.pii.AnonymizedFaker`` with ``enforce_uniqueness`` set to
                ``True``.
        """
        default_regex_format = r'\d{30}' if is_numeric else '[0-1a-z]{5}'
        regex_format = column_metadata.get('regex_format', default_regex_format)
        transformer = rdt.transformers.RegexGenerator(
            regex_format=regex_format,
            enforce_uniqueness=(column_name in self._keys)
        )

        return transformer

    def _get_transformer_instance(self, sdtype, column_metadata):
        transformer = self._transformers_by_sdtype[sdtype]
        if isinstance(transformer, AnonymizedFaker):
            is_lexify = transformer.function_name == 'lexify'
            is_baseprovider = transformer.provider_name == 'BaseProvider'
            if is_lexify and is_baseprovider:  # Default settings
                return self.create_anonymized_transformer(
                    sdtype, column_metadata, False, self._locales
                )

        kwargs = {
            key: value for key, value in column_metadata.items()
            if key not in ['pii', 'sdtype']
        }
        if sdtype == 'datetime':
            kwargs['enforce_min_max_values'] = self._enforce_min_max_values

        if kwargs and transformer is not None:
            transformer_class = transformer.__class__
            return transformer_class(**kwargs)

        return deepcopy(transformer)

    def _update_constraint_transformers(self, data, columns_created_by_constraints, config):
        missing_columns = set(columns_created_by_constraints) - config['transformers'].keys()
        for column in missing_columns:
            dtype_kind = data[column].dtype.kind
            if dtype_kind in ('i', 'f'):
                config['sdtypes'][column] = 'numerical'
                config['transformers'][column] = rdt.transformers.FloatFormatter(
                    missing_value_replacement='mean',
                    missing_value_generation='random',
                    enforce_min_max_values=self._enforce_min_max_values
                )
            else:
                sdtype = self._DTYPE_TO_SDTYPE.get(dtype_kind, 'categorical')
                config['sdtypes'][column] = sdtype
                config['transformers'][column] = self._get_transformer_instance(sdtype, {})

        # Remove columns that have been dropped by the constraint
        for column in list(config['sdtypes'].keys()):
            if column not in data:
                LOGGER.info(
                    f"A constraint has dropped the column '{column}', removing the transformer "
                    "from the 'HyperTransformer'."
                )
                config['sdtypes'].pop(column)
                config['transformers'].pop(column)

        return config

    def _create_config(self, data, columns_created_by_constraints):
        sdtypes = {}
        transformers = {}

        columns_in_multi_col_transformer = self._get_grouped_columns()
        for column in set(data.columns) - columns_created_by_constraints:
            column_metadata = self.metadata.columns.get(column)
            sdtype = column_metadata.get('sdtype')

            if column in columns_in_multi_col_transformer:
                sdtypes[column] = sdtype
                continue

            pii = column_metadata.get('pii', sdtype not in self._transformers_by_sdtype)
            sdtypes[column] = 'pii' if pii else sdtype

            if sdtype == 'id':
                is_numeric = pd.api.types.is_numeric_dtype(data[column].dtype)
                if column_metadata.get('regex_format', False):
                    transformers[column] = self.create_regex_generator(
                        column,
                        sdtype,
                        column_metadata,
                        is_numeric
                    )
                    sdtypes[column] = 'text'

                elif column in self._keys:
                    prefix = None
                    if not is_numeric:
                        prefix = 'sdv-id-'

                    transformers[column] = IDGenerator(prefix=prefix)
                    sdtypes[column] = 'text'

                else:
                    transformers[column] = AnonymizedFaker(
                        provider_name=None,
                        function_name='bothify',
                        function_kwargs={'text': '#####'}
                    )
                    sdtypes[column] = 'pii'

            elif sdtype == 'unknown':
                transformers[column] = AnonymizedFaker(
                    function_name='bothify',
                )
                transformers[column].function_kwargs = {
                    'text': 'sdv-pii-?????',
                    'letters': '0123456789abcdefghijklmnopqrstuvwxyz'
                }

            elif pii:
                enforce_uniqueness = bool(column in self._keys)
                transformers[column] = self.create_anonymized_transformer(
                    sdtype,
                    column_metadata,
                    enforce_uniqueness,
                    self._locales
                )

            elif sdtype in self._transformers_by_sdtype:
                transformers[column] = self._get_transformer_instance(sdtype, column_metadata)

            else:
                sdtypes[column] = 'categorical'
                transformers[column] = self._get_transformer_instance(
                    'categorical',
                    column_metadata
                )

        for columns, transformer in self.grouped_columns_to_transformers.items():
            transformers[columns] = transformer

        config = {'transformers': transformers, 'sdtypes': sdtypes}
        config = self._update_constraint_transformers(data, columns_created_by_constraints, config)

        return config

    def update_transformers(self, column_name_to_transformer):
        """Update any of the transformers assigned to each of the column names.

        Args:
            column_name_to_transformer (dict):
                Dict mapping column names to transformers to be used for that column.
        """
        if self._hyper_transformer.field_transformers == {}:
            raise NotFittedError(
                'The DataProcessor must be prepared for fitting before the transformers can be '
                'updated.'
            )

        for column, transformer in column_name_to_transformer.items():
            if column in self._keys and not type(transformer) in (AnonymizedFaker, RegexGenerator):
                raise SynthesizerInputError(
                    f"Invalid transformer '{transformer.__class__.__name__}' for a primary "
                    f"or alternate key '{column}'. Please use 'AnonymizedFaker' or "
                    "'RegexGenerator' instead."
                )

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', module='rdt.hyper_transformer')
            self._hyper_transformer.update_transformers(column_name_to_transformer)

        self.grouped_columns_to_transformers = {
            col_tuple: transformer
            for col_tuple, transformer in self._hyper_transformer.field_transformers.items()
            if isinstance(col_tuple, tuple)
        }

    def _fit_hyper_transformer(self, data):
        """Create and return a new ``rdt.HyperTransformer`` instance.

        First get the ``dtypes`` and then use them to build a transformer dictionary
        to be used by the ``HyperTransformer``.

        Args:
            data (pandas.DataFrame):
                Data to transform.

        Returns:
            rdt.HyperTransformer
        """
        self._hyper_transformer.fit(data)

    def _fit_formatters(self, data):
        """Fit ``NumericalFormatter`` and ``DatetimeFormatter`` for each column in the data."""
        for column_name in data:
            column_metadata = self.metadata.columns.get(column_name)
            sdtype = column_metadata.get('sdtype')
            if sdtype == 'numerical' and column_name != self._primary_key:
                representation = column_metadata.get('computer_representation', 'Float')
                self.formatters[column_name] = NumericalFormatter(
                    enforce_rounding=self._enforce_rounding,
                    enforce_min_max_values=self._enforce_min_max_values,
                    computer_representation=representation
                )
                self.formatters[column_name].learn_format(data[column_name])

            elif sdtype == 'datetime' and column_name != self._primary_key:
                datetime_format = column_metadata.get('datetime_format')
                self.formatters[column_name] = DatetimeFormatter(datetime_format=datetime_format)
                self.formatters[column_name].learn_format(data[column_name])

    def prepare_for_fitting(self, data):
        """Prepare the ``DataProcessor`` for fitting.

        This method will learn the ``dtypes`` of the data, fit the numerical formatters,
        fit the constraints and create the configuration for the ``rdt.HyperTransformer``.
        If the ``rdt.HyperTransformer`` has already been updated, this will not perform the
        actions again.

        Args:
            data (pandas.DataFrame):
                Table data to be learnt.
        """
        if not self._prepared_for_fitting:
            LOGGER.info(f'Fitting table {self.table_name} metadata')
            self._dtypes = data[list(data.columns)].dtypes

            self.formatters = {}
            LOGGER.info(f'Fitting formatters for table {self.table_name}')
            self._fit_formatters(data)

            LOGGER.info(f'Fitting constraints for table {self.table_name}')
            if len(self._constraints_list) != len(self._constraints):
                self._fit_constraints(data)

            constrained = self._transform_constraints(data)
            columns_created_by_constraints = set(constrained.columns) - set(data.columns)

            config = self._hyper_transformer.get_config()
            missing_columns = columns_created_by_constraints - config.get('sdtypes').keys()
            if not config.get('sdtypes'):
                LOGGER.info((
                    'Setting the configuration for the ``HyperTransformer`` '
                    f'for table {self.table_name}'
                ))
                config = self._create_config(constrained, columns_created_by_constraints)
                self._hyper_transformer.set_config(config)

            elif missing_columns:
                config = self._update_constraint_transformers(
                    constrained,
                    missing_columns,
                    config
                )
                self._hyper_transformer = rdt.HyperTransformer()
                self._hyper_transformer.set_config(config)

            self._prepared_for_fitting = True

    def fit(self, data):
        """Fit this metadata to the given data.

        Args:
            data (pandas.DataFrame):
                Table to be analyzed.
        """
        if data.empty:
            raise ValueError('The fit dataframe is empty, synthesizer will not be fitted.')
        self._prepared_for_fitting = False
        self.prepare_for_fitting(data)
        constrained = self._transform_constraints(data)
        if constrained.empty:
            raise ValueError(
                'The constrained fit dataframe is empty, synthesizer will not be fitted.')
        LOGGER.info(f'Fitting HyperTransformer for table {self.table_name}')
        self._fit_hyper_transformer(constrained)
        self.fitted = True
        self.columns = list(data.columns)

    def reset_sampling(self):
        """Reset the sampling state for the anonymized columns and primary keys."""
        self._hyper_transformer.reset_randomization()

    def generate_keys(self, num_rows, reset_keys=False):
        """Generate the columns that are identified as ``keys``.

        Args:
            num_rows (int):
                Number of rows to be created. Must be an integer greater than 0.
            reset_keys (bool):
                Whether or not to reset the keys generators. Defaults to ``False``.

        Returns:
            pandas.DataFrame:
                A dataframe with the newly generated primary keys of the size ``num_rows``.
        """
        generated_keys = self._hyper_transformer.create_anonymized_columns(
            num_rows=num_rows,
            column_names=self._keys,
        )
        return generated_keys

    def transform(self, data, is_condition=False):
        """Transform the given data.

        Args:
            data (pandas.DataFrame):
                Table data.

        Returns:
            pandas.DataFrame:
                Transformed data.
        """
        data = data.copy()
        if not self.fitted:
            raise NotFittedError()

        # Filter columns that can be transformed
        columns = [
            column for column in self.get_sdtypes(primary_keys=not is_condition)
            if column in data.columns
        ]
        LOGGER.debug(f'Transforming constraints for table {self.table_name}')
        data = self._transform_constraints(data[columns], is_condition)

        LOGGER.debug(f'Transforming table {self.table_name}')
        if self._keys and not is_condition:
            data = data.set_index(self._primary_key, drop=False)

        try:
            transformed = self._hyper_transformer.transform_subset(data)
        except (rdt.errors.NotFittedError, rdt.errors.ConfigNotSetError):
            transformed = data

        return transformed

    def reverse_transform(self, data, reset_keys=False):
        """Reverse the transformed data to the original format.

        Args:
            data (pandas.DataFrame):
                Data to be reverse transformed.
            reset_keys (bool):
                Whether or not to reset the keys generators. Defaults to ``False``.

        Returns:
            pandas.DataFrame
        """
        if not self.fitted:
            raise NotFittedError()

        reversible_columns = [
            column
            for column in self._hyper_transformer._output_columns
            if column in data.columns
        ]

        reversed_data = data
        try:
            if not data.empty:
                reversed_data = self._hyper_transformer.reverse_transform_subset(
                    data[reversible_columns]
                )
        except rdt.errors.NotFittedError:
            LOGGER.info(f'HyperTransformer has not been fitted for table {self.table_name}')

        for transformer in self.grouped_columns_to_transformers.values():
            if not transformer.output_columns:
                reversed_data = transformer.reverse_transform(reversed_data)

        num_rows = len(reversed_data)
        sampled_columns = list(reversed_data.columns)
        missing_columns = [
            column
            for column in self.metadata.columns.keys() - set(sampled_columns + self._keys)
            if self._hyper_transformer.field_transformers.get(column)
        ]
        if missing_columns and num_rows:
            anonymized_data = self._hyper_transformer.create_anonymized_columns(
                num_rows=num_rows,
                column_names=missing_columns
            )
            sampled_columns.extend(missing_columns)
            reversed_data[anonymized_data.columns] = anonymized_data[anonymized_data.notna()]

        if self._keys and num_rows:
            generated_keys = self.generate_keys(num_rows, reset_keys)
            sampled_columns.extend(self._keys)
            reversed_data[generated_keys.columns] = generated_keys[generated_keys.notna()]

        for constraint in reversed(self._constraints_to_reverse):
            reversed_data = constraint.reverse_transform(reversed_data)

        # Add new columns generated by the constraint
        new_columns = list(set(reversed_data.columns) - set(sampled_columns))
        sampled_columns.extend(new_columns)

        # Sort the sampled columns in the order of the metadata.
        # Any extra columns not present in the metadata will be dropped.
        # In multitable there may be missing columns in the sample such as foreign keys
        # And alternate keys. Thats the reason of ensuring that the metadata column is within
        # The sampled columns.
        sampled_columns = [
            column for column in self.metadata.columns.keys()
            if column in sampled_columns
        ]
        for column_name in sampled_columns:
            column_data = reversed_data[column_name]

            dtype = self._dtypes[column_name]
            if is_integer_dtype(dtype) and is_float_dtype(column_data.dtype):
                column_data = column_data.round()

            reversed_data[column_name] = column_data[column_data.notna()]
            try:
                reversed_data[column_name] = reversed_data[column_name].astype(dtype)
            except ValueError as e:
                column_metadata = self.metadata.columns.get(column_name)
                sdtype = column_metadata.get('sdtype')
                if sdtype not in self._DTYPE_TO_SDTYPE.values():
                    LOGGER.info(
                        f"The real data in '{column_name}' was stored as '{dtype}' but the "
                        'synthetic data could not be cast back to this type. If this is a '
                        'problem, please check your input data and metadata settings.'
                    )
                    if column_name in self.formatters:
                        self.formatters.pop(column_name)

                else:
                    raise ValueError(e)

        # reformat columns using the formatters
        for column in sampled_columns:
            if column in self.formatters:
                data_to_format = reversed_data[column]
                reversed_data[column] = self.formatters[column].format_data(data_to_format)
        d = reversed_data[sampled_columns]
        new_column_order = self.columns
        df_syn = d[new_column_order]
        return df_syn

    def filter_valid(self, data):
        """Filter the data using the constraints and return only the valid rows.

        Args:
            data (pandas.DataFrame):
                Table data.

        Returns:
            pandas.DataFrame:
                Table containing only the valid rows.
        """
        for constraint in self._constraints:
            data = constraint.filter_valid(data)

        return data

    def to_dict(self):
        """Get a dict representation of this DataProcessor.

        Returns:
            dict:
                Dict representation of this DataProcessor.
        """
        constraints_to_reverse = [cnt.to_dict() for cnt in self._constraints_to_reverse]
        return {
            'metadata': deepcopy(self.metadata.to_dict()),
            'constraints_list': self.get_constraints(),
            'constraints_to_reverse': constraints_to_reverse,
            'model_kwargs': deepcopy(self._model_kwargs)
        }

    @classmethod
    def from_dict(cls, metadata_dict, enforce_rounding=True, enforce_min_max_values=True):
        """Load a DataProcessor from a metadata dict.

        Args:
            metadata_dict (dict):
                Dict metadata to load.
            enforce_rounding (bool):
                If passed, set the ``enforce_rounding`` on the new instance.
            enforce_min_max_values (bool):
                If passed, set the ``enforce_min_max_values`` on the new instance.
        """
        instance = cls(
            metadata=SingleTableMetadata.load_from_dict(metadata_dict['metadata']),
            enforce_rounding=enforce_rounding,
            enforce_min_max_values=enforce_min_max_values,
            model_kwargs=metadata_dict.get('model_kwargs')
        )

        instance._constraints_to_reverse = [
            Constraint.from_dict(cnt) for cnt in metadata_dict.get('constraints_to_reverse', [])
        ]
        instance._constraints_list = metadata_dict.get('constraints_list', [])

        return instance

    def to_json(self, filepath):
        """Dump this DataProcessor into a JSON file.

        Args:
            filepath (str):
                Path of the JSON file where this metadata will be stored.
        """
        with open(filepath, 'w') as out_file:
            json.dump(self.to_dict(), out_file, indent=4)

    @classmethod
    def from_json(cls, filepath):
        """Load a DataProcessor from a JSON.

        Args:
            filepath (str):
                Path of the JSON file to load
        """
        with open(filepath, 'r') as in_file:
            return cls.from_dict(json.load(in_file))

In [3]:
LOGGER = logging.getLogger(__name__)

MAX_DECIMALS = sys.float_info.dig - 1
INTEGER_BOUNDS = {
    'Int8': (-2**7, 2**7 - 1),
    'Int16': (-2**15, 2**15 - 1),
    'Int32': (-2**31, 2**31 - 1),
    'Int64': (-2**63, 2**63 - 1),
    'UInt8': (0, 2**8 - 1),
    'UInt16': (0, 2**16 - 1),
    'UInt32': (0, 2**32 - 1),
    'UInt64': (0, 2**64 - 1),
}


class NumericalFormatter:
    """Formatter for numerical data.

    Args:
        enforce_rounding (bool):
            Whether or not to learn what place to round to based on the data seen during ``fit``.
            If ``True``, the data returned by ``reverse_transform`` will be rounded to that place.
            Defaults to ``False``.
        enforce_min_max_values (bool):
            Whether or not to clip the data returned by ``reverse_transform`` to the min and
            max values seen during ``fit``.
            Defaults to ``False``.
        computer_representation (dtype):
            Accepts ``'Int8'``, ``'Int16'``, ``'Int32'``, ``'Int64'``, ``'UInt8'``, ``'UInt16'``,
            ``'UInt32'``, ``'UInt64'``, ``'Float'``.
            Defaults to ``'Float'``.
    """

    _dtype = None
    _min_value = None
    _max_value = None
    _rounding_digits = None

    def __init__(self, enforce_rounding=False, enforce_min_max_values=False,
                 computer_representation='Float'):
        self.enforce_rounding = enforce_rounding
        self.enforce_min_max_values = enforce_min_max_values
        self.computer_representation = computer_representation

    @staticmethod
    def _learn_rounding_digits(data):
        """Check if data has any decimals."""
        name = data.name
        data = np.array(data)
        roundable_data = data[~(np.isinf(data) | pd.isna(data))]

        # Doesn't contain numbers
        if len(roundable_data) == 0:
            return None

        # Doesn't contain decimal digits
        if ((roundable_data % 1) == 0).all():
            return 0

        # Try to round to fewer digits
        if (roundable_data == roundable_data.round(MAX_DECIMALS)).all():
            for decimal in range(MAX_DECIMALS + 1):
                if (roundable_data == roundable_data.round(decimal)).all():
                    return decimal

        # Can't round, not equal after MAX_DECIMALS digits of precision
        LOGGER.info(
            f"No rounding scheme detected for column '{name}'."
            ' Synthetic data will not be rounded.'
        )
        return None

    def learn_format(self, column):
        """Learn the format of a column.

        Args:
            column (pandas.Series):
                Data to learn the format.
        """
        self._dtype = column.dtype
        if self.enforce_min_max_values:
            self._min_value = column.min()
            self._max_value = column.max()

        if self.enforce_rounding:
            self._rounding_digits = self._learn_rounding_digits(column)

    def format_data(self, column):
        """Format a column according to the learned format.

        Args:
            column (pd.Series):
                Data to format.

        Returns:
            numpy.ndarray:
                containing the formatted data.
        """
        column = column.copy().to_numpy()
        if self.enforce_min_max_values:
            column = column.clip(self._min_value, self._max_value)
        elif self.computer_representation != 'Float':
            min_bound, max_bound = INTEGER_BOUNDS[self.computer_representation]
            column = column.clip(min_bound, max_bound)

        is_integer = np.dtype(self._dtype).kind == 'i'
        if self.enforce_rounding and self._rounding_digits is not None:
            column = column.round(self._rounding_digits)
        elif is_integer:
            column = column.round(0)

        if pd.isna(column).any() and is_integer:
            return column

        return column.astype(self._dtype)
def get_datetime_format(value):
    """Get the ``strftime`` format for a given ``value``.

    This function returns the ``strftime`` format of a given ``value`` when possible.
    If the ``_guess_datetime_format_for_array`` from ``pandas.core.tools.datetimes`` is
    able to detect the ``strftime`` it will return it as a ``string`` if not, a ``None``
    will be returned.

    Args:
        value (pandas.Series, np.ndarray, list, or str):
            Input to attempt detecting the format.

    Return:
        String representing the datetime format in ``strftime`` format or ``None`` if not detected.
    """
    if not isinstance(value, pd.Series):
        value = pd.Series(value)

    value = value[~value.isna()]
    value = value.astype(str).to_numpy()

    return _guess_datetime_format_for_array(value)
class DatetimeFormatter:
    """Formatter for datetime data.

    Args:
        datetime_format (str):
            The strftime to use for parsing time. For more information, see
            https://docs.python.org/3/library/datetime.html#strftime-and-strptime-behavior.
            If ``None`` it will attempt to learn it by itself. Defaults to ``None``.
    """

    def __init__(self, datetime_format=None):
        self.datetime_format = datetime_format

    def learn_format(self, column):
        """Learn the format of a column.

        Args:
            column (pandas.Series):
                Data to learn the format.
        """
        self._dtype = column.dtype
        if self.datetime_format is None:
            self.datetime_format = get_datetime_format(column)

    def format_data(self, column):
        """Format a column according to the learned format.

        Args:
            column (pd.Series):
                Data to format.

        Returns:
            numpy.ndarray:
                containing the formatted data.
        """
        if self.datetime_format:
            try:
                datetime_column = pd.to_datetime(column, format=self.datetime_format)
                column = datetime_column.dt.strftime(self.datetime_format)
            except ValueError:
                column = pd.to_datetime(column).dt.strftime(self.datetime_format)

        return column.astype(self._dtype)
TMP_FILE_NAME = '.sample.csv.temp'
DISABLE_TMP_FILE = 'disable'
IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type']
def validate_file_path(output_file_path):
    """Validate the user-passed output file arg, and create the file."""
    output_path = None
    if output_file_path == DISABLE_TMP_FILE:
        # Temporary way of disabling the output file feature, used by HMA1.
        return output_path

    elif output_file_path:
        output_path = os.path.abspath(output_file_path)
        if os.path.exists(output_path):
            raise AssertionError(f'{output_path} already exists.')

    else:
        if os.path.exists(TMP_FILE_NAME):
            os.remove(TMP_FILE_NAME)

        output_path = TMP_FILE_NAME

    # Create the file.
    with open(output_path, 'w+'):
        pass

    return output_path
def groupby_list(list_to_check):
    """Return the first element of the list if the length is 1 else the entire list."""
    return list_to_check[0] if len(list_to_check) == 1 else list_to_check
def check_num_rows(num_rows, expected_num_rows, is_reject_sampling, max_tries_per_batch):
    """Check the number of sampled rows against the expected number of rows.

    If the number of sampled rows is zero, throw a ValueError.
    If the number of sampled rows is less than the expected number of rows,
    raise a warning.

    Args:
        num_rows (int):
            The number of sampled rows.
        expected_num_rows (int):
            The expected number of rows.
        is_reject_sampling (bool):
            If reject sampling is used or not.
        max_tries_per_batch (int):
            Number of times to retry sampling until the batch size is met.

    Side Effects:
        ValueError or warning.
    """
    if num_rows < expected_num_rows:
        if num_rows == 0:
            user_msg = ('Unable to sample any rows for the given conditions. ')
            if is_reject_sampling:
                user_msg = user_msg + (
                    f'Try increasing `max_tries_per_batch` (currently: {max_tries_per_batch}). '
                    'Note that increasing this value will also increase the sampling time.'
                )
            else:
                user_msg = user_msg + (
                    'This may be because the provided values are out-of-bounds in the '
                    'current model. \nPlease try again with a different set of values.'
                )
            raise ValueError(user_msg)

        else:
            # This case should only happen with reject sampling.
            user_msg = (
                f'Only able to sample {num_rows} rows for the given conditions. '
                'To sample more rows, try increasing `max_tries_per_batch` '
                f'(currently: {max_tries_per_batch}). Note that increasing this value '
                'will also increase the sampling time.'
            )
            warnings.warn(user_msg)
def detect_discrete_columns(metadata, data, transformers):
    """Detect the discrete columns in a dataset.

    Because the metadata doesn't necessarily match the data (we only preprocess the data,
    while the metadata stays static), this method tries to infer whether the data is
    discrete.

    Args:
        metadata (sdv.metadata.SingleTableMetadata):
            Metadata that belongs to the given ``data``.

        data (pandas.DataFrame):
            ``pandas.DataFrame`` that matches the ``metadata``.

        transformers (dict[str: rdt.transformers.BaseTransformer]):
            A dictionary mapping between column names and the transformers assigned
            for it.

    Returns:
        discrete_columns (list):
            A list of discrete columns to be used with some of ``sdv`` synthesizers.
    """
    discrete_columns = []
    for column in data.columns:
        if column in metadata.columns:
            sdtype = metadata.columns[column]['sdtype']
            # Numerical and datetime columns never get preprocessed into categorical ones
            if sdtype in ['numerical', 'datetime']:
                continue

            elif sdtype in ['categorical', 'boolean']:
                transformer = transformers.get(column)
                if transformer and transformer.get_output_sdtypes().get(column) == 'float':
                    continue

                discrete_columns.append(column)
                continue

        # Logic to detect columns produced by transformers outside of the metadata scope
        # or columns created by constraints.
        column_data = data[column].dropna()

        # Ignore columns with only nans and empty datasets
        if column_data.empty:
            continue

        # Non-integer floats and integers with too many unique values are not categorical
        try:
            column_data = column_data.astype('float')
            is_int = column_data.equals(column_data.round())
            is_float = not is_int
            num_values = len(column_data)
            num_categories = column_data.nunique()
            threshold = max(10, num_values * .1)
            has_many_categories = num_categories > threshold
            if is_float or (is_int and has_many_categories):
                continue

        except (ValueError, TypeError):
            pass

        # Everything else is presumed categorical
        discrete_columns.append(column)

    return discrete_columns

### Load the original training dataset

In [4]:
train = pd.read_csv('../original data/intrusion/train.csv') # load the original training data
with open('../original data/intrusion/metadata.json', 'r') as f:
    metadata = json.load(f)  # load the metadata object
metadata_obj = SingleTableMetadata.load_from_dict(metadata)
label = 'label' # label indicates the column to be classified

### Data generation

In [5]:
# Cholesky Decomposition method to generate data
def generate_synthetic_cholesky(mean, covariance, size):
    L = np.linalg.cholesky(covariance) # get the Cholesky Decomposition of covariance matrix
    random_data = np.random.normal(loc=0, scale=1, size=(size, len(mean))) # random sampling
    synthetic_data = np.dot(random_data, L.T)
    synthetic_data += mean # generate synthetic data
    return synthetic_data
# PCA method to generate data
def generate_synthetic_PCA(mean, variance, covariance, size):
    e_dim=2 # dimension after data reduction (can be changed according to the dataset)
    column_n = len(mean)
    l, v = np.linalg.eig(covariance)
    max_index = np.argpartition(l,-e_dim)[-e_dim:]
    matrix = np.zeros((column_n,e_dim))
    # get the transformation matrix
    for i in range(e_dim-1,-1,-1):
        matrix[:,e_dim-1-i] = v[:,max_index[i]]
    # calculate new mean, variance and covariance for sampling
    nmean = np.dot(mean, matrix)
    nvar = np.dot(variance, matrix*matrix)
    k = np.sqrt(nvar)
    conv = np.zeros((e_dim,e_dim))
    for i in range(e_dim):
        conv[i,i]=k[i]
    A = np.random.multivariate_normal(mean=nmean, cov=conv, size=size)
    # calculate the new samples
    synthetic_data = np.dot(A, matrix.T)
    return synthetic_data
# activate synthetic data
def apply_activate(data, output_info):
    data_t = []
    st = 0
    for item in output_info:
        if item[1] == 'tanh':
            ed = st + item[0]
            subset_columns = data.iloc[:, st:ed].values
            data_t.append(torch.tanh(torch.tensor(subset_columns, dtype=torch.float32)))
            st = ed
        elif item[1] == 'softmax':
            ed = st + item[0]
            subset_columns = data.iloc[:, st:ed].values
            data_t.append(F.gumbel_softmax(torch.tensor(subset_columns, dtype=torch.float32), tau=0.2))
            st = ed
        else:
            assert 0
    return torch.cat(data_t, dim=1)
# function to balance data
def balance_original_data(data, label):
  d_data = pd.DataFrame()
  class_counts = data[label].value_counts()
  class_counts_dict = class_counts.to_dict()
  total_rows = len(data)
  n_class = len(class_counts)
  step = int(total_rows/n_class)
  for l, count in class_counts_dict.items():
    filtered = data[data[label]==l]
    sample = filtered.sample(n=step, random_state=42, replace=True)
    d_data = pd.concat([d_data, sample], axis=0)
  return d_data

In [6]:
# transformer to convert non numerical feature to numbers
CATEGORICAL = "categorical"
CONTINUOUS = "continuous"
ORDINAL = "ordinal"
class Transformer:

    @staticmethod
    def get_metadata(data, categorical_columns=tuple(), ordinal_columns=tuple()):
        meta = []

        df = pd.DataFrame(data)
        for index in df:
            column = df[index]

            if index in categorical_columns:
                mapper = column.value_counts().index.tolist()
                meta.append({
                    "name": index,
                    "type": CATEGORICAL,
                    "size": len(mapper),
                    "i2s": mapper
                })
            elif index in ordinal_columns:
                value_count = list(dict(column.value_counts()).items())
                value_count = sorted(value_count, key=lambda x: -x[1])
                mapper = list(map(lambda x: x[0], value_count))
                meta.append({
                    "name": index,
                    "type": ORDINAL,
                    "size": len(mapper),
                    "i2s": mapper
                })
            else:
                meta.append({
                    "name": index,
                    "type": CONTINUOUS,
                    "min": column.min(),
                    "max": column.max(),
                })

        return meta

    def fit(self, data, categorical_columns=tuple(), ordinal_columns=tuple()):
        raise NotImplementedError

    def transform(self, data):
        raise NotImplementedError

    def inverse_transform(self, data):
        raise NotImplementedError

class BGMTransformer(Transformer):
    """Model continuous columns with a BayesianGMM and normalized to a scalar [0, 1] and a vector.

    Discrete and ordinal columns are converted to a one-hot vector.
    """

    def __init__(self, n_clusters=10, eps=0.005):
        """n_cluster is the upper bound of modes."""
        self.meta = None
        self.n_clusters = n_clusters
        self.eps = eps

    def fit(self, data, categorical_columns=tuple(), ordinal_columns=tuple()):
        self.meta = self.get_metadata(data, categorical_columns, ordinal_columns)
        model = []

        self.output_info = []
        self.output_dim = 0
        self.components = []
        for id_, info in enumerate(self.meta):
            if info['type'] == CONTINUOUS:
                gm = BayesianGaussianMixture(
                    n_components=self.n_clusters,
                    weight_concentration_prior_type='dirichlet_process',
                    weight_concentration_prior=0.001,
                    n_init=1)
                gm.fit(data.iloc[:, id_].values.reshape([-1, 1]))
                model.append(gm)
                comp = gm.weights_ > self.eps
                self.components.append(comp)

                self.output_info += [(1, 'tanh'), (np.sum(comp), 'softmax')]
                self.output_dim += 1 + np.sum(comp)
            else:
                model.append(None)
                self.components.append(None)
                self.output_info += [(info['size'], 'softmax')]
                self.output_dim += info['size']

        self.model = model

    def transform(self, data):
        values = []
        c_model = []
        for id_, info in enumerate(self.meta):
            current = data.iloc[:, id_]
            if info['type'] == CONTINUOUS:
                current = current.values.reshape([-1, 1])

                means = self.model[id_].means_.reshape((1, self.n_clusters))
                stds = np.sqrt(self.model[id_].covariances_).reshape((1, self.n_clusters))
                features = (current - means) / (4 * stds)

                probs = self.model[id_].predict_proba(current.reshape([-1, 1]))

                n_opts = sum(self.components[id_])
                features = features[:, self.components[id_]]
                probs = probs[:, self.components[id_]]

                opt_sel = np.zeros(len(data), dtype='int')
                for i in range(len(data)):
                    pp = probs[i] + 1e-6
                    pp = pp / sum(pp)
                    opt_sel[i] = np.random.choice(np.arange(n_opts), p=pp)

                idx = np.arange((len(features)))
                features = features[idx, opt_sel].reshape([-1, 1])
                features = np.clip(features, -.99, .99)

                probs_onehot = np.zeros_like(probs)
                probs_onehot[np.arange(len(probs)), opt_sel] = 1
                values += [features, probs_onehot]
                c_model.append(None)
            else:
                label_encoder = LabelEncoder()
                current_encoded = label_encoder.fit_transform(current)
                col_t = np.zeros([len(data), info['size']])
                col_t[np.arange(len(data)), current_encoded] = 1
                # col_t[np.arange(len(data)), current.astype('int32')] = 1
                c_model.append(label_encoder)
                values.append(col_t)
        self.c_model = c_model
        return np.concatenate(values, axis=1)

    def inverse_transform(self, data, sigmas):
        data_t = np.zeros([len(data), len(self.meta)])
        data_t = data_t.astype(object)
        st = 0
        for id_, info in enumerate(self.meta):
            if info['type'] == CONTINUOUS:
                u = data[:, st]
                v = data[:, st + 1:st + 1 + np.sum(self.components[id_])]

                if sigmas is not None:
                    sig = sigmas[st]
                    u = np.random.normal(u, sig)

                u = np.clip(u, -1, 1)
                v_t = np.ones((data.shape[0], self.n_clusters)) * -100
                v_t[:, self.components[id_]] = v
                v = v_t
                st += 1 + np.sum(self.components[id_])
                means = self.model[id_].means_.reshape([-1])
                stds = np.sqrt(self.model[id_].covariances_).reshape([-1])
                p_argmax = np.argmax(v, axis=1)
                std_t = stds[p_argmax]
                mean_t = means[p_argmax]
                tmp = u * 4 * std_t + mean_t
                data_t[:, id_] = tmp

            else:
                current = data[:, st:st + info['size']]
                st += info['size']
                data_t[:, id_] = np.argmax(current, axis=1)
                label_encoder = self.c_model[id_]
                d_int = data_t[:, id_].astype(int)
                original_categories = label_encoder.inverse_transform(d_int)
                data_t[:, id_] = original_categories
        df_t = pd.DataFrame(data_t)
        return df_t

In [7]:
# CholeskySynthesizer method to fit and sample synthetic data
class CholeskySynthesizer():
    def __init__(self):
        self.e_dim = 2
        self.mean = None
        self.cov = None
        self.column_names = None
        self.transformer = None
    
    def fit(self, train_data, categorical_columns=tuple(), ordinal_columns=tuple()):
        self.column_names = list(train_data.columns)
        self.transformer = BGMTransformer()
        self.transformer.fit(train_data, categorical_columns, ordinal_columns)
        train_data = self.transformer.transform(train_data)
        
        column_n = train_data.shape[1]
        mean_real = [np.mean(train_data[:,i]) for i in range(0,column_n)]
        cov_real = np.cov(train_data.T)
        epsilon = 1e-6
        cov_real_regularized = cov_real + epsilon * np.eye(cov_real.shape[0])
        self.mean = mean_real
        self.cov = cov_real_regularized
    
    def sample(self, n):
        output_info = self.transformer.output_info
        new_data = []
        fake = generate_synthetic_cholesky(self.mean, self.cov, n)
        df_fake = pd.DataFrame(fake)
        for u in range(n):
            fakeact = apply_activate(df_fake, output_info)
            new_data.append(fakeact.detach().cpu().numpy())
        new_data = np.concatenate(new_data, axis=0)
        new_data = new_data[:n]
        inv_data = self.transformer.inverse_transform(new_data, None)
        inv_data.columns = self.column_names
        return inv_data
# balancedCholesky method to generate balanced data using Cholesky Decomposition
def balancedCholesky(data, metadata, label, sample_n):
    metadata_obj = SingleTableMetadata.load_from_dict(metadata)
    data_dict = metadata_obj.columns
    categorical_columns = [column for column, info in data_dict.items() if info.get('sdtype') != 'numerical']
    ordinal_columns = [column for column, info in data_dict.items() if (info.get('sdtype') != 'cateorical') and (info.get('sdtype') != 'numerical')]
    
    synthetic_data = pd.DataFrame()
    class_counts = data[label].value_counts()
    class_counts_dict = class_counts.to_dict()
    total_rows = len(data)
    threshold = 0.3
    small_classes = [label for label, count in class_counts_dict.items() if count < threshold * total_rows]
    n_class = len(class_counts)
    step = int(sample_n/n_class)
    for l in small_classes:
        filtered_df = data[data[label]==l]
        model = CholeskySynthesizer()
        model.fit(filtered_df, categorical_columns, ordinal_columns)
        syn = model.sample(step)
        synthetic_data = pd.concat([synthetic_data, syn], axis=0)
    model = CholeskySynthesizer()
    data1 = data[~data[label].isin(small_classes)]
    model.fit(data1, categorical_columns, ordinal_columns)
    syn = model.sample(sample_n - synthetic_data.shape[0])
    synthetic_data = pd.concat([synthetic_data, syn], axis=0)
    int_col = [column_name for column_name, properties in metadata["columns"].items() if properties.get("computer_representation") == "Int64"]
    synthetic_data[int_col] = synthetic_data[int_col].round().astype(int)
    return synthetic_data
 

In [8]:
# PCASynthesizer method to fit and sample data
class PCASynthesizer():
    def __init__(self):
        self.e_dim = 2
        self.mean = None
        self.var = None
        self.cov = None
        self.column_names = None
        self.transformer = None
    
    def fit(self, train_data, categorical_columns=tuple(), ordinal_columns=tuple()):
        self.column_names = list(train_data.columns)
        self.transformer = BGMTransformer()
        self.transformer.fit(train_data, categorical_columns, ordinal_columns)
        train_data = self.transformer.transform(train_data)
        
        column_n = train_data.shape[1]
        mean_real = [np.mean(train_data[:,i]) for i in range(0,column_n)]
        vars_real = [np.var(train_data[:,i]) for i in range(0,column_n)]
        cov_real = np.cov(train_data.T)
        self.mean = mean_real
        self.var = vars_real
        self.cov = cov_real
    
    def sample(self, n):
        output_info = self.transformer.output_info
        new_data = []
        fake = generate_synthetic_PCA(self.mean, self.var, self.cov, n)
        df_fake = pd.DataFrame(fake)
        for u in range(n):
            fakeact = apply_activate(df_fake, output_info)
            new_data.append(fakeact.detach().cpu().numpy())
        new_data = np.concatenate(new_data, axis=0)
        new_data = new_data[:n]
        inv_data = self.transformer.inverse_transform(new_data, None)
        inv_data.columns = self.column_names
        return inv_data
# BPCA method to generate balanced data using PCA
def BPCA(data, metadata, label, sample_n):
    metadata_obj = SingleTableMetadata.load_from_dict(metadata)
    data_dict = metadata_obj.columns
    categorical_columns = [column for column, info in data_dict.items() if info.get('sdtype') != 'numerical']
    ordinal_columns = [column for column, info in data_dict.items() if (info.get('sdtype') != 'cateorical') and (info.get('sdtype') != 'numerical')]
    
    synthetic_data = pd.DataFrame()
    class_counts = data[label].value_counts()
    class_counts_dict = class_counts.to_dict()
    total_rows = len(data)
    threshold = 0.3
    small_classes = [label for label, count in class_counts_dict.items() if count < threshold * total_rows]
    n_class = len(class_counts)
    step = int(sample_n/n_class)
    for l in small_classes:
        filtered_df = data[data[label]==l]
        model = PCASynthesizer()
        model.fit(filtered_df, categorical_columns, ordinal_columns)
        syn = model.sample(step)
        synthetic_data = pd.concat([synthetic_data, syn], axis=0)
    model = PCASynthesizer()
    data1 = data[~data[label].isin(small_classes)]
    model.fit(data1, categorical_columns, ordinal_columns)
    syn = model.sample(sample_n - synthetic_data.shape[0])
    synthetic_data = pd.concat([synthetic_data, syn], axis=0)
    int_col = [column_name for column_name, properties in metadata["columns"].items() if properties.get("computer_representation") == "Int64"]
    synthetic_data[int_col] = synthetic_data[int_col].round().astype(int)
    return synthetic_data

In [9]:
result_3 = balancedCholesky(train, metadata, label, train.shape[0])
result_4 = BPCA(train, metadata, label, train.shape[0])

  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **kwargs)
  return fit_method(estimator, *args, **