## Loading libraries

In [None]:
import pandas as pd
import numpy as np
import json
from datetime import datetime
from pathlib import Path
from tqdm.auto import tqdm
import os
from itertools import combinations

from utils.text_encodertext_encoder import TextEncoder
from BN_generator import BNGenerator, TextDecoder

## Evaluation Functions

In [2]:
def SRMSE(x_population, resamples):
    """
    Calculates Standardized Root Mean Square Error (SRMSE).
    - Computes SRMSE for marginal and bivariate distributions.

    Parameters:
      - x_population: Original population DataFrame.
      - resamples: Generated synthetic population DataFrame.

    Returns:
      - [srmse_mar, srmse_bi]: Marginal SRMSE and Bivariate SRMSE.
    """
    ## Marginal distribution calculation
    sam_marg_cnt = []
    resam_marg_cnt = []
    for col in x_population.columns:
        pop_series = x_population[col].dropna()
        syn_series = resamples[col].dropna()

        resam = syn_series.value_counts().sort_index()
        sam = pop_series.value_counts().sort_index()
        tab = pd.merge(resam, sam, left_index=True, right_index=True, how='outer').fillna(0)

        sam_prop = tab.iloc[:, 1].values / pop_series.shape[0] if pop_series.shape[0] > 0 else 0
        resam_prop = tab.iloc[:, 0].values / syn_series.shape[0] if syn_series.shape[0] > 0 else 0
        sam_marg_cnt.append(sam_prop)
        resam_marg_cnt.append(resam_prop)

    sam_marg_cnt = np.concatenate(sam_marg_cnt) if sam_marg_cnt else np.array([])
    resam_marg_cnt = np.concatenate(resam_marg_cnt) if resam_marg_cnt else np.array([])

    if sam_marg_cnt.size > 0:
        rmse_mar = np.linalg.norm(sam_marg_cnt - resam_marg_cnt) / np.sqrt(len(sam_marg_cnt))
        ybar_mar = sam_marg_cnt.mean()
        srmse_mar = rmse_mar / ybar_mar if ybar_mar != 0 else np.nan
    else:
        srmse_mar = np.nan

    ## Bivariate distribution calculation
    bi_index = list(combinations(x_population.columns, 2))
    sam_bi_cnt = []
    resam_bi_cnt = []
    for col1, col2 in bi_index:
        pop_pair = x_population[[col1, col2]].dropna()
        syn_pair = resamples[[col1, col2]].dropna()

        sam = pd.DataFrame(pd.crosstab(pop_pair[col1], pop_pair[col2])).stack().sort_index()
        resam = pd.DataFrame(pd.crosstab(syn_pair[col1], syn_pair[col2])).stack().sort_index()
        sam.name = 'pop'
        resam.name = 'syn'
        tab = pd.merge(resam, sam, left_index=True, right_index=True, how='outer').fillna(0)
        sam_prop = tab.iloc[:, 1].values / pop_pair.shape[0] if pop_pair.shape[0] > 0 else 0
        resam_prop = tab.iloc[:, 0].values / syn_pair.shape[0] if syn_pair.shape[0] > 0 else 0
        sam_bi_cnt.append(sam_prop)
        resam_bi_cnt.append(resam_prop)

    sam_bi_cnt = np.concatenate(sam_bi_cnt) if sam_bi_cnt else np.array([])
    resam_bi_cnt = np.concatenate(resam_bi_cnt) if resam_bi_cnt else np.array([])

    if sam_bi_cnt.size > 0:
        rmse_bi = np.linalg.norm(sam_bi_cnt - resam_bi_cnt) / np.sqrt(len(sam_bi_cnt))
        ybar_bi = sam_bi_cnt.mean()
        srmse_bi = rmse_bi / ybar_bi if ybar_bi != 0 else np.nan
    else:
        srmse_bi = np.nan

    return [srmse_mar, srmse_bi]

