In [5]:
import requests
import pandas as pd
import os

def download_census_data(api_key, output_file="acs_pums_2023.parquet"):
    """
    Download census data from the API and save to parquet file.
    
    Args:
        api_key: Census API key
        output_file: Path to save the parquet file
    
    Returns:
        pd.DataFrame: The downloaded census data
    """
    url = (
        "https://api.census.gov/data/2023/acs/acs1/pums?"
        "get=AGEP,SEX,SCHL,OCCP,INDP,PINCP,RAC3P,WKHP,ESR,STATE"
        f"&for=state:*&key={api_key}"
    )
    
    print("Fetching census data from API...")
    response = requests.get(url)
    print(f"Status: {response.status_code}")
    print(f"Content-Type: {response.headers.get('Content-Type')}")
    
    if response.status_code != 200:
        print(f"Error: {response.text[:500]}")
        raise SystemExit("Failed to download census data")
    
    data = response.json()
    
    # Convert JSON → DataFrame
    columns = data[0]
    rows = data[1:]
    df = pd.DataFrame(rows, columns=columns)
    
    # Save to parquet
    df.to_parquet(output_file, index=False)
    print(f"✓ Saved {len(df):,} records to: {output_file}")
    
    return df


# Check if raw data already exists
RAW_DATA_FILE = "acs_pums_2023.parquet"

if os.path.exists(RAW_DATA_FILE):
    print(f"✓ Found existing raw data file: {RAW_DATA_FILE}")
    print("Skipping download step...")
    df = pd.read_parquet(RAW_DATA_FILE)
    print(f"Loaded {len(df):,} records from existing file")
else:
    print(f"Raw data file not found. Downloading...")
    API_KEY = "71cbe35ff6e0c0df92b72d8a233c40a85de5b0b9"
    df = download_census_data(API_KEY, RAW_DATA_FILE)

df.head()

✓ Found existing raw data file: acs_pums_2023.parquet
Skipping download step...
Loaded 3,405,809 records from existing file
Loaded 3,405,809 records from existing file


Unnamed: 0,AGEP,SEX,SCHL,OCCP,INDP,PINCP,RAC3P,WKHP,ESR,STATE,state
0,94,2,16,N,N,12200,1,0,6,34,34
1,18,1,19,N,N,0,1,0,6,34,34
2,75,1,21,N,N,108000,29,0,6,8,8
3,44,1,18,N,N,0,2,0,6,12,12
4,31,1,18,4220,7580,5200,2,20,6,39,39


In [7]:
import pandas as pd
import os

