## Import packages

In [None]:
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import HistGradientBoostingClassifier
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

from fasd.utils import set_seed
from fasd import TabularFASD

seed = 123
set_seed(seed)

## Generate synthetic data

In [None]:
X = load_breast_cancer(as_frame=True).frame
generator = TabularFASD(target_column="target", random_state=seed)
generator.fit(X)
syn = generator.generate(len(X))

## Evaluate Machine Learning Efficacy

In [None]:
yy = X["target"].copy()
xx = X.drop("target", axis=1)
y_syn = syn["target"].copy()
X_syn = syn.drop("target", axis=1)

X_tr, X_te, y_tr, y_te = train_test_split(
    xx, yy, stratify=yy, train_size=0.7, random_state=seed
)
X_syn_tr, X_syn_te, y_syn_tr, y_syn_te = train_test_split(
    X_syn, y_syn, stratify=y_syn, train_size=0.7, random_state=seed
)

model = HistGradientBoostingClassifier(max_depth=3)
model.fit(X_tr, y_tr)
preds = model.predict_proba(X_te)[:, 1]
score = roc_auc_score(y_te, preds)
print(f"Train Real Test Real ROCAUC: {score}")

model = HistGradientBoostingClassifier(max_depth=3)
model.fit(X_syn_tr, y_syn_tr)
preds = model.predict_proba(X_te)[:, 1]
score = roc_auc_score(y_te, preds)
print(f"Train Synthetic Test Real ROCAUC: {score}")

# Plot Feature Distributions
Here we see that even though feature distributions are not retained very well (poor fidelity), task-specific utility is high as seen from ML efficacy above. This is the exact purpose of FASD.

In [None]:
%matplotlib inline

fig, axes = plt.subplots(7,5,figsize=(28,15))
axes = axes.flatten()  

for i, col in enumerate(X.columns):
    ax = axes[i]
    bins = np.histogram_bin_edges(pd.concat((X[col],syn[col])).astype(float), bins='auto')
    ax.hist(X[col], bins=bins, alpha=0.5)
    ax.hist(syn[col], bins=bins, alpha=0.5)
    ax.set_title(col, fontsize=10)
    ax.tick_params(labelsize=8)

for j in range(i + 1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()