# SDV SYNTHETIC DATA GENERATION

## SETUP

In [1]:
import pandas as pd
pd.set_option('display.max_columns', None)

## LOAD REAL DATA

In [None]:
from ucimlrepo import fetch_ucirepo 

'''# metadata 
print(diabetes_130_us_hospitals_for_years_1999_2008.metadata) 
  
# variable information 
print(diabetes_130_us_hospitals_for_years_1999_2008.variables) '''
  
# fetch dataset 
diabetes_130_us_hospitals_for_years_1999_2008 = fetch_ucirepo(id=296) 
  
# data (as pandas dataframes) 
X = diabetes_130_us_hospitals_for_years_1999_2008.data.features 
y = diabetes_130_us_hospitals_for_years_1999_2008.data.targets 

# create complete real_data
diabetes = pd.DataFrame(X)
diabetes["readmitted"] = y

# visualize data
diabetes.head()

## EXPLORE & PREPROCESS REAL DATA

In [None]:
# dimensions
print(f"Dimension: {diabetes.shape}")

In [None]:
# detect sensitive columns by intuition by their name
print(f"\columns: {diabetes.columns}\n")

# identify identity sensible data: 
sensitive_column_names = ['race', 'gender', 'weight', 'admission_type_id','discharge_disposition_id','admission_source_id','payer_code', 'medical_specialty']
print(f"\nSensitive columns: {sensitive_column_names}\n")

###  Check columns values & distribution

In [None]:
# SINGLE COLUMN: check columns values & distribution
def visualize_columns_distributions(df):
    for col in df.columns:
        print(f"\n\nColumn: {col}")

        # Get value counts of data
        val_counts = df[col].value_counts(dropna = False)   

        # prepare to print more pretty way
        counts_df = pd.DataFrame({'Items': val_counts})    

        # Print 
        print(counts_df)
    
# call to columns_distibution function
visualize_columns_distributions(diabetes)    

###  Generalize 'Nan' race values to 'Other'

In [None]:
# Generalize Nan race as 'Other'
import numpy as np
diabetes.loc[diabetes["race"].isin([np.nan ,'Other']), "race"] = "Other"

# validate change
diabetes.race.value_counts() 

###  Remove 'Unknown/Invalid' gender values

In [None]:
# 'Unknown/Invalid' # only 3 registry, not possible to define gender, best option would be to remove them 
diabetes[diabetes['gender'] == 'Unknown/Invalid'] 

# removing  'Unknown/Invalid' gender data
print(f"Shape before drop: {diabetes.shape}")
diabetes = diabetes.drop(diabetes[diabetes["gender"] == 'Unknown/Invalid'].index)

# validating results (only 3 less)
print(f"Shape after drop: {diabetes.shape}")

###  Check for null values per column

In [None]:
# nulls per columns (percentage)
diabetes.isna().sum() * 100 / len(diabetes)

###  Drop "weight" column

In [None]:
# remove weight column form dataframe
print(f"Columns before remove {len(diabetes.columns)}")
print(f"Columns:{diabetes.columns}")
diabetes = diabetes.drop('weight', axis=1)
print(f"Columns after remove {len(diabetes.columns)}")

###  Check for variability

In [None]:
# Drops columns without variability
def columns_without_variability(data_frame):
    '''
     Function that is responsible to determine which columnns has no variability (those which has only 1 value).
    '''
    sobran = []

    cols = data_frame.columns
    for col in cols:
        if len(data_frame[col].unique()) < 2:
            sobran.append(col)

    return sobran

# obtener listado de columnas sin variabilidad en una lista
cols_without_variability = columns_without_variability(diabetes)

# remove columns without variabitly
print(f"Columns before remove {len(diabetes.columns)}")
print(f"Columns:{diabetes.columns}")
diabetes = diabetes.drop(columns = cols_without_variability)
print(f"Columns after remove {len(diabetes.columns)}")

###  Check dtype uniformity 

In [None]:
# data information
print(f"\nData information: {diabetes.info()}\n")

# Check column values, correspond to dtypes
for cat in diabetes.columns:
    print(f"\nColumn: {cat} values: {diabetes[cat].unique()}")

### Categorical data codification

### Numerical data standarization