def calculate_precision_recall(population_df, generated_df):
    """
    Calculates Precision and Recall.
    - Compares rows as tuples.

    Parameters:
      - population_df: Original population DataFrame.
      - generated_df: Generated synthetic population DataFrame.

    Returns:
      - dict: Precision, Recall, F1 Score, unique counts, and matching combination info.
    """
    all_cols = population_df.columns.tolist()

    # Drop rows with any NAs (copy for analysis)
    p_pop_df = population_df.dropna(subset=all_cols)
    p_gen_df = generated_df.dropna(subset=all_cols)

    # Convert rows to tuples
    p_pop = p_pop_df[all_cols].apply(tuple, axis=1)
    p_gen = p_gen_df[all_cols].apply(tuple, axis=1)

    # Precision: Fraction of generated data present in the original data
    pop_set = set(p_pop)
    gen_in_pop = [1 if profile in pop_set else 0 for profile in p_gen]
    precision = round(np.mean(gen_in_pop), 4) if len(gen_in_pop) > 0 else 0.0

    # Recall: Fraction of original data present in the generated data
    gen_set = set(p_gen)
    pop_in_gen = [1 if profile in gen_set else 0 for profile in p_pop]
    recall = round(np.mean(pop_in_gen), 4) if len(pop_in_gen) > 0 else 0.0

    # Calculate F1 Score
    f1_score = round(2 * (precision * recall) / (precision + recall), 4) if (precision + recall) > 0 else 0.0

    # Unique matching combinations (types, duplicates removed)
    matching_unique_combinations = pop_set.intersection(gen_set)

    # Count of matching rows (actual row count)
    # Calculate how many times each matching combination appears in the generated data
    gen_value_counts = p_gen.value_counts()
    matching_rows_count = sum(gen_value_counts[comb] for comb in matching_unique_combinations if comb in gen_value_counts)

    # Unique combination counts
    unique_combinations = {
        'population': int(p_pop.nunique()),
        'generated': int(p_gen.nunique())
    }

    # Matching combination info
    matching_combinations = {
        'unique_types': len(matching_unique_combinations),  # Number of unique matching types
        'total_count': matching_rows_count  # Total count of matching rows
    }

    return {
        'precision': precision,
        'recall': recall,
        'f1_score': f1_score,
        'unique_combinations': unique_combinations,
        'matching_combinations': matching_combinations
    }

