In [11]:
# Imports & environment synthcity setup (prefer site-packages synthcity)
import sys
from pathlib import Path
import site
import numpy as np
import pandas as pd
import torch
import random
from copy import deepcopy
from tqdm import tqdm
# sklearn helpers
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
# Ensure repo-local synthcity doesn't shadow environment package
repo_root = Path('/home/dhanush/work/3S-Testing').resolve()
repo_synthcity = repo_root / 'synthcity'
new_sys_path = []
for p in sys.path:
    if not p:
        continue
    try:
        rp = Path(p).resolve()
    except Exception:
        new_sys_path.append(p); continue
    if str(repo_synthcity) in str(rp) or str(repo_root) in str(rp):
        continue
    new_sys_path.append(p)
sys.path = new_sys_path
# Prefer site-packages entries
site_packages = site.getsitepackages() if hasattr(site,'getsitepackages') else []
for sp in reversed(site_packages):
    if sp in sys.path: sys.path.remove(sp)
    sys.path.insert(0, sp)
# add repo root at end so `src` imports work
if str(repo_root) not in sys.path: sys.path.append(str(repo_root))
# Import synthcity plugins and project helpers
from synthcity.plugins import Plugins
import synthcity
print('synthcity from', synthcity.__file__)
from src.data_loader import load_adult_data
from src.metrics import *
from src.utils import *
# seed
seed = 0
np.random.seed(seed); random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed(seed)
print('Imports complete — Plugins available:', Plugins().list())