In [None]:
'''import itertools

# function that determines singularization risk columns pairs
def determine_singularization_risk(df, column_pairs):
      # Dictionary to keep track of problematic rows
      special_pairs = {}
      count =0
      for val in column_pairs:
           data_crosstab = pd.crosstab(df[val[0]], 
                            df[val[1]],  
                            margins = False) 
           # Check if any row in the crosstab has only 1 value
           #print(f"\nCross_data: {data_crosstab}")
           conflicted_value_column = data_crosstab.apply(lambda row: row[row == 1].index.tolist(), axis=1)

           # Filter only rows where the conflicted column is found (non-empty lists)
           conflicted_value_column = conflicted_value_column[conflicted_value_column.apply(len) > 0]

           if any(conflicted_value_column.apply(len) == 1):
                 #print(f"conflicted_value_column: {conflicted_value_column}")
                 # complete dictionary
                 count+=1
                 special_attention = {}
                 for index, columns in conflicted_value_column.items():
                       if len(columns)==1:
                             special_attention[val[0]] = index
                             special_attention[val[1]] = columns     
                             #print(f"special_attention: {special_attention}")

                             # add to principal dictionary    
                             special_pairs[val] = special_attention
                             #print(f"special_pairs: {special_pairs}")

      return (special_pairs, count)
      
                
# COLUMN PAIRS: create pairs, only with sensitive data
column_pairs = list(itertools.combinations(sensitive_column_names, 2))                
# function call
special_pairs, count = determine_singularization_risk(diabetes1, column_pairs)               
 
# visualize results
print(f"From: {len(column_pairs)} pairs conflictive are: {count}")
for key in special_pairs.keys():
      print(f"Check: {special_pairs.get(key)}")'''

## EXPLORE REAL DATA VISUALLY

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# visualizing data distribution
for col in diabetes.columns:
    fig, ax = plt.subplots(figsize = (10,5))
    sns.countplot(data=diabetes, x=col, ax = ax)
    ax.set_title(col)
    plt.tight_layout()
    plt.show()

## CREATE SYNTHETIZER & SYNTHETIC DATA WITH SDV

In [None]:
# Transform `diabetes` dataframe `SingleTableMetadata` data type 
from sdv.metadata import SingleTableMetadata

metadata = SingleTableMetadata()

# Automatically detect metadata from the actual DataFrame
metadata.detect_from_dataframe(diabetes)

# Change dtype of "_id" columns. Threat as categorical instead of numerical
for column_name in metadata.columns:
    if '_id' in column_name:
        metadata.update_column(column_name, sdtype='categorical')

# Check if metadata has been correctly generated
print(metadata)	

In [None]:
from sdv.single_table import GaussianCopulaSynthesizer

synthesizer = GaussianCopulaSynthesizer(
    metadata,
    enforce_min_max_values=True,
    enforce_rounding=True) 

In [None]:
# train data to learn from real data
synthesizer.fit(
    data = diabetes
)

In [None]:
# create new data (same dimensions) based on learned model
synthetic_data = synthesizer.sample(
    num_rows=diabetes.shape[0]
)

## EXPLORE SYNTHETIC DATA AND VALIDATE

In [None]:
# dimensions
print(f"Real dimension: {diabetes.shape}")
print(f"Synth dimension: {synthetic_data.shape}")

# data information
print(f"\n\nReal data information: {diabetes.info()}")
print(f"Synth data information: {synthetic_data.info()}\n")

In [None]:
# Compare values creation
for col in diabetes.columns:
    print(f"\n\nColumn: {col}")
    
    # Get value counts for both real and synthetic data
    real_counts = diabetes[col].value_counts()
    synth_counts = synthetic_data[col].value_counts()
    # Combine the counts to ensure all categories are represented in both datasets
    combined_counts = pd.DataFrame({'Real': real_counts, 'Synthetic': synth_counts}).fillna(0)
    
    # Print the combined counts for easy comparison
    print(combined_counts)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# CORRELATION MATRIX
fig, ax = plt.subplots(1,2,figsize = (15,5))
corr_r = diabetes.corr()
corr_s = synthetic_data.corr()
sns.heatmap(corr_r, 
            xticklabels=corr_r.columns.values,
            yticklabels=corr_r.columns.values,
            cmap="Blues",
            annot=True,         # Display the correlation values in the cells
            fmt=".2f", ax = ax[0])
sns.heatmap(corr_s, 
            xticklabels=corr_s.columns.values,
            yticklabels=corr_s.columns.values,
            cmap="Greens",
            annot=True,         # Display the correlation values in the cells
            fmt=".2f", ax = ax[1])
ax[0].set_title("REAL")
ax[1].set_title("SYNTH")
plt.tight_layout()     
plt.show()