def evaluate_synthetic_population(population_csv, generated_csv):
    """
    Evaluates synthetic population data.
    - Reads CSVs; assumes NA handling happens within metric functions or beforehand.

    Parameters:
      - population_csv: Path to the original population data CSV.
      - generated_csv: Path to the generated synthetic population data CSV.
    """
    print(f"Loading data from {population_csv} and {generated_csv}...")

    # Load CSVs: Assuming default NA handling for now.
    population_df = pd.read_csv(population_csv)
    generated_df = pd.read_csv(generated_csv)

    # Ensure columns match (name and order)
    if set(population_df.columns) != set(generated_df.columns):
        print("Warning: Column names or order don't match between datasets")
        print(f"Population columns: {population_df.columns.tolist()}")
        print(f"Generated columns: {generated_df.columns.tolist()}")
        common_cols = sorted(list(set(population_df.columns) & set(generated_df.columns))) # Sort for consistency
        population_df = population_df[common_cols]
        generated_df = generated_df[common_cols]
        print(f"Using common columns (sorted): {common_cols}")
    else:
        # Ensure consistent column order even if names match
        generated_df = generated_df[population_df.columns]

    # Data Summary
    print("\n=== Data Summary ===")
    print(f"Population data: {len(population_df):,} rows, {len(population_df.columns)} columns")
    print(f"Generated data: {len(generated_df):,} rows, {len(generated_df.columns)} columns")

    # Calculate SRMSE
    print("\nCalculating SRMSE...")
    srmse_results = SRMSE(population_df, generated_df)
    print(f"SRMSE for marginal distributions: {srmse_results[0]:.4f}")
    print(f"SRMSE for bivariate distributions: {srmse_results[1]:.4f}")

    # Calculate Precision and Recall
    print("\nCalculating precision and recall...")
    pr_metrics = calculate_precision_recall(population_df, generated_df)

    print("\n=== Metrics ===")
    print(f"Precision: {pr_metrics['precision']:.4f}")
    print(f"Recall: {pr_metrics['recall']:.4f}")
    print(f"F1 Score: {pr_metrics['f1_score']:.4f}")

    print("\n=== Unique Combinations Analysis ===")
    print(f"Population unique combinations: {pr_metrics['unique_combinations']['population']:,}")
    print(f"Generated unique combinations: {pr_metrics['unique_combinations']['generated']:,}")

    print("\n=== Matching Combinations Analysis ===")
    # Precision Perspective
    print(f"Precision Perspective) Total complete generated rows: {len(generated_df.dropna(subset=generated_df.columns)):,}") # Using only complete rows for comparison
    print(f"Precision Perspective) Number of rows matching original: {pr_metrics['matching_combinations']['total_count']:,}")
    precision_matching_ratio = (pr_metrics['matching_combinations']['total_count'] / len(generated_df.dropna(subset=generated_df.columns)) * 100) if len(generated_df.dropna(subset=generated_df.columns)) > 0 else 0
    print(f"Precision Perspective) Ratio of matching rows: {precision_matching_ratio:.2f}%")
    # Recall Perspective
    print(f"Recall Perspective) Total unique generated combinations: {pr_metrics['unique_combinations']['generated']:,}")
    print(f"Recall Perspective) Number of unique combinations matching original: {pr_metrics['matching_combinations']['unique_types']:,}")
    recall_matching_ratio = (pr_metrics['matching_combinations']['unique_types'] / pr_metrics['unique_combinations']['generated'] * 100) if pr_metrics['unique_combinations']['generated'] > 0 else 0
    print(f"Recall Perspective) Ratio of matching unique combinations: {recall_matching_ratio:.2f}%")

    return {
        'srmse_marginal': srmse_results[0],
        'srmse_bivariate': srmse_results[1],
        'precision': pr_metrics['precision'],
        'recall': pr_metrics['recall'],
        'f1_score': pr_metrics['f1_score'],
        'unique_combinations': pr_metrics['unique_combinations'],
        'matching_combinations': pr_metrics['matching_combinations']
    }

## Main Inference Execution

In [4]:
import os
import json
import time
import io
import contextlib
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
from tqdm import tqdm


# ═══════════════════════════════════════════════════════════════════════════
# Helper Utilities
# ═══════════════════════════════════════════════════════════════════════════

def estimate_tokens(texts: List[str]) -> int:
    """Rough token estimation using whitespace splitting."""
    return sum(len(t.split()) for t in texts)


def append_and_flush_json(path: Path, texts: List[str]):
    """Atomically overwrite JSON so progress is always recoverable."""
    tmp = path.with_suffix(".tmp")
    with tmp.open("w", encoding="utf-8") as f:
        json.dump(texts, f, ensure_ascii=False, indent=2)
    tmp.replace(path)


# ═══════════════════════════════════════════════════════════════════════════
# Configuration (style kept identical)
# ═══════════════════════════════════════════════════════════════════════════

# Evaluation parameters
epochs_to_evaluate = [35]
temperatures = [0.70]
batch_size = 1000

h_population_path = "h_population.csv"
print("Loading data...")
h_population = pd.read_csv(h_population_path).astype(str)
target_n_rows = len(h_population)

detailed_prompt = "Education status is"

base_output_dir = Path("generated_data/LLM-BN_distilgpt2")
base_output_dir.mkdir(parents=True, exist_ok=True)

all_metrics = []
decoder = TextDecoder()

total_combinations = len(epochs_to_evaluate) * len(temperatures)
current_combination = 0

global_inference_time = 0.0
global_tokens = 0

# ═══════════════════════════════════════════════════════════════════════════
# Main Loop
# ═══════════════════════════════════════════════════════════════════════════
for epoch in epochs_to_evaluate:
    model_path = f"saved_models/LLM-BN_distilgpt2_epoch{epoch}"
    print("\n" + "="*50)
    print(f"===== Model Load: epoch {epoch} =====")
    print("="*50)
    generator = BNGenerator(model_name="distilgpt2")

    try:
        generator.load_model(model_path)
        print(f"Model loaded successfully: {model_path}")
    except Exception as e:
        print(f"Failed to load model: {model_path}\nError: {e}")
        continue

    for temp in temperatures:
        current_combination += 1
        print("\n" + "-"*50)
        print(f"----- Combination {current_combination}/{total_combinations}: epoch{epoch}, temperature={temp} -----")
        print("-"*50)

        output_dir = base_output_dir / f"epoch{epoch}_temp{temp:.2f}"
        output_dir.mkdir(parents=True, exist_ok=True)
        texts_path_1 = output_dir / "generated_texts_initial.json"
        csv_path_1 = output_dir / "generated_data_initial.csv"
        final_csv_path = output_dir / "generated_data_final.csv"
        metrics_path = output_dir / "evaluation_metrics.json"

        if final_csv_path.exists():
            try:
                if len(pd.read_csv(final_csv_path)) >= target_n_rows:
                    print("Existing final dataset found. Skipping generation.")
                    if metrics_path.exists():
                        all_metrics.append(json.loads(metrics_path.read_text(encoding="utf-8")))
                    continue
            except Exception:
                print("Existing CSV unreadable; regenerating.")

        # ── Stage 1: Initial Generation ────────────────────────────────────
        print(f"\nStarting initial sample generation... Target: {target_n_rows} rows")
        generated_texts_1: List[str] = []
        if texts_path_1.exists():
            generated_texts_1 = json.loads(texts_path_1.read_text(encoding="utf-8"))
            print(f"Loaded {len(generated_texts_1)} texts from previous run — resuming.")

        with tqdm(total=target_n_rows, desc="Generating (initial)", dynamic_ncols=True) as bar:
            bar.update(len(generated_texts_1))
            while len(generated_texts_1) < target_n_rows:
                cur = min(batch_size, target_n_rows - len(generated_texts_1))
                start = time.perf_counter()
                with io.StringIO() as buf, contextlib.redirect_stdout(buf):
                    batch = generator.generate(
                        n_samples=cur,
                        condition=detailed_prompt,
                        max_length=512,
                        temperature=temp,
                        batch_size=cur,
                    )
                elapsed = time.perf_counter() - start
                global_inference_time += elapsed
                global_tokens += estimate_tokens(batch)

                generated_texts_1.extend(batch)
                append_and_flush_json(texts_path_1, generated_texts_1)
                bar.update(cur)

        print("\nFirst generated text sample (initial):\n" + "-"*30)
        if generated_texts_1:
            print(generated_texts_1[0])
        else:
            print("No text generated.")
        print("-"*30)

        print("Decoding initial texts...")
        generated_df_1 = decoder.decode_texts(generated_texts_1)
        generated_df_1.to_csv(csv_path_1, index=False)
        print(f"Initial CSV saved: {csv_path_1}")

        if not generated_df_1.empty:
            missing_counts_1 = generated_df_1.isnull().sum()
            print("\nMissing value analysis (initial):")
            print(missing_counts_1[missing_counts_1 > 0].to_string())
            print(f"Total missing values: {generated_df_1.isnull().sum().sum()}")
            print(f"Rows with missing values: {generated_df_1.isnull().any(axis=1).sum()}")
        else:
            print("Initial DataFrame is empty.")

        df_clean_1 = generated_df_1.dropna()
        df_final = df_clean_1.copy()

        # ── Stage 2: Additional Generation ─────────────────────────────────
        if len(df_final) < target_n_rows:
            needed = target_n_rows - len(df_final)
            print(f"\nNeed {needed} more clean rows.")
            texts_path_2 = output_dir / "generated_texts_additional.json"
            extra_texts: List[str] = []
            if texts_path_2.exists():
                extra_texts = json.loads(texts_path_2.read_text(encoding="utf-8"))
                print(f"Loaded {len(extra_texts)} extra texts from previous run — resuming.")

            while len(df_final) < target_n_rows:
                cur_needed = target_n_rows - len(df_final)
                n_gen = max(int(cur_needed * 1.2), batch_size)
                print(f"Generating {n_gen} additional samples…")
                start = time.perf_counter()
                with io.StringIO() as buf, contextlib.redirect_stdout(buf):
                    new_batch = generator.generate(
                        n_samples=n_gen,
                        condition=detailed_prompt,
                        max_length=512,
                        temperature=temp,
                        batch_size=batch_size,
                    )
                elapsed = time.perf_counter() - start
                global_inference_time += elapsed
                global_tokens += estimate_tokens(new_batch)

                extra_texts.extend(new_batch)
                append_and_flush_json(texts_path_2, extra_texts)

                new_df = decoder.decode_texts(new_batch).dropna()
                df_final = pd.concat([df_final, new_df], ignore_index=True)
                print(f"Clean rows accumulated: {len(df_final)} / {target_n_rows}")

            df_final = df_final.head(target_n_rows)
        else:
            print("\nNo additional generation needed.")
            df_final = df_final.head(target_n_rows)

        final_rows = len(df_final)
        print(f"\nFinal dataset prepared. Rows: {final_rows} (Target: {target_n_rows})")
        df_final.to_csv(final_csv_path, index=False)
        print(f"Final CSV saved: {final_csv_path}")

        print("\nEvaluating…")
        try:
            metrics = evaluate_synthetic_population(h_population_path, final_csv_path)
            metrics.update({
                "epoch": epoch,
                "temperature": temp,
                "target_data_rows": target_n_rows,
                "final_data_rows": final_rows,
                "initial_generated_rows": len(generated_df_1),
                "initial_clean_rows": len(df_clean_1),
                "initial_missing_rate": (len(generated_df_1) - len(df_clean_1)) / len(generated_df_1) if len(generated_df_1) else 0,
                "cumulative_inference_time_sec": round(global_inference_time, 2),
                "cumulative_token_count": global_tokens,
                "throughput_tokens_per_sec": round(global_tokens / global_inference_time, 2) if global_inference_time else np.nan,
            })
            all_metrics.append(metrics)
            metrics_path.write_text(json.dumps(metrics, indent=2, default=str), encoding="utf-8")
            print(f"Metrics saved: {metrics_path}")

            print("\nEvaluation Summary:")
            print(f"Epoch {epoch}, Temp {temp} — F1: {metrics['f1_score']:.4f}, SRMSE_M: {metrics['srmse_marginal']:.4f}, SRMSE_B: {metrics['srmse_bivariate']:.4f}")
            print(f"Throughput: {metrics['throughput_tokens_per_sec']} tok/s over {metrics['cumulative_inference_time_sec']} s")
        except Exception as e:
            print(f"Evaluation failed: {e}")
            import traceback; traceback.print_exc()

