# CTGAN finetuning с Optuna

Нотбук подбирает гиперпараметры CTGAN для выбранного датасета и метода кодирования из реестра. В качестве целевой метрики используется качество модели, обученной на синтетике и проверенной на реальных данных (accuracy или R²). Лучшие параметры сохраняются в файл `<dataset_name>.txt` в папке `optuna_results`.

In [None]:
import ast
import json
from pathlib import Path
import traceback
import optuna
import numpy as np
import pandas as pd
from ctgan import CTGAN
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.metrics import accuracy_score, r2_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
# Конфигурация для полного запуска
CANDIDATE_REGISTRY_PATHS = [
    Path("datasets/datasets_registry.csv"),
    Path("../datasets/datasets_registry.csv"),
    Path("datasets_registry.csv"),
    Path("../datasets_registry.csv"),
]
DATASETS_REGISTRY = next((p for p in CANDIDATE_REGISTRY_PATHS if p.exists()), CANDIDATE_REGISTRY_PATHS[0])
dataset_name = "adult"                # имя датасета из реестра
encoding_method = "one_hot_encoding"  # one_hot_encoding | label_encoding | frequency_encoding | original
N_TRIALS = 120                         # количество попыток Optuna для полного поиска
EPOCHS = 300                           # количество эпох обучения CTGAN на каждую попытку
RANDOM_STATE = 42
OUTPUT_DIR = Path("optuna_results")
OUTPUT_DIR.mkdir(exist_ok=True)
output_path = OUTPUT_DIR / f"{dataset_name}.txt"


In [None]:
def load_registry(registry_path: Path = DATASETS_REGISTRY) -> pd.DataFrame:
    if not registry_path.exists():
        raise FileNotFoundError(
            f"Не найден datasets_registry.csv. Проверьте пути: {[str(p) for p in CANDIDATE_REGISTRY_PATHS]}"
        )
    df = pd.read_csv(registry_path, skipinitialspace=True)
    df.columns = [c.strip() for c in df.columns]
    df["dataset_name"] = df["dataset_name"].str.strip()
    return df


def _resolve_path(path_str: str, anchors: list[Path]) -> Path:
    p = Path(path_str)
    if p.is_absolute():
        return p

    # Если путь начинается с datasets/, пробуем от корня репозитория
    if path_str.startswith("datasets/"):
        for anchor in anchors:
            candidate = (anchor / path_str).resolve()
            if candidate.exists():
                return candidate

    for anchor in anchors:
        candidate = (anchor / path_str).resolve()
        if candidate.exists():
            return candidate

    return (anchors[0] / path_str).resolve()


def get_dataset_info(name: str) -> dict:
    registry = load_registry()
    row = registry.loc[registry["dataset_name"] == name]
    if row.empty:
        raise ValueError(f"Датасет {name} не найден в {DATASETS_REGISTRY}")
    rec = row.iloc[0]

    repo_root = DATASETS_REGISTRY.parent.parent.resolve()
    anchors = [repo_root, DATASETS_REGISTRY.parent.resolve(), Path.cwd()]

    dataset_csv = _resolve_path(str(rec["dataset_csv"]), anchors)
    dataset_path = _resolve_path(str(rec["dataset_path"]), anchors)

    return {
        "dataset_csv": dataset_csv,
        "dataset_path": dataset_path,
        "target": rec["target"].strip(),
    }


def get_encoded_dataset(name: str, method: str):
    info = get_dataset_info(name)
    data_csv_path = info["dataset_csv"]
    if not data_csv_path.exists():
        raise FileNotFoundError(
            f"{data_csv_path} не найден. Выполните notebooks/dataset_encoding.ipynb или скорректируйте путь в datasets_registry.csv"
        )

    data_df = pd.read_csv(data_csv_path)
    row = data_df.loc[data_df["method"] == method]
    if row.empty:
        raise ValueError(f"Метод {method} не найден в {data_csv_path}")

    rec = row.iloc[0]
    new_cat_cols_raw = str(rec.get("New_cat_cols", "[]"))
    try:
        new_cat_cols = ast.literal_eval(new_cat_cols_raw)
    except Exception:
        new_cat_cols = []

    repo_root = DATASETS_REGISTRY.parent.parent.resolve()
    anchors = [repo_root, DATASETS_REGISTRY.parent.resolve(), data_csv_path.parent.resolve(), Path.cwd()]
    dataset_path = _resolve_path(str(rec["path"]), anchors)

    if not dataset_path.exists():
        raise FileNotFoundError(
            f"{dataset_path} не найден. Перегенерируйте кодировки через dataset_encoding.ipynb"
        )

    return {
        "csv_path": dataset_path,
        "target": info["target"],
        "discrete_features": [c for c in new_cat_cols if c],
    }



In [None]:
def prepare_data(csv_path: Path, target: str):
    df = pd.read_csv(csv_path)
    if target not in df.columns:
        raise ValueError(f"Целевая колонка {target} отсутствует в {csv_path}")
    X = df.drop(columns=[target])
    y = df[target]
    is_regression = y.nunique() > 20 and pd.api.types.is_numeric_dtype(y)
    return df, X, y, is_regression


