In [None]:
# ╔══════════ 0 · Imports ───────────────────────────────────────────╗
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
import torch

from sdv.single_table import CTGANSynthesizer
from sdv.metadata import Metadata

# ╔══════════ 1 · Fixed-epoch trainer (800 epochs) ──────────────────╗
def train_ctgan_800(
    df_slice: pd.DataFrame,
    *,
    service_code: str,
    total_epochs: int = 800,     # ← fixed budget
    step: int = 50,              # epochs per fit() call
    outdir: Path = Path("checkpoints_3/models"),
):
    """Train a CTGAN for exactly 800 epochs and save the model."""
    outdir.mkdir(parents=True, exist_ok=True)
    metadata = Metadata.detect_from_dataframes({"table": df_slice})

    done = 0
    bar = tqdm(total=total_epochs, desc=f"CTGAN [{service_code}]",
               unit="epoch", dynamic_ncols=True)

    while done < total_epochs:
        this_step = min(step, total_epochs - done)
        synth = CTGANSynthesizer(
            metadata,
            epochs=1200,
            batch_size=1024,
            pac=8,
            embedding_dim=128,
            generator_dim=(256, 256),
            discriminator_dim=(256, 256),
            generator_lr=2e-4,
            discriminator_lr=2e-4,
            cuda=True,
            verbose=True
        )
        synth.fit(df_slice)
        done += this_step
        bar.update(this_step)

    bar.close()

    path = outdir / f"{service_code}_ep{total_epochs}.pkl"
    synth.save(path)
    return path

# ╔══════════ 2 · Data loading & loop over top-5 codes ──────────────╗
DATA_PATH = "filtered_data/df_v2_filtered.pkl"
TOP_K     = 5

df = pd.read_pickle(DATA_PATH)
codes = df["service_code_description"].value_counts().head(TOP_K).index.tolist()

for code in codes:
    slice_df = df[df["service_code_description"] == code].copy()
    print(f"\n🛠️  Training CTGAN for {code} — rows: {len(slice_df):,}")
    model_path = train_ctgan_800(slice_df, service_code=code)
    print(f"✅  Saved model → {model_path}")

# optional GPU tidy-up
if torch.cuda.is_available():
    torch.cuda.empty_cache()



🛠️  Training CTGAN for ASSIST MEMBER OF THE PUBLIC — rows: 169,406


CTGAN [ASSIST MEMBER OF THE PUBLIC]:   0%|          | 0/800 [00:00<?, ?epoch/s]


We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.


Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. Call result.infer_objects(copy=False) instead. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`

