# US testing notebook

This notebook includes testing for the functionality of reweight with PolicyEngine US.

In [2]:
import pandas as pd
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

In [3]:
from policyengine_us.data.datasets.cps.enhanced_cps.loss import generate_model_variables
(
    household_weights,
    weight_adjustment,
    values_df,
    targets,
    targets_array,
    equivalisation_factors_array
) = generate_model_variables("cps_2021", 2025)

In [4]:
from reweight import reweight

In [9]:
sim_matrix = torch.tensor(values_df.to_numpy(), dtype=torch.float32)
initial_weights = torch.tensor(household_weights, dtype=torch.float32)
targets_tensor = torch.tensor(targets_array, dtype=torch.float32)
final_weights = reweight(initial_weights, sim_matrix, targets, targets_tensor, epochs=1_000)

Epoch 100, Loss: 0.36243075132369995
Epoch 200, Loss: 0.3186039328575134
Epoch 300, Loss: 0.2849416434764862
Epoch 400, Loss: 0.2582869231700897
Epoch 500, Loss: 0.23667685687541962
Epoch 600, Loss: 0.21886183321475983
Epoch 700, Loss: 0.20400400459766388
Epoch 800, Loss: 0.191498264670372
Epoch 900, Loss: 0.18087948858737946
Epoch 1000, Loss: 0.17178186774253845


In [10]:
def nonzero_proportion(tensor):
    return torch.count_nonzero(tensor).item() / tensor.numel()

print(nonzero_proportion(initial_weights))

print(nonzero_proportion(final_weights))

1.0
1.0


In [11]:
import plotly.express as px
import pandas as pd

df = pd.DataFrame({
    "initial_weights": initial_weights.numpy(),
    "final_weights": final_weights.numpy()
})

px.histogram(
    df,
    x=["initial_weights", "final_weights"],
)

In [12]:
list(values_df.columns)

['employment income (IRS SOI)',
 'self-employment income (IRS SOI)',
 'partnership/S-corp income (IRS SOI)',
 'farm income (IRS SOI)',
 'farm rental income (IRS SOI)',
 'short-term capital gains (IRS SOI)',
 'long-term capital gains (IRS SOI)',
 'taxable interest income (IRS SOI)',
 'tax-exempt interest income (IRS SOI)',
 'rental income (IRS SOI)',
 'qualified dividend income (IRS SOI)',
 'non-qualified dividend income (IRS SOI)',
 'taxable pension income (IRS SOI)',
 'Social Security (IRS SOI)',
 'Alimony income (IRS SOI)',
 'Federal income tax (CBO)',
 'SNAP allotment (CBO)',
 'Social Security (CBO)',
 'SSI (CBO)',
 'unemployment compensation (CBO)',
 'SNAP allotment participants',
 'SSI participants',
 'Social Security participants',
 'U.S. population',
 '0 to 5 and male population',
 '0 to 5 and female population',
 '10 to 15 and male population',
 '10 to 15 and female population',
 '20 to 25 and male population',
 '20 to 25 and female population',
 '30 to 35 and male population',

In [13]:
values_df["employment income (IRS SOI)"]

0         55956.421875
1         60767.156250
2         75958.945312
3        151917.886719
4             0.000000
             ...      
59143     39245.453125
59144    425870.145508
59145    268390.800781
59146     31649.560547
59147     39245.453125
Name: employment income (IRS SOI), Length: 59148, dtype: float64

In [14]:
income_values = np.asarray(values_df["employment income (IRS SOI)"])

In [15]:
import plotly.graph_objects as go
import numpy as np

# Assuming you have these arrays:
# income_values: array of income values
# household_weights: array of the number of people for each income

# Create the histogram
fig = go.Figure(data=[go.Histogram(
    x=income_values,
    histfunc='sum',
    y=household_weights,
    nbinsx=200,  # You can adjust the number of bins as needed
)])

# Customize the layout
fig.update_layout(
    title='Income Distribution',
    xaxis_title='Income',
    yaxis_title='Number of People',
    bargap=0.1  # Adds a small gap between bars
)

# Show the plot
fig.show()


In [16]:
finishing_weights = final_weights.numpy()

# Create the histogram
fig = go.Figure(data=[go.Histogram(
    x=income_values,
    histfunc='sum',
    y=finishing_weights,
    nbinsx=200,  # You can adjust the number of bins as needed
)])

# Customize the layout
fig.update_layout(
    title='Income Distribution',
    xaxis_title='Income',
    yaxis_title='Number of People',
    bargap=0.1  # Adds a small gap between bars
)

# Show the plot
fig.show()

#NOTE: Could try again with "employment income budgetary impact (UK)"