def evaluate_synthetic_quality(df_real: pd.DataFrame, df_synth: pd.DataFrame, target: str, is_regression: bool):
    X_real = df_real.drop(columns=[target])
    y_real = df_real[target]
    X_synth = df_synth.drop(columns=[target])
    y_synth = df_synth[target]

    X_train_syn, _, y_train_syn, _ = train_test_split(
        X_synth, y_synth, test_size=0.25, random_state=RANDOM_STATE, stratify=None
    )
    X_train_real, X_val_real, y_train_real, y_val_real = train_test_split(
        X_real, y_real, test_size=0.25, random_state=RANDOM_STATE, stratify=None
    )

    if is_regression:
        model = make_pipeline(StandardScaler(), Ridge(random_state=RANDOM_STATE))
        model.fit(X_train_syn, y_train_syn)
        preds = model.predict(X_val_real)
        return r2_score(y_val_real, preds)

    # Классификация
    model = make_pipeline(
        StandardScaler(),
        LogisticRegression(max_iter=2000, solver="lbfgs", n_jobs=4),
    )
    model.fit(X_train_syn, y_train_syn)
    preds = model.predict(X_val_real)
    return accuracy_score(y_val_real, preds)


def objective(trial, df_real, target, discrete_features, is_regression):
    params = {
        "embedding_dim": trial.suggest_int("embedding_dim", 64, 256, step=32),
        "gen_dim": trial.suggest_categorical("gen_dim", [128, 256, 512]),
        "disc_dim": trial.suggest_categorical("disc_dim", [64, 128, 256]),
        "batch_size": trial.suggest_categorical("batch_size", [256, 512, 1024]),
        "generator_lr": trial.suggest_float("generator_lr", 1e-4, 5e-3, log=True),
        "discriminator_lr": trial.suggest_float("discriminator_lr", 1e-4, 5e-3, log=True),
    }

    try:
        ctgan = CTGAN(
            epochs=EPOCHS,
            batch_size=int(params["batch_size"]),
            embedding_dim=int(params["embedding_dim"]),
            generator_dim=(int(params["gen_dim"]), int(params["gen_dim"])),
            discriminator_dim=(int(params["disc_dim"]), int(params["disc_dim"])),
            generator_lr=params["generator_lr"],
            discriminator_lr=params["discriminator_lr"],
            pac=1,
            enable_gpu=False,
            verbose=False,
        )

        ctgan.fit(df_real, discrete_features)
        df_synth = ctgan.sample(len(df_real))

        score = evaluate_synthetic_quality(df_real, df_synth, target, is_regression)
        return score
    except Exception:
        print("Trial failed:")
        traceback.print_exc()
        return -np.inf


encoded_info = get_encoded_dataset(dataset_name, encoding_method)
df_real, X_real, y_real, is_regression = prepare_data(encoded_info["csv_path"], encoded_info["target"])
discrete_features = [c for c in encoded_info["discrete_features"] if c in df_real.columns]
print(f"Датасет: {dataset_name} | метод: {encoding_method} | строк: {len(df_real)} | фичей: {X_real.shape[1]} | дискретных: {len(discrete_features)}")



Датасет: adult | метод: one_hot_encoding | строк: 10000 | фичей: 118 | дискретных: 113


In [None]:
pruner = optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=0)
study = optuna.create_study(
    direction="maximize",
    study_name=f"ctgan_{dataset_name}_{encoding_method}",
    pruner=pruner,
)
study.optimize(
    lambda trial: objective(trial, df_real, encoded_info["target"], discrete_features, is_regression),
    n_trials=N_TRIALS,
    show_progress_bar=True,
)

print("Лучший результат:", study.best_value)
print("Параметры:")
print(json.dumps(study.best_trial.params, indent=2))

result_payload = {
    "dataset": dataset_name,
    "method": encoding_method,
    "score": study.best_value,
    "epochs": EPOCHS,
    "n_trials": N_TRIALS,
    "params": study.best_trial.params,
}

output_path.write_text(json.dumps(result_payload, indent=2), encoding="utf-8")
print(f"Сохранено: {output_path.resolve()}")



[I 2025-12-10 13:04:59,929] A new study created in memory with name: ctgan_adult_one_hot_encoding
  0%|          | 0/120 [00:00<?, ?it/s]

## COLAB

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
!git clone https://github.com/IliaOknesavi/NIR.git

Cloning into 'NIR'...
remote: Enumerating objects: 284, done.[K
remote: Counting objects: 100% (16/16), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 284 (delta 0), reused 8 (delta 0), pack-reused 268 (from 1)[K
Receiving objects: 100% (284/284), 108.07 MiB | 21.61 MiB/s, done.
Resolving deltas: 100% (56/56), done.
Updating files: 100% (142/142), done.


In [3]:
os.chdir("/content/NIR")

In [4]:
!pip install -r requirements.txt

Collecting sdv>=1.10.0 (from -r requirements.txt (line 1))
  Downloading sdv-1.30.0-py3-none-any.whl.metadata (14 kB)
Collecting sdmetrics>=0.13.1 (from -r requirements.txt (line 2))
  Downloading sdmetrics-0.24.0-py3-none-any.whl.metadata (9.3 kB)
Collecting ctgan>=0.10.0 (from -r requirements.txt (line 3))
  Downloading ctgan-0.11.1-py3-none-any.whl.metadata (10 kB)
Collecting openml>=0.14.1 (from -r requirements.txt (line 4))
  Downloading openml-0.15.1-py3-none-any.whl.metadata (10 kB)
Collecting category_encoders>=2.6.0 (from -r requirements.txt (line 10))
  Downloading category_encoders-2.9.0-py3-none-any.whl.metadata (7.9 kB)
Collecting optuna>=3.5 (from -r requirements.txt (line 12))
  Downloading optuna-4.6.0-py3-none-any.whl.metadata (17 kB)
Collecting boto3<2.0.0,>=1.28 (from sdv>=1.10.0->-r requirements.txt (line 1))
  Downloading boto3-1.42.6-py3-none-any.whl.metadata (6.8 kB)
Collecting botocore<2.0.0,>=1.31 (from sdv>=1.10.0->-r requirements.txt (line 1))
  Downloading b