In [1]:
import math
import polars as pl

from lpm_query import QueryInference
from lpm_structure_learning import learn_structure
from lpm_fidelity.distances import bivariate_distances_in_data
from lpm_fidelity.counting  import bivariate_empirical_frequencies
from lpm_plot import plot_fidelity, plot_marginal_1d, plot_marginal_2d

In [2]:
df = pl.read_parquet("../resources/ces-10-cols.parquet")

In [3]:
df.head()

State,Age,Sex,Race,Religion,Party_allegiance,Police_makes_me_feel,Mental_health_status,Policy_support_ban_assault_rifles,Policy_support_allowing_teachers_to_carry_guns
str,str,str,str,str,str,str,str,str,str
"""Maine""","""(49, 55]""","""Male""","""White""","""Roman Catholic""","""Republican""","""(b) Somewhat safe""","""(c) Good""","""No""","""Yes"""
"""Florida""","""(42, 49]""","""Female""","""Black""","""Protestant""","""not sure""","""(b) Somewhat safe""","""(d) Fair""","""Yes""","""Yes"""
"""Idaho""","""(55, 60]""","""Female""","""White""","""Mormon""","""Republican""","""(b) Somewhat safe""","""(a) Excellent""","""No""","""Yes"""
"""New York""","""(49, 55]""","""Male""","""White""","""Roman Catholic""","""Republican""","""(a) Mostly safe""","""(c) Good""","""No""","""Yes"""
"""Florida""","""(35, 42]""","""Male""","""White""","""Protestant""","""Democrat""","""(a) Mostly safe""","""(c) Good""","""Yes""","""No"""


In [4]:
df.shape

(10000, 10)

In [5]:
model_spec, schema = learn_structure(df, max_clusters=30)

Model index chosen: 22


In [6]:
qi = QueryInference(model_spec, schema)

In [7]:
df_generated = qi.generate(10000, df.columns)
df_generated.head()

State,Age,Sex,Race,Religion,Party_allegiance,Police_makes_me_feel,Mental_health_status,Policy_support_ban_assault_rifles,Policy_support_allowing_teachers_to_carry_guns
str,str,str,str,str,str,str,str,str,str
"""Iowa""","""(27, 35]""","""Male""","""White""","""Agnostic""","""independent""","""(d) Mostly unsafe""","""(c) Good""","""Yes""","""No"""
"""Pennsylvania""","""(65, 70]""","""Female""","""White""","""Protestant""","""Republican""","""(b) Somewhat safe""","""(d) Fair""","""No""","""Yes"""
"""North Carolina""","""(27, 35]""","""Female""","""Black""","""Nothing in particular""","""independent""","""(b) Somewhat safe""","""(c) Good""","""Yes""","""No"""
"""California""","""(27, 35]""","""Female""","""Asian""","""Nothing in particular""","""independent""","""(b) Somewhat safe""","""(d) Fair""","""Yes""","""Yes"""
"""California""","""(-inf, 27]""","""Male""","""White""","""Roman Catholic""","""independent""","""(b) Somewhat safe""","""(a) Excellent""","""No""","""Yes"""


In [8]:
plot_marginal_1d(df, df_generated, columns=["Sex", "Race", "Age"])

In [9]:
co_occurences = bivariate_empirical_frequencies({"Observed":df, "Synthetic": df_generated}, "Sex", "Age")
plot_marginal_2d(co_occurences, "Sex", "Age")

In [10]:
distances = bivariate_distances_in_data(df, df_generated)

In [11]:
distances = distances.with_columns([pl.lit("LPM").alias("model")])
distances = distances.with_columns(pl.arange(1, len(distances) + 1).alias("index"))

In [12]:
plot_fidelity(distances)

In [13]:
math.exp(qi.logpdf({"Religion":"Mormon"}))

0.014158817857619254

In [14]:
math.exp(qi.logpdf({"Religion":"Mormon"}, {"State":"Utah"}))

0.27036195317247524

In [15]:
### XXX: fix ylabel above

In [16]:
qi.mutual_information(["Policy_support_ban_assault_rifles"], ["Policy_support_allowing_teachers_to_carry_guns"])

Array(0.10250264, dtype=float32)

In [17]:
qi.mutual_information(["Policy_support_ban_assault_rifles"], ["Policy_support_allowing_teachers_to_carry_guns"], {"Party_allegiance":"Republican"})

Array(0.04284743, dtype=float32)

In [22]:
qi.mutual_information(["Policy_support_ban_assault_rifles"], ["Policy_support_allowing_teachers_to_carry_guns"], {"Religion":"Protestant"})

Array(0.11588026, dtype=float32)