# ═══════════════════════════════════════════════════════════════════════════
# Final Summary
# ═══════════════════════════════════════════════════════════════════════════
if all_metrics:
    print("\nSummary of results collected so far:")
    print(pd.DataFrame(all_metrics).sort_values(["epoch", "temperature"]).to_string())

summary_json_path = base_output_dir / "all_evaluation_metrics.json"
summary_json_path.write_text(json.dumps(all_metrics, indent=2, default=str), encoding="utf-8")

summary_csv_path = base_output_dir / "evaluation_summary.csv"
pd.DataFrame(all_metrics).to_csv(summary_csv_path, index=False)

print("\n===== Evaluation Complete =====")
print(f"Summary saved to: {summary_csv_path}")

Loading data...

===== Model Load: epoch 35 =====
Using device: cuda
Model loaded from saved_models/LLM-BN_distilgpt2_epoch35
Model loaded successfully: saved_models/LLM-BN_distilgpt2_epoch35

--------------------------------------------------
----- Combination 1/1: epoch35, temperature=0.7 -----
--------------------------------------------------

Starting initial sample generation... Target: 1066319 rows
Loaded 2000 texts from previous run — resuming.


Generating (initial):   0%|                                                                | 0/1066319 [00:00<?, ?it/s]
[Aal Progress:   0%|                                                                         | 0/1000 [00:00<?, ?it/s]
[Aal Progress (Batch 1/1):   0%|                                                             | 0/1000 [00:00<?, ?it/s]
Total Progress (Batch 1/1): 100%|██████████████████████████████████████████████████| 1000/1000 [00:22<00:00, 45.24it/s]
Generating (initial):   0%|▏                                                 | 3000/1066319 [00:22<2:10:43, 135.57it/s]
[Aal Progress:   0%|                                                                         | 0/1000 [00:00<?, ?it/s]
[Aal Progress (Batch 1/1):   0%|                                                             | 0/1000 [00:00<?, ?it/s]
Total Progress (Batch 1/1): 100%|██████████████████████████████████████████████████| 1000/1000 [00:17<00:00, 56.00it/s]
Generating (initial):   0%|▏            


First generated text sample (initial):
------------------------------
Education status is Not student, Age group is [45,50), Kid in household is No, Work type is Manager/Office, Major travel mode is Car, Major departure time is Peak, Work days is 5 days, Gender is Male, Driver license is Yes, Number of household members is 4, Household monthly income level is 5M-10M KRW, Home type is Apartment, Car ownership of household is Yes.
------------------------------
Decoding initial texts...
Initial CSV saved: generated_data\LLM-BN_distilgpt2\epoch35_temp0.70\generated_data_initial.csv

Missing value analysis (initial):
Age           16
Gender         6
Homeincome    22
Hometype      26
CarOwn        26
Driver        15
Workdays      35
Worktype      18
Student        3
NumHH          3
KidinHH        7
ComMode        8
ComTime       27
Total missing values: 212
Rows with missing values: 197


Total Progress (Batch 1/1):   0%|                                                             | 0/1000 [00:00<?, ?it/s]


Need 197 more clean rows.
Generating 1000 additional samples…


Total Progress (Batch 1/1): 100%|██████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 80.81it/s]


Clean rows accumulated: 1067122 / 1066319

Final dataset prepared. Rows: 1066319 (Target: 1066319)
Final CSV saved: generated_data\LLM-BN_distilgpt2\epoch35_temp0.70\generated_data_final.csv

Evaluating…
Loading data from h_population.csv and generated_data\LLM-BN_distilgpt2\epoch35_temp0.70\generated_data_final.csv...

=== Data Summary ===
Population data: 1,066,319 rows, 13 columns
Generated data: 1,066,319 rows, 13 columns

Calculating SRMSE...
SRMSE for marginal distributions: 0.2489
SRMSE for bivariate distributions: 0.6039

Calculating precision and recall...

=== Metrics ===
Precision: 0.9527
Recall: 0.7602
F1 Score: 0.8456

=== Unique Combinations Analysis ===
Population unique combinations: 264,005
Generated unique combinations: 120,541

=== Matching Combinations Analysis ===
Precision Perspective) Total complete generated rows: 1,066,319
Precision Perspective) Number of rows matching original: 1,015,888
Precision Perspective) Ratio of matching rows: 95.27%
Recall Perspective)