[2025-12-05T22:13:23.606814-0500][56369][CRITICAL] module disabled: /home/dhanush/anaconda3/envs/3s/lib/python3.9/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-12-05T22:13:23.607277-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_arf' has no attribute 'plugin'
[2025-12-05T22:13:23.607636-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_arf' has no attribute 'plugin'
[2025-12-05T22:13:23.608091-0500][56369][CRITICAL] module plugin_arf load failed
[2025-12-05T22:13:23.608650-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_great' has no attribute 'plugin'
[2025-12-05T22:13:23.608988-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_great' has no attribute 'plugin'
[2025-12-05T22:13:23.609382-0500][56369][CRITICAL] module plugin_great load failed
[2025-12-05T22:13:23.610177-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_bayesian_net

synthcity from /home/dhanush/anaconda3/envs/3s/lib/python3.9/site-packages/synthcity/__init__.py
Imports complete — Plugins available: ['rtvae', 'timevae', 'tvae', 'fflows', 'nflow', 'dummy_sampler', 'survival_nflow', 'aim', 'uniform_sampler', 'ddpm', 'radialgan', 'ctgan', 'survival_gan', 'image_cgan', 'dpgan', 'pategan', 'image_adsgan', 'marginal_distributions', 'timegan', 'survae', 'adsgan', 'survival_ctgan']


# Get Data

In [12]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

X_train, X_test, y_train, y_test, X, y = load_adult_data()

D_adult = X
D_adult["y"] = y
seed = 0
X_train, X_test = train_test_split(D_adult, test_size=0.6, random_state=seed)

# Train base models

In [13]:
model_dict = {
    "mlp": MLPClassifier(),
    "knn": KNeighborsClassifier(),
    "dt": DecisionTreeClassifier(),
    "rf": RandomForestClassifier(),
    "gbc": GradientBoostingClassifier(),
}

trained_model_dict = {}

for model in model_dict.keys():
    clf = model_dict[model]
    clf.fit(X_train.drop("y", axis=1), X_train["y"])

    trained_model_dict[model] = deepcopy(clf)

# Train Generative model

In [15]:
# Train DDPM generator from environment synthcity
discrete_columns = [
    "education-num",
    "marital-status",
    "employment_type",
    "relationship",
    "race",
    "sex",
    "country",
]

np.random.seed(seed)
import random
random.seed(seed)
import torch
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed(seed)

base_plugin = Plugins().get('ddpm')
if callable(base_plugin):
    syn_model = base_plugin()
else:
    syn_model = type(base_plugin)()


syn_model.fit(X_train, discrete_columns=discrete_columns, epochs=50, seed=seed)


[2025-12-05T22:17:14.375504-0500][56369][CRITICAL] module disabled: /home/dhanush/anaconda3/envs/3s/lib/python3.9/site-packages/synthcity/plugins/generic/plugin_goggle.py
[2025-12-05T22:17:14.376406-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_arf' has no attribute 'plugin'
[2025-12-05T22:17:14.376798-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_arf' has no attribute 'plugin'
[2025-12-05T22:17:14.377086-0500][56369][CRITICAL] module plugin_arf load failed
[2025-12-05T22:17:14.377626-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_great' has no attribute 'plugin'
[2025-12-05T22:17:14.378010-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_great' has no attribute 'plugin'
[2025-12-05T22:17:14.378431-0500][56369][CRITICAL] module plugin_great load failed
[2025-12-05T22:17:14.379026-0500][56369][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_bayesian_net

# Identify column of the marginal to shift

In [16]:
from tqdm import tqdm

metric = "age"
data = X_train[metric]
cat_groups_present = False

if len(np.unique(data)) < 10:
    cat_groups = np.unique(data)
    cat_groups_present = True
else:
    mean, std = np.mean(data), np.std(data)

    minimum, maximum = np.min(data), np.max(data)

eval_idx = np.where(X_train.columns == metric)[0][0]
eval_idx


0

# Shift 3S

In [17]:
from src.shift import rejection_sample

ys_mlp_all = []
ys_knn_all = []
ys_dt_all = []
ys_rf_all = []
ys_gbc_all = []

for i in range(2):

    ys_mlp_tmp = []
    ys_knn_tmp = []
    ys_dt_tmp = []
    ys_rf_tmp = []
    ys_gbc_tmp = []
    n_range = 10
    n_std = 1 * std

    # Generate synthetic data from DDPM (use generate() method)
    try:
        out = syn_model.generate(count=10000)
    except Exception:
        try:
            out = syn_model.sample(count=10000)
        except Exception as e:
            print('DDPM generation failed:', e)
            out = None
    
    if out is not None:
        try:
            shift_df = out.dataframe()
        except Exception:
            shift_df = pd.DataFrame(out)
    else:
        continue

    xs = list(
        np.arange(
            mean - n_std, mean + n_std, ((mean + n_std) - (mean - n_std)) / n_range
        )
    )  
    for shift_mean in np.arange(
        mean - n_std, mean + n_std, ((mean + n_std) - (mean - n_std)) / n_range
    ): 

        reject_df = rejection_sample(
            D=shift_df, mean=shift_mean, std=std / 2, feat_id=[0]
        )
        if len(reject_df) == 0:
            continue
        test_df = pd.DataFrame(reject_df, columns=X_test.columns)
        real_tester = test_df
        for model in model_dict.keys():
            clf = model_dict[model]
            y_score = clf.predict_proba(real_tester.drop("y", axis=1))[:, 1]
            y_pred = clf.predict(real_tester.drop("y", axis=1))

            if model == "mlp":
                ys_mlp_tmp.append(accuracy_score(real_tester["y"], y_pred))

            if model == "knn":
                ys_knn_tmp.append(accuracy_score(real_tester["y"], y_pred))

            if model == "dt":
                ys_dt_tmp.append(accuracy_score(real_tester["y"], y_pred))

            if model == "rf":
                ys_rf_tmp.append(accuracy_score(real_tester["y"], y_pred))

            if model == "gbc":
                ys_gbc_tmp.append(accuracy_score(real_tester["y"], y_pred))

    ys_mlp_all.append(ys_mlp_tmp)
    ys_knn_all.append(ys_knn_tmp)
    ys_dt_all.append(ys_dt_tmp)
    ys_rf_all.append(ys_rf_tmp)
    ys_gbc_all.append(ys_gbc_tmp)


# Rejection sample (Test/Oracle data)

In [18]:
yr_mlp = []
yr_knn = []
yr_dt = []
yr_rf = []
yr_gbc = []
xr = list(
    np.arange(mean - n_std, mean + n_std, ((mean + n_std) - (mean - n_std)) / n_range)
)  
i = 0
for shift_mean in np.arange(
    mean - n_std, mean + n_std, ((mean + n_std) - (mean - n_std)) / n_range
):  

    reject_df = rejection_sample(D=X_test, mean=shift_mean, std=std / 2, feat_id=[0])
    if len(reject_df) == 0:
        continue
    test_df = pd.DataFrame(reject_df, columns=X_test.columns)
    real_tester = test_df
    for model in model_dict.keys():
        clf = model_dict[model]
        y_score = clf.predict_proba(real_tester.drop("y", axis=1))[:, 1]
        y_pred = clf.predict(real_tester.drop("y", axis=1))

        if model == "mlp":
            yr_mlp.append(accuracy_score(real_tester["y"], y_pred))

        if model == "knn":
            yr_knn.append(accuracy_score(real_tester["y"], y_pred))

        if model == "dt":
            yr_dt.append(accuracy_score(real_tester["y"], y_pred))

        if model == "rf":
            yr_rf.append(accuracy_score(real_tester["y"], y_pred))

        if model == "gbc":
            yr_gbc.append(accuracy_score(real_tester["y"], y_pred))


# Shift RS (Source)

In [19]:
yr_mlp_val = []
yr_knn_val = []
yr_dt_val = []
yr_rf_val = []
yr_gbc_val = []
xr = list(
    np.arange(mean - n_std, mean + n_std, ((mean + n_std) - (mean - n_std)) / n_range)
)  
i = 0
for shift_mean in np.arange(
    mean - n_std, mean + n_std, ((mean + n_std) - (mean - n_std)) / n_range
): 
    reject_df = rejection_sample(D=X_train, mean=shift_mean, std=std / 2, feat_id=[0])
    if len(reject_df) == 0:
        continue
    test_df = pd.DataFrame(reject_df, columns=X_train.columns)
    real_tester = test_df
    for model in model_dict.keys():
        clf = model_dict[model]
        y_score = clf.predict_proba(real_tester.drop("y", axis=1))[:, 1]
        y_pred = clf.predict(real_tester.drop("y", axis=1))

        if model == "mlp":
            yr_mlp_val.append(accuracy_score(real_tester["y"], y_pred))

        if model == "knn":
            yr_knn_val.append(accuracy_score(real_tester["y"], y_pred))

        if model == "dt":
            yr_dt_val.append(accuracy_score(real_tester["y"], y_pred))

        if model == "rf":
            yr_rf_val.append(accuracy_score(real_tester["y"], y_pred))

        if model == "gbc":
            yr_gbc_val.append(accuracy_score(real_tester["y"], y_pred))


# Mean Shift

In [20]:
yr_mlp_ms = []
yr_knn_ms = []
yr_dt_ms = []
yr_rf_ms = []
yr_gbc_ms = []
xr = list(
    np.arange(mean - n_std, mean + n_std, ((mean + n_std) - (mean - n_std)) / n_range)
)  
i = 0
for shift_mean in np.arange(
    mean - n_std, mean + n_std, ((mean + n_std) - (mean - n_std)) / n_range
):  
    from copy import deepcopy

    test_df = deepcopy(X_train)
    test_df[metric] = np.random.normal(
        loc=shift_mean, scale=std, size=len(X_train[metric])
    )

    if len(reject_df) == 0:
        continue
  
    real_tester = test_df
    for model in model_dict.keys():
        clf = model_dict[model]
        y_score = clf.predict_proba(real_tester.drop("y", axis=1))[:, 1]
        y_pred = clf.predict(real_tester.drop("y", axis=1))

        if model == "mlp":
            yr_mlp_ms.append(accuracy_score(real_tester["y"], y_pred))

        if model == "knn":
            yr_knn_ms.append(accuracy_score(real_tester["y"], y_pred))

        if model == "dt":
            yr_dt_ms.append(accuracy_score(real_tester["y"], y_pred))

        if model == "rf":
            yr_rf_ms.append(accuracy_score(real_tester["y"], y_pred))

        if model == "gbc":
            yr_gbc_ms.append(accuracy_score(real_tester["y"], y_pred))


# Compare to performance on oracle/test

In [22]:
ids = np.where((X_train[metric] > xs[0]) & (X_train[metric] < xs[-1]))
quantiles = X_train[metric].iloc[ids].quantile([0.25, 0.5, 0.75]).values
q1 = xs < quantiles[0]
q2 = (xs > quantiles[0]) & (xs < quantiles[2])
q3 = xs > quantiles[2]


results = {}

q1_dict = {}
q1_dict["Error 3S"] = np.mean(np.abs(np.mean(ys_rf_all, axis=0) - yr_rf)[q1])
q1_dict["Error MS"] = np.mean(np.abs(np.array(yr_rf_ms) - np.array(yr_rf))[q1])
q1_dict["Error RS"] = np.mean(np.abs(np.array(yr_rf_val) - np.array(yr_rf))[q1])

q2_dict = {}
q2_dict["Error 3S"] = np.mean(np.abs(np.mean(ys_rf_all, axis=0) - yr_rf)[q2])
q2_dict["Error MS"] = np.mean(np.abs(np.array(yr_rf_ms) - np.array(yr_rf))[q2])
q2_dict["Error RS"] = np.mean(np.abs(np.array(yr_rf_val) - np.array(yr_rf))[q2])

q3_dict = {}
q3_dict["Error 3S"] = np.mean(np.abs(np.mean(ys_rf_all, axis=0) - yr_rf)[q3])
q3_dict["Error MS"] = np.mean(np.abs(np.array(yr_rf_ms) - np.array(yr_rf))[q3])
q3_dict["Error RS"] = np.mean(np.abs(np.array(yr_rf_val) - np.array(yr_rf))[q3])

results["Q1"] = q1_dict
results["Q2"] = q2_dict
results["Q3"] = q3_dict


threeS_err = np.abs(np.mean(ys_rf_all, axis=0) - yr_rf)
avg_dict = {}
avg_dict["Error 3S"] = np.mean(threeS_err)
avg_dict["Error MS"] = np.mean(np.abs(np.array(yr_rf_ms) - np.array(yr_rf)))
avg_dict["Error RS"] = np.mean(np.abs(np.array(yr_rf_val) - np.array(yr_rf)))

results["avg"] = avg_dict

results


{'Q1': {'Error 3S': 0.032666666666666656,
  'Error MS': 0.07165074450832964,
  'Error RS': 0.09199999999999997},
 'Q2': {'Error 3S': 0.010500000000000037,
  'Error MS': 0.047062251216275997,
  'Error RS': 0.16125},
 'Q3': {'Error 3S': 0.008500000000000027,
  'Error MS': 0.07804378593542678,
  'Error RS': 0.18733333333333338},
 'avg': {'Error 3S': 0.01655000000000002,
  'Error MS': 0.06373325961963733,
  'Error RS': 0.1483}}

In [24]:
ids = np.where((X_train[metric] > xs[0]) & (X_train[metric] < xs[-1]))
quantiles = X_train[metric].iloc[ids].quantile([0.25, 0.5, 0.75]).values
q1 = xs < quantiles[0]
q2 = (xs > quantiles[0]) & (xs < quantiles[2])
q3 = xs > quantiles[2]


results = {}

q1_dict = {}
q1_dict["Error 3S"] = np.mean(np.abs(np.mean(ys_rf_all, axis=0) - yr_rf)[q1])
q1_dict["Error MS"] = np.mean(np.abs(np.array(yr_rf_ms) - np.array(yr_rf))[q1])
q1_dict["Error RS"] = np.mean(np.abs(np.array(yr_rf_val) - np.array(yr_rf))[q1])

q2_dict = {}
q2_dict["Error 3S"] = np.mean(np.abs(np.mean(ys_rf_all, axis=0) - yr_rf)[q2])
q2_dict["Error MS"] = np.mean(np.abs(np.array(yr_rf_ms) - np.array(yr_rf))[q2])
q2_dict["Error RS"] = np.mean(np.abs(np.array(yr_rf_val) - np.array(yr_rf))[q2])

q3_dict = {}
q3_dict["Error 3S"] = np.mean(np.abs(np.mean(ys_rf_all, axis=0) - yr_rf)[q3])
q3_dict["Error MS"] = np.mean(np.abs(np.array(yr_rf_ms) - np.array(yr_rf))[q3])
q3_dict["Error RS"] = np.mean(np.abs(np.array(yr_rf_val) - np.array(yr_rf))[q3])

results["Q1"] = q1_dict
results["Q2"] = q2_dict
results["Q3"] = q3_dict


threeS_err = np.abs(np.mean(ys_rf_all, axis=0) - yr_rf)
avg_dict = {}
avg_dict["Error 3S"] = np.mean(threeS_err)
avg_dict["Error MS"] = np.mean(np.abs(np.array(yr_rf_ms) - np.array(yr_rf)))
avg_dict["Error RS"] = np.mean(np.abs(np.array(yr_rf_val) - np.array(yr_rf)))

results["avg"] = avg_dict