def label_census_data(input_file, code_ref_file="code_reference.json", output_file="acs_pums_2023_labeled.parquet"):
    """
    Label census data using code reference mappings.
    
    Args:
        input_file: Path to raw parquet file
        code_ref_file: Path to code reference JSON file
        output_file: Path to save labeled parquet file
    
    Returns:
        pd.DataFrame: The labeled census data
    """
    print("--- Starting JSON code mapping for the dataframe ---")
    
    # Load code reference
    with open(code_ref_file, 'r', encoding='utf-8') as f:
        code_ref = pd.read_json(f)
    
    # Helper function to build value map from code reference
    def build_value_map(col_spec):
        """
        Extract the mapping dictionary from the code reference structure.
        col_spec: the column specification from code_ref (e.g., code_ref['SEX'])
        Returns: dict mapping code values to labels
        """
        if 'values' in col_spec and 'item' in col_spec['values']:
            return col_spec['values']['item']
        return {}
    
    # Helper function to map a series using the value map
    def map_series_with_codes(series, value_map):
        """
        Map a pandas Series using a value map dictionary.
        Converts numeric codes to string labels.
        series: pandas Series to map
        value_map: dict mapping codes to labels
        Returns: mapped pandas Series
        """
        # Convert series to string for mapping (handles both int and string keys)
        return series.astype(str).map(value_map).fillna(series.astype(str))
    
    # Load the parquet file
    df = pd.read_parquet(input_file)
    
    # Define the columns we want to decode from code_reference.json
    cols_to_decode = [
        "SEX", "SCHL", "OCCP", "INDP",
        "RAC3P", "ESR", "STATE"
    ]
    
    # Define simple column renames (no value mapping needed)
    cols_to_rename = {
        "AGEP": "Age",
        "PINCP": "Annual Income",
        "WKHP": "Weekly Work Hours"
    }
    
    # Keep track of columns to remove
    cols_to_remove = []
    
    # Loop through and apply mappings from code_reference.json
    for col_name in cols_to_decode:
        # Check if column exists in 'df'
        if col_name in df.columns:
            # Check if column exists in 'code_ref' JSON
            if col_name in code_ref:
                try:
                    # Get the value mapping table (e.g., {"1": "Male", "2": "Female"})
                    vmap = build_value_map(code_ref[col_name])
                    
                    # Get the friendly name for the new column (e.g., "Sex")
                    new_col_label = code_ref[col_name].get('label', f"{col_name}_LABEL")
                    
                    # Apply the mapping function to create the new column
                    df[new_col_label] = map_series_with_codes(df[col_name], vmap)
                    
                    # Mark the original column for removal
                    cols_to_remove.append(col_name)
                    
                    print(f"Success: Mapped '{col_name}' to new column '{new_col_label}'")
                except Exception as e:
                    print(f"Error: Failed to map '{col_name}': {e}")
            else:
                print(f"Skipped: Key '{col_name}' not found in code_ref")
        else:
            print(f"Skipped: Column '{col_name}' not found in DataFrame")
    
    # Rename simple columns (just copy values with new name)
    for old_name, new_name in cols_to_rename.items():
        if old_name in df.columns:
            df[new_name] = df[old_name]
            cols_to_remove.append(old_name)
            print(f"Success: Renamed '{old_name}' to '{new_name}'")
        else:
            print(f"Skipped: Column '{old_name}' not found in DataFrame")
    
    # Remove duplicate "state" column if it exists (lowercase version)
    if 'state' in df.columns:
        cols_to_remove.append('state')
        print(f"Success: Marked duplicate 'state' column for removal")
    
    # Remove the original coded columns
    if cols_to_remove:
        df = df.drop(columns=cols_to_remove)
        print(f"\n--- Removed {len(cols_to_remove)} original coded columns: {cols_to_remove} ---")
    
    # Reorder columns: Age, Annual Income, Weekly Work Hours first, then others
    priority_cols = ["Age", "Annual Income", "Weekly Work Hours"]
    existing_priority = [col for col in priority_cols if col in df.columns]
    other_cols = [col for col in df.columns if col not in priority_cols]
    new_col_order = existing_priority + other_cols
    df = df[new_col_order]
    
    # Save labeled data
    df.to_parquet(output_file, index=False)
    print(f"\n--- Mapping complete. ---")
    print(f"✓ Saved {len(df):,} labeled records to: {output_file}")
    
    return df


# Check if labeled data already exists
LABELED_DATA_FILE = "acs_pums_2023_labeled.parquet"

if os.path.exists(LABELED_DATA_FILE):
    print(f"✓ Found existing labeled data file: {LABELED_DATA_FILE}")
    print("Skipping labeling step...")
    df = pd.read_parquet(LABELED_DATA_FILE)
    print(f"Loaded {len(df):,} labeled records from existing file")
else:
    print(f"Labeled data file not found. Starting labeling process...")
    df = label_census_data(RAW_DATA_FILE, output_file=LABELED_DATA_FILE)

print(f"\nFinal labeled DataFrame preview:")
df.head()

✓ Found existing labeled data file: acs_pums_2023_labeled.parquet
Skipping labeling step...
Loaded 3,405,809 labeled records from existing file

Final labeled DataFrame preview:
Loaded 3,405,809 labeled records from existing file

Final labeled DataFrame preview:


Unnamed: 0,Age,Annual Income,Weekly Work Hours,Sex,Educational attainment,Occupation,Industry,Race,Employment status recode,State
0,94,12200,0,Female,Regular high school diploma,N/A (less than 16 years old/NILF who last work...,N/A (less than 16 years old/NILF who last work...,White alone,Not in Labor Force,New Jersey/NJ
1,18,0,0,Male,"1 or more years of college credit, no degree",N/A (less than 16 years old/NILF who last work...,N/A (less than 16 years old/NILF who last work...,White alone,Not in Labor Force,New Jersey/NJ
2,75,108000,0,Male,Bachelor's degree,N/A (less than 16 years old/NILF who last work...,N/A (less than 16 years old/NILF who last work...,White; Some Other Race,Not in Labor Force,Colorado/CO
3,44,0,0,Male,"Some college, but less than 1 year",N/A (less than 16 years old/NILF who last work...,N/A (less than 16 years old/NILF who last work...,Black or African American alone,Not in Labor Force,Florida/FL
4,31,5200,20,Male,"Some college, but less than 1 year",CLN-Janitors And Building Cleaners,PRF-Employment Services,Black or African American alone,Not in Labor Force,Ohio/OH


In [8]:
import pandas as pd
import numpy as np
import scipy.stats as ss
import os

print("--- Starting Synthetic Data Generation ---")

# --- 1. Setup: Load and Rename Data ---
def get_label(key, default):
    """Get label from code_ref dictionary, with fallback to default."""
    try:
        return code_ref[key].get('label', default)
    except KeyError:
        return default

# Map labeled columns to simplified names for processing
rename_map = {
    'Age': 'age',
    'Sex': 'gender',
    'State': 'state',
    'Race': 'race',
    'Annual Income': 'pincp',
    'Weekly Work Hours': 'wkhp',
    'Educational attainment': 'educational_attainment',
    'Occupation': 'occupation',
    'Employment status recode': 'employment_status',
    'Industry': 'industry'
}

# Select and rename columns
cols_to_keep = [col for col in rename_map.keys() if col in df.columns]
d = df[cols_to_keep].rename(columns=rename_map)
print(f"Initial data shape: {d.shape}")

# --- 2. Data Cleaning and Type Conversion ---
# Convert numeric columns
for col in ['age', 'pincp', 'wkhp']:
    if col in d.columns:
        d[col] = pd.to_numeric(d[col], errors='coerce')

# Drop rows with missing critical data
d = d.dropna(subset=['age', 'gender', 'state'])

# Filter age range
d = d[(d.age >= 16) & (d.age <= 90)]
d['age'] = d['age'].astype(int)

# Convert categorical columns
cat_cols = ['gender', 'state', 'race', 'educational_attainment',
            'occupation', 'employment_status', 'industry']
for c in cat_cols:
    if c in d.columns:
        d[c] = d[c].fillna("Not Applicable").astype("category")

print(f"Cleaned data shape: {d.shape}")

# --- 3. Stratified Sampling ---
rng = np.random.default_rng(42)

# Create age bins for stratification
age_bins = [16, 20, 30, 40, 50, 60, 70, 91]
d["age_bin"] = pd.cut(d["age"], bins=age_bins, right=False)

# Stratify by gender and age
strata = ["gender", "age_bin"]
target_n = 1_000_000  # Generate 1 million records

print("Calculating stratum weights...")
strata_weights = d.groupby(strata, observed=True).size()
strata_weights = strata_weights / strata_weights.sum()

print(f"Sampling {target_n:,} records across {len(strata_weights)} strata...")
samples = []

for idx, proportion in strata_weights.items():
    stratum_data = d[(d["gender"] == idx[0]) & (d["age_bin"] == idx[1])]
    
    if len(stratum_data) == 0:
        continue
    
    # Calculate number of samples for this stratum
    n_samples = int(round(proportion * target_n))
    
    if n_samples == 0:
        continue
    
    # Sample with replacement
    sampled = stratum_data.sample(n=n_samples, replace=True, random_state=42)
    samples.append(sampled)

syn = pd.concat(samples, ignore_index=True)
print(f"Generated {len(syn):,} raw synthetic records.")

# --- 4. Reasonableness Check ---
print("\n--- Running Reasonableness Check ---")

# Define illogical conditions
invalid_mask = (
    # Unemployed or not in labor force shouldn't have work hours
    ((syn['employment_status'] == 'Unemployed') & (syn['wkhp'] > 0)) |
    ((syn['employment_status'] == 'Not in labor force') & (syn['wkhp'] > 0)) |
    # Students shouldn't have very high income or excessive work hours
    ((syn['occupation'].str.contains('Student', na=False)) & (syn['pincp'] > 100000)) |
    ((syn['occupation'].str.contains('Student', na=False)) & (syn['wkhp'] > 40))
)

invalid_count = invalid_mask.sum()
print(f"Found {invalid_count:,} illogical records ({100*invalid_count/len(syn):.2f}%)")

# Filter out invalid records
syn_filtered = syn[~invalid_mask].copy()
print(f"Filtered data shape: {syn_filtered.shape}")

# --- 5. Apply Privacy Noise ---
print("\n--- Applying Privacy Noise ---")

# Add jitter to age to prevent exact copy risk
syn_filtered["age"] = (
    syn_filtered["age"].astype(float) + 
    rng.integers(-2, 3, size=len(syn_filtered))
)
syn_filtered["age"] = syn_filtered["age"].clip(16, 90).round().astype(int)

# Shuffle categorical columns to add privacy noise
cols_to_shuffle = ['state', 'race', 'educational_attainment', 'occupation']
for col in cols_to_shuffle:
    if col in syn_filtered.columns:
        syn_filtered[col] = syn_filtered[col].sample(frac=1, random_state=42).values
        print(f"  Shuffled: {col}")

# --- 6. Select Final Columns for Demographics ---
cols_for_output = [
    "age", "gender", "state", "race",
    "educational_attainment", "occupation"
]

# Ensure we only select columns that exist
final_cols = [col for col in cols_for_output if col in syn_filtered.columns]
syn_demo = syn_filtered[final_cols].copy()

print(f"\nFinal synthetic demographics shape: {syn_demo.shape}")
print(f"Columns: {list(syn_demo.columns)}")

# --- 7. Validation Checks ---
print("\n--- Running Validation Checks (Jensen-Shannon Divergence) ---")

def jensen_shannon_divergence(p, q):
    """Calculate Jensen-Shannon divergence between two distributions."""
    p = p / p.sum()
    q = q / q.sum()
    m = 0.5 * (p + q)
    
    # Avoid log(0)
    p = np.where(p == 0, 1e-10, p)
    q = np.where(q == 0, 1e-10, q)
    m = np.where(m == 0, 1e-10, m)
    
    return 0.5 * ss.entropy(p, m) + 0.5 * ss.entropy(q, m)

# Validate distributions for each column
validation_cols = ['age', 'gender', 'state', 'race', 
                   'educational_attainment', 'occupation']

for col in validation_cols:
    if col in d.columns and col in syn_demo.columns:
        real_dist = d[col].value_counts().sort_index()
        syn_dist = syn_demo[col].value_counts().sort_index()
        syn_dist = syn_dist.reindex(real_dist.index, fill_value=0)
        
        js_score = jensen_shannon_divergence(real_dist.values, syn_dist.values)
        print(f"  JS({col:25s}) = {js_score:.4f}")

# Correlation preservation check
print("\n--- Correlation Preservation Check ---")
def calculate_correlation(data, cols):
    """Calculate correlation matrix, handling categorical variables."""
    df_numeric = data[cols].copy()
    for c in cols:
        if str(df_numeric[c].dtype) == "category":
            df_numeric[c] = df_numeric[c].cat.codes
    return df_numeric.corr(numeric_only=True)

cols_for_corr = ["age", "pincp", "wkhp"]
present_cols = [c for c in cols_for_corr if c in d.columns]

if len(present_cols) >= 2:
    print("Real Data Correlation:")
    c_real = calculate_correlation(d, present_cols)
    print(c_real.round(3))
    
    print("\nSynthetic Data Correlation:")
    c_syn = calculate_correlation(syn_filtered, present_cols)
    print(c_syn.round(3))

# --- 8. Save Synthetic Demographics ---
print("\n--- Saving Synthetic Demographics ---")

# Save to current directory
save_path = "synthetic_demographics_1m.parquet"
syn_demo.to_parquet(save_path, index=False)
print(f"✓ Saved {len(syn_demo):,} records to: {save_path}")

# Also save to CSV for easy viewing
csv_path = "synthetic_demographics_1m.csv"
syn_demo.head(1000).to_csv(csv_path, index=False)
print(f"✓ Saved preview (1000 records) to: {csv_path}")

print("\n--- Synthetic Data Generation Complete ---")
print(f"Final dataset preview:")
syn_demo.head()

--- Starting Synthetic Data Generation ---
Initial data shape: (3405809, 10)
Initial data shape: (3405809, 10)
Cleaned data shape: (2818857, 10)
Calculating stratum weights...
Sampling 1,000,000 records across 14 strata...
Cleaned data shape: (2818857, 10)
Calculating stratum weights...
Sampling 1,000,000 records across 14 strata...
Generated 1,000,001 raw synthetic records.

--- Running Reasonableness Check ---
Found 13,442 illogical records (1.34%)
Filtered data shape: (986559, 11)

--- Applying Privacy Noise ---
  Shuffled: state
  Shuffled: race
  Shuffled: educational_attainment
Generated 1,000,001 raw synthetic records.

--- Running Reasonableness Check ---
Found 13,442 illogical records (1.34%)
Filtered data shape: (986559, 11)

--- Applying Privacy Noise ---
  Shuffled: state
  Shuffled: race
  Shuffled: educational_attainment
  Shuffled: occupation

Final synthetic demographics shape: (986559, 6)
Columns: ['age', 'gender', 'state', 'race', 'educational_attainment', 'occupation

Unnamed: 0,age,gender,state,race,educational_attainment,occupation
0,17,Female,Nebraska/NE,White alone,Bachelor's degree,MGR-Financial Managers
1,20,Female,Texas/TX,White alone,Associate's degree,ENT-Producers And Directors
2,20,Female,Massachusetts/MA,White alone,Doctorate degree,EDU-Postsecondary Teachers
3,18,Female,Texas/TX,White alone,Master's degree,EDU-Special Education Teachers
4,18,Female,California/CA,Some Other Race alone,12th grade - no diploma,N/A (less than 16 years old/NILF who last work...
