From df31a328fa554864ff25378edfdcfb2be1686edf Mon Sep 17 00:00:00 2001 From: Eric Pham-Hung Date: Mon, 20 Oct 2025 13:14:17 -0700 Subject: [PATCH] adding extrinsic eval notebook Signed-off-by: Eric Pham-Hung --- ..._synthesizer_101_with_extrinsic_eval.ipynb | 990 ++++++++++++++++++ 1 file changed, 990 insertions(+) create mode 100644 nemo/NeMo-Safe-Synthesizer/intro/safe_synthesizer_101_with_extrinsic_eval.ipynb diff --git a/nemo/NeMo-Safe-Synthesizer/intro/safe_synthesizer_101_with_extrinsic_eval.ipynb b/nemo/NeMo-Safe-Synthesizer/intro/safe_synthesizer_101_with_extrinsic_eval.ipynb new file mode 100644 index 000000000..8ea636697 --- /dev/null +++ b/nemo/NeMo-Safe-Synthesizer/intro/safe_synthesizer_101_with_extrinsic_eval.ipynb @@ -0,0 +1,990 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "630e3e17", + "metadata": {}, + "source": [ + "# ๐ŸŽ›๏ธ NeMo Safe Synthesizer 101: The Basics\n", + "\n", + "> โš ๏ธ **Warning**: NeMo Safe Synthesizer is in Early Access and not recommended for production use.\n", + "\n", + "
\n", + "\n", + "In this notebook, we demonstrate how to create a synthetic version of a tabular dataset using the NeMo Microservices Python SDK. The notebook should take about 20 minutes to run.\n", + "\n", + "After completing this notebook, you'll be able to:\n", + "- Use the NeMo Microservices SDK to interact with Safe Synthesizer\n", + "- Create novel synthetic data that follows the statistical properties of your input dataset\n", + "- Access an evaluation report on synthetic data quality and privacy\n" + ] + }, + { + "cell_type": "markdown", + "id": "8be84f5d", + "metadata": {}, + "source": [ + "#### ๐Ÿ’พ Install dependencies\n", + "\n", + "**IMPORTANT** ๐Ÿ‘‰ Ensure you have a NeMo Microservices Platform deployment available. Follow the quickstart or Helm chart instructions in your environment's setup guide. You may need to restart your kernel after installing dependencies.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9f5d6f5a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/shadeform/GenerativeAIExamples/nemo/NeMo-Safe-Synthesizer/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "from nemo_microservices import NeMoMicroservices\n", + "from nemo_microservices.beta.safe_synthesizer.builder import SafeSynthesizerBuilder\n", + "\n", + "import logging\n", + "logging.basicConfig(level=logging.WARNING)\n", + "logging.getLogger(\"httpx\").setLevel(logging.WARNING)" + ] + }, + { + "cell_type": "markdown", + "id": "53bb2807", + "metadata": {}, + "source": [ + "### โš™๏ธ Initialize the NeMo Safe Synthesizer Client\n", + "\n", + "- The Python SDK provides a wrapper around the NeMo Microservices Platform APIs.\n", + "- `http://localhost:8080` is the default url for the client's `base_url` in the quickstart.\n", + "- If using a managed or remote deployment, ensure correct base URLs and tokens.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8c15ab93", + "metadata": {}, + "outputs": [], + "source": [ + "client = NeMoMicroservices(\n", + " base_url=\"http://localhost:8080\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "74d72ef7", + "metadata": {}, + "source": [ + "NeMo DataStore is launched as one of the services, and we'll use it to manage our storage. so we'll set the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ab037a3a", + "metadata": {}, + "outputs": [], + "source": [ + "datastore_config = {\n", + " \"endpoint\": \"http://localhost:3000/v1/hf\",\n", + " \"token\": \"\",\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "2d66c819", + "metadata": {}, + "source": [ + "## ๐Ÿ“ฅ Load input data\n", + "\n", + "Safe Synthesizer learns the patterns and correlations in your input dataset to produce synthetic data with similar properties. For this tutorial, we will use a small public sample dataset. Replace it with your own data if desired.\n", + "\n", + "The sample dataset used here is a set of women's clothing reviews, including age, product category, rating, and review text. Some of the reviews contain Personally Identifiable Information (PII), such as height, weight, age, and location." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "daa955b6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/shadeform/GenerativeAIExamples/nemo/NeMo-Safe-Synthesizer/.venv/bin/python: No module named pip\n", + "/bin/bash: line 1: uv: command not found\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install kagglehub || uv pip install kagglehub" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7204f213", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading from https://www.kaggle.com/api/v1/datasets/download/nicapotato/womens-ecommerce-clothing-reviews?dataset_version_number=1...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2.79M/2.79M [00:00<00:00, 70.2MB/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting files...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Clothing IDAgeTitleReview TextRatingRecommended INDPositive Feedback CountDivision NameDepartment NameClass Name
076733NaNAbsolutely wonderful - silky and sexy and comf...410InitmatesIntimateIntimates
1108034NaNLove this dress! it's sooo pretty. i happene...514GeneralDressesDresses
2107760Some major design flawsI had such high hopes for this dress and reall...300GeneralDressesDresses
3104950My favorite buy!I love, love, love this jumpsuit. it's fun, fl...510General PetiteBottomsPants
484747Flattering shirtThis shirt is very flattering to all due to th...516GeneralTopsBlouses
\n", + "
" + ], + "text/plain": [ + " Clothing ID Age Title \\\n", + "0 767 33 NaN \n", + "1 1080 34 NaN \n", + "2 1077 60 Some major design flaws \n", + "3 1049 50 My favorite buy! \n", + "4 847 47 Flattering shirt \n", + "\n", + " Review Text Rating Recommended IND \\\n", + "0 Absolutely wonderful - silky and sexy and comf... 4 1 \n", + "1 Love this dress! it's sooo pretty. i happene... 5 1 \n", + "2 I had such high hopes for this dress and reall... 3 0 \n", + "3 I love, love, love this jumpsuit. it's fun, fl... 5 1 \n", + "4 This shirt is very flattering to all due to th... 5 1 \n", + "\n", + " Positive Feedback Count Division Name Department Name Class Name \n", + "0 0 Initmates Intimate Intimates \n", + "1 4 General Dresses Dresses \n", + "2 0 General Dresses Dresses \n", + "3 0 General Petite Bottoms Pants \n", + "4 6 General Tops Blouses " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import kagglehub\n", + "import pandas as pd\n", + "\n", + "# Download latest version\n", + "path = kagglehub.dataset_download(\"nicapotato/womens-ecommerce-clothing-reviews\")\n", + "df = pd.read_csv(f\"{path}/Womens Clothing E-Commerce Reviews.csv\", index_col=0)\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "87d72c68", + "metadata": {}, + "source": [ + "## ๐Ÿ—๏ธ Create a Safe Synthesizer job\n", + "\n", + "The `SafeSynthesizerBuilder` provides a fluent interface to configure and submit jobs.\n", + "\n", + "The following code creates and submits a job:\n", + "- `SafeSynthesizerBuilder(client)`: initialize with the NeMo Microservices client.\n", + "- `.from_data_source(df)`: set the input data source.\n", + "- `.with_datastore(datastore_config)`: configure model artifact storage.\n", + "- `.with_replace_pii()`: enable automatic replacement of PII.\n", + "- `.synthesize()`: train and generate synthetic data.\n", + "- `.create_job()`: submit the job to the platform.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85d9de56", + "metadata": {}, + "outputs": [], + "source": [ + "job = (\n", + " SafeSynthesizerBuilder(client)\n", + " .from_data_source(df)\n", + " .with_datastore(datastore_config)\n", + " .with_replace_pii()\n", + " .synthesize()\n", + " .create_job()\n", + ")\n", + "\n", + "print(f\"job_id = {job.job_id}\")\n", + "job.wait_for_completion()\n", + "\n", + "print(f\"Job finished with status {job.fetch_status()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa2eacb2", + "metadata": {}, + "outputs": [], + "source": [ + "# If your notebook shuts down, it's okay, your job is still running on the microservices platform.\n", + "# You can get the same job object and interact with it again by uncommenting the following code\n", + "# snippet, and modifying it with the job id from the previous cell output.\n", + "\n", + "# from nemo_microservices.beta.safe_synthesizer.sdk.job import SafeSynthesizerJob\n", + "# job = SafeSynthesizerJob(job_id=\"\", client=client)" + ] + }, + { + "cell_type": "markdown", + "id": "285d4a9d", + "metadata": {}, + "source": [ + "## ๐Ÿ‘€ View synthetic data\n", + "\n", + "After the job completes, fetch the generated synthetic dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7f25574a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Clothing IDAgeTitleReview TextRatingRecommended INDPositive Feedback CountDivision NameDepartment NameClass Name
067039Beautiful in person!This top is just beautiful in person! the colo...510GeneralTopsKnits
118649Pretty; not flatteringI love the color of this shirt, and it's very ...410GeneralTopsBlouses
238246Fabulous wool crazy sweaterThis sweater is gorgeous. the fabric goes wit...512GeneralJacketsJackets
318638Pretty fall topI really liked this, all things considered. i ...415GeneralTopsBlouses
461937NaNI ordered the pink color. it's a very pretty t...511GeneralTopsKnits
.................................
99522856White pajamas my favorite pair since when i wa...Wish they had more colors. i ordered two pair ...510General PetiteBottomsJackets
99624755Great casual shirt for any occasion!I love this shirt!! not only is it a great cas...512GeneralTopsKnits
99790264NaNNaN510General PetiteDressesDresses
998661047Amazing color and patternI love the fit of this skirt. i usually wear a...510GeneralBottomsSkirts
99939242Not happy with the finished productGreat color and style but the cut is not exact...410GeneralTopsBlouses
\n", + "

1000 rows ร— 10 columns

\n", + "
" + ], + "text/plain": [ + " Clothing ID Age Title \\\n", + "0 670 39 Beautiful in person! \n", + "1 186 49 Pretty; not flattering \n", + "2 382 46 Fabulous wool crazy sweater \n", + "3 186 38 Pretty fall top \n", + "4 619 37 NaN \n", + ".. ... ... ... \n", + "995 228 56 White pajamas my favorite pair since when i wa... \n", + "996 247 55 Great casual shirt for any occasion! \n", + "997 902 64 NaN \n", + "998 6610 47 Amazing color and pattern \n", + "999 392 42 Not happy with the finished product \n", + "\n", + " Review Text Rating \\\n", + "0 This top is just beautiful in person! the colo... 5 \n", + "1 I love the color of this shirt, and it's very ... 4 \n", + "2 This sweater is gorgeous. the fabric goes wit... 5 \n", + "3 I really liked this, all things considered. i ... 4 \n", + "4 I ordered the pink color. it's a very pretty t... 5 \n", + ".. ... ... \n", + "995 Wish they had more colors. i ordered two pair ... 5 \n", + "996 I love this shirt!! not only is it a great cas... 5 \n", + "997 NaN 5 \n", + "998 I love the fit of this skirt. i usually wear a... 5 \n", + "999 Great color and style but the cut is not exact... 4 \n", + "\n", + " Recommended IND Positive Feedback Count Division Name Department Name \\\n", + "0 1 0 General Tops \n", + "1 1 0 General Tops \n", + "2 1 2 General Jackets \n", + "3 1 5 General Tops \n", + "4 1 1 General Tops \n", + ".. ... ... ... ... \n", + "995 1 0 General Petite Bottoms \n", + "996 1 2 General Tops \n", + "997 1 0 General Petite Dresses \n", + "998 1 0 General Bottoms \n", + "999 1 0 General Tops \n", + "\n", + " Class Name \n", + "0 Knits \n", + "1 Blouses \n", + "2 Jackets \n", + "3 Blouses \n", + "4 Knits \n", + ".. ... \n", + "995 Jackets \n", + "996 Knits \n", + "997 Dresses \n", + "998 Skirts \n", + "999 Blouses \n", + "\n", + "[1000 rows x 10 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Fetch the synthetic data created by the job\n", + "synthetic_df = job.fetch_data()\n", + "synthetic_df\n" + ] + }, + { + "cell_type": "markdown", + "id": "2b25f152", + "metadata": {}, + "source": [ + "## ๐Ÿ“Š View evaluation report\n", + "\n", + "An evaluation comparing the synthetic data to the input data is performed automatically. You can:\n", + "\n", + "- **Inspect key scores**: overall synthetic data quality and privacy.\n", + "- **Download the full HTML report**: includes charts and detailed metrics.\n", + "- **Display the report inline**: useful when viewing in notebook environments.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7b691127", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Synthetic data quality score (0-10, higher is better): 8.9\n", + "Data privacy score (0-10, higher is better): 8.5\n" + ] + } + ], + "source": [ + "# Print selected information from the job summary\n", + "summary = job.fetch_summary()\n", + "print(\n", + " f\"Synthetic data quality score (0-10, higher is better): {summary.synthetic_data_quality_score}\"\n", + ")\n", + "print(f\"Data privacy score (0-10, higher is better): {summary.data_privacy_score}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "39e62ea9", + "metadata": {}, + "outputs": [], + "source": [ + "# Download the full evaluation report to your local machine\n", + "job.save_report(\"evaluation_report.html\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "45f7e22b", + "metadata": {}, + "outputs": [], + "source": [ + "# Fetch and display the full evaluation report inline\n", + "# job.display_report_in_notebook()" + ] + }, + { + "cell_type": "markdown", + "id": "dd1e4925-3620-4b31-bc17-16f74d10fbb5", + "metadata": {}, + "source": [ + "## ๐Ÿงช Extrinsic Evaluation \n", + "\n", + "This section details the **extrinsic evaluation** process, where the quality of the synthetic data is assessed based on how well a model trained on it performs on a real-world task. This comparison is critical for validating the synthetic data's utility.\n", + "\n", + "- **Train Benchmark Model**: A model is trained on a small, fixed subset of the **original data** to establish a performance baseline.\n", + "- **Train Synthetic Model**: A second model, using the same structure, is trained on the **entire synthetic dataset**.\n", + "- **Compare Performance**: Both models are evaluated against the same **fixed holdout test set** ($\\mathbf{X_{test}, y_{test}}$).\n", + "- **Inspect Key Metrics**: The comparison focuses on key metrics like **ROC AUC** and **F1-Score** to determine if the synthetic model performs comparably to the benchmark." + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "37b6df30-6627-4a40-8604-e905ada571b7", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.feature_extraction.text import TfidfVectorizer\n", + "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.metrics import classification_report, accuracy_score, roc_auc_score\n", + "from sklearn.base import clone\n", + "\n", + "# --- 1. Define Features (X) and Target (y) ---\n", + "\n", + "# Separate features (X) from the binary target variable (y).\n", + "X = df.drop('Recommended IND', axis=1)\n", + "y = df['Recommended IND']\n", + "\n", + "# Fill any missing text values with an empty string to prevent TfidfVectorizer errors.\n", + "X['Review Text'] = X['Review Text'].fillna('')\n", + "X['Title'] = X['Title'].fillna('')\n", + "\n", + "# --- 2. Fixed-Size Non-Overlapping Data Split (1000 Train, 500 Test) ---\n", + "\n", + "# Sample 1000 random indices for the training set.\n", + "train_indices = X.sample(n=1000, random_state=1).index\n", + "X_train = X.loc[train_indices]\n", + "y_train = y.loc[train_indices]\n", + "\n", + "# Create a temporary DataFrame containing only the remaining, unused rows.\n", + "remaining_X = X.drop(train_indices)\n", + "\n", + "# Sample 500 random indices from the remaining data for the holdout test set.\n", + "test_indices = remaining_X.sample(n=500, random_state=1).index\n", + "X_test = X.loc[test_indices]\n", + "y_test = y.loc[test_indices]\n", + "\n", + "# --- 3. Define Feature Types for Preprocessing ---\n", + "\n", + "# List the columns for each type of transformation.\n", + "text_features = ['Review Text']\n", + "numerical_features = ['Age', 'Rating', 'Positive Feedback Count']\n", + "categorical_features = ['Division Name', 'Department Name', 'Class Name']\n", + "\n", + "# --- 4. Define Preprocessing Steps ---\n", + "\n", + "# Text: Convert text to numerical vectors (max 5000 terms).\n", + "text_transformer = TfidfVectorizer(stop_words='english', max_features=5000)\n", + "# Numerical: Standardize numerical features (mean=0, std=1).\n", + "numerical_transformer = StandardScaler()\n", + "# Categorical: One-hot encode categories, ignoring unseen categories in the test set.\n", + "categorical_transformer = OneHotEncoder(handle_unknown='ignore')\n", + "\n", + "# --- 5. Create Column Transformer (The Preprocessor) ---\n", + "\n", + "# Apply different transformations to the respective column groups.\n", + "preprocessor = ColumnTransformer(\n", + " transformers=[\n", + " # TfidfVectorizer only takes one column input.\n", + " ('text', text_transformer, text_features[0]),\n", + " ('num', numerical_transformer, numerical_features),\n", + " ('cat', categorical_transformer, categorical_features)\n", + " ],\n", + " # Exclude non-feature columns (like 'Clothing ID' and 'Title') from the model input.\n", + " remainder='drop'\n", + ")\n", + "\n", + "# --- 6. Define the Classifier and Full Pipeline ---\n", + "\n", + "# Choose the classification algorithm (Logistic Regression for recommendation prediction).\n", + "model = LogisticRegression(solver='liblinear', random_state=42)\n", + "\n", + "# Build the final Pipeline: Preprocessing followed by Classification.\n", + "full_pipeline = Pipeline(steps=[\n", + " ('preprocessor', preprocessor),\n", + " ('classifier', model)\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "ee747c80-d42f-4ec5-b27b-2b2462436b92", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--- Training Benchmark Model on Original Data (1000 rows) ---\n", + "Benchmark training and evaluation complete.\n" + ] + } + ], + "source": [ + "# Assuming the full_pipeline structure (preprocessor + classifier) has been defined.\n", + "\n", + "# --- 1. TRAIN BENCHMARK MODEL ON ORIGINAL DATA ---\n", + "\n", + "# Assign the pipeline to a specific name for the original model.\n", + "# NOTE: In a real comparison, you should clone the pipeline here to ensure the original structure is used without modification.\n", + "original_pipeline = full_pipeline \n", + "print(\"\\n--- Training Benchmark Model on Original Data (1000 rows) ---\")\n", + "# Train the pipeline using the fixed, 1000-row original training set.\n", + "original_pipeline.fit(X_train, y_train)\n", + "\n", + "# --- 2. EVALUATE BENCHMARK MODEL ON FIXED TEST SET ---\n", + "\n", + "# Predict class labels on the 500-row fixed holdout test set.\n", + "y_pred_original = original_pipeline.predict(X_test)\n", + "# Predict probabilities and extract the probability for the positive class (Recommended=1), needed for ROC AUC.\n", + "y_prob_original = original_pipeline.predict_proba(X_test)[:, 1]\n", + "\n", + "# --- 3. STORE BENCHMARK RESULTS ---\n", + "\n", + "# Initialize a dictionary to store all model comparison results.\n", + "results = {}\n", + "results['Original'] = {\n", + " # Calculate and store standard metrics using the fixed test set results.\n", + " 'Accuracy': accuracy_score(y_test, y_pred_original),\n", + " 'ROC AUC': roc_auc_score(y_test, y_prob_original),\n", + " # Store the full classification report as a dictionary for detailed metric access (Precision, Recall, F1).\n", + " 'Classification Report': classification_report(y_test, y_pred_original, output_dict=True)\n", + "}\n", + "print(\"Benchmark training and evaluation complete.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "cf3f1d59-8c46-4d84-b813-a4adf88a3422", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--- Training Model on Synthetic Data ---\n", + "Synthetic training and evaluation complete.\n" + ] + } + ], + "source": [ + "# Prepare the synthetic data features and target\n", + "# Extract features (X_synthetic) and fill missing text values.\n", + "X_synthetic = synthetic_df.drop('Recommended IND', axis=1).fillna({'Review Text': '', 'Title': ''})\n", + "# Extract the target variable (y_synthetic).\n", + "y_synthetic = synthetic_df['Recommended IND']\n", + "\n", + "# Clone the original pipeline structure (preprocessor and classifier) to ensure a clean, unfitted start.\n", + "synthetic_pipeline = clone(full_pipeline) \n", + "\n", + "# ----------------- TRAIN ON SYNTHETIC DATA -----------------\n", + "print(\"\\n--- Training Model on Synthetic Data ---\")\n", + "# Fit the pipeline using the entire synthetic dataset.\n", + "synthetic_pipeline.fit(X_synthetic, y_synthetic)\n", + "\n", + "# --- 2. EVALUATE SYNTHETIC MODEL ON FIXED TEST SET ---\n", + "\n", + "# Make predictions using the synthetic-trained model against the fixed 500-row holdout set (X_test).\n", + "y_pred_synthetic = synthetic_pipeline.predict(X_test)\n", + "# Extract the probability for the positive class (Recommended=1) for ROC AUC calculation.\n", + "y_prob_synthetic = synthetic_pipeline.predict_proba(X_test)[:, 1]\n", + "\n", + "# --- 3. STORE SYNTHETIC RESULTS ---\n", + "\n", + "# Store the performance metrics in the results dictionary for side-by-side comparison.\n", + "results['Synthetic'] = {\n", + " 'Accuracy': accuracy_score(y_test, y_pred_synthetic),\n", + " 'ROC AUC': roc_auc_score(y_test, y_prob_synthetic),\n", + " # Store the detailed classification report as a dictionary.\n", + " 'Classification Report': classification_report(y_test, y_pred_synthetic, output_dict=True)\n", + "}\n", + "print(\"Synthetic training and evaluation complete.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "d83e681e-aac2-44d0-83cb-1d93002a725d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "============================================================\n", + " SIDE-BY-SIDE MODEL COMPARISON\n", + " (Tested on 500-Row Holdout Set)\n", + "============================================================\n", + "| | Original (Benchmark) | Synthetic |\n", + "|:--------------------|-----------------------:|------------:|\n", + "| Train Size | 1000.0000 | 1000.0000 |\n", + "| Accuracy | 0.9480 | 0.9380 |\n", + "| ROC AUC Score | 0.9776 | 0.9821 |\n", + "| Precision (Class 1) | 0.9640 | 0.9505 |\n", + "| Recall (Class 1) | 0.9734 | 0.9758 |\n", + "\n", + "============================================================\n", + "Key Finding:\n", + "The Synthetic Model performs AS WELL OR BETTER than the Original Benchmark.\n" + ] + } + ], + "source": [ + "print(\"\\n\" + \"=\"*60)\n", + "print(\" SIDE-BY-SIDE MODEL COMPARISON\")\n", + "print(\" (Tested on 500-Row Holdout Set)\")\n", + "print(\"=\"*60)\n", + "\n", + "# --- 1. PREPARE COMPARISON DATA ---\n", + "\n", + "# Compile key performance metrics from both the Original and Synthetic models into a single dictionary.\n", + "summary_data = {\n", + " 'Model': ['Original (Benchmark)', 'Synthetic'],\n", + " # Record the size of the training data used for each model.\n", + " 'Train Size': [len(X_train), len(X_synthetic)],\n", + " # Gather Accuracy and ROC AUC score for overall performance.\n", + " 'Accuracy': [results['Original']['Accuracy'], results['Synthetic']['Accuracy']],\n", + " 'ROC AUC Score': [results['Original']['ROC AUC'], results['Synthetic']['ROC AUC']],\n", + " # Include key classification metrics (Precision and Recall) for the positive class (Recommended=1).\n", + " 'Precision (Class 1)': [results['Original']['Classification Report']['1']['precision'], results['Synthetic']['Classification Report']['1']['precision']],\n", + " 'Recall (Class 1)': [results['Original']['Classification Report']['1']['recall'], results['Synthetic']['Classification Report']['1']['recall']],\n", + "}\n", + "\n", + "# --- 2. DISPLAY SUMMARY TABLE ---\n", + "\n", + "# Convert the summary dictionary to a pandas DataFrame for structured display.\n", + "summary_df = pd.DataFrame(summary_data).set_index('Model').T\n", + "# Label the row index of the transposed DataFrame as 'Metric'.\n", + "summary_df.columns.name = 'Metric'\n", + "\n", + "# Print the comparison table in Markdown format for clean terminal output, formatting metrics to 4 decimal places.\n", + "print(summary_df.to_markdown(floatfmt=\".4f\"))\n", + "\n", + "print(\"\\n\" + \"=\"*60)\n", + "\n", + "# --- 3. INTERPRETATION ---\n", + "\n", + "# Provide a simple, text-based conclusion based on the most robust metric (ROC AUC).\n", + "print(\"Key Finding:\")\n", + "# Compare ROC AUC scores to determine if the synthetic data generalized as well as the original data.\n", + "if results['Synthetic']['ROC AUC'] >= results['Original']['ROC AUC']:\n", + " print(\"The Synthetic Model performs AS WELL OR BETTER than the Original Benchmark.\")\n", + "else:\n", + " print(\"The Synthetic Model's performance is slightly lower than the Original Benchmark.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9b55961-ddac-4d91-aa4d-9646fb72c7be", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "My Virtual Env", + "language": "python", + "name": "myenv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}