In [149]:
import polars as pl
import altair as alt
import lpm_fidelity.counting as lf
import random             # Only needed for the Jupyter notebook test data
import copy               # Only needed for the Jupyter notebook test data

In [150]:
# This is just some test data to play with
eobs = {
    "date": ["2021-01-04","2021-01-05","2021-01-06","2021-01-07","2021-01-08","2021-01-11","2021-01-12","2021-01-13","2021-01-14","2021-01-15","2021-01-19","2021-01-20"],
    "vti": [190.048508,191.583771,193.307281,196.407532,197.358429,196.179733,196.873093,197.091003,196.902817,195.189224,196.972137,199.458298],
    "vb": [190.040939,192.759628,198.970963,202.255219,202.175842,202.175842,205.192230,203.961853,206.521790,203.763412,206.114990,207.365189],
    "ca": [200,50,247,23,139,166,176,110,168,14,229,207],
    "cb": [168,133,171,117,134,132,248,8,161,13,168,133],
    "na": ['L','O','E','P','U','O','X','E','Q','K','M','C'],
    "nb": ['J','A','P','D','Q','T','M','J','F','R','Q','D'],
    "ba": ['yes','yes','yes','yes','yes','yes','no','no','yes','yes','yes','no'],
    "bb": ['no','no','no','no','no','no','no','yes','no','no','yes','no'],
}

def flip(x):
    if random.random() < 0.2:
        x = 'yes' if x == 'no' else 'no'
    return x

def replace(x):
    if random.random() < 0.2:
        x = random.choice(list(set(eobs['na']+eobs['nb'])))
    return x

def fjitter(x):
    if random.random() < 0.2:
        x += random.uniform(-3,3)
    return x

def ijitter(x):
    if random.random() < 0.2:
        x += int(random.uniform(-3,3))
    return x
    
esyn = copy.deepcopy(eobs)
odf = pl.DataFrame(eobs)
sdf = pl.DataFrame(esyn)
sdf = sdf.with_columns(pl.col('vti').map_elements(fjitter,return_dtype=pl.Float64))
sdf = sdf.with_columns(pl.col('vb').map_elements(fjitter,return_dtype=pl.Float64))
sdf = sdf.with_columns(pl.col('ca').map_elements(ijitter,return_dtype=pl.Int64))
sdf = sdf.with_columns(pl.col('cb').map_elements(ijitter,return_dtype=pl.Int64))
sdf = sdf.with_columns(pl.col('ba').map_elements(flip,return_dtype=str))
sdf = sdf.with_columns(pl.col('bb').map_elements(flip,return_dtype=str))
sdf = sdf.with_columns(pl.col('na').map_elements(replace,return_dtype=str))
sdf = sdf.with_columns(pl.col('nb').map_elements(replace,return_dtype=str))

In [151]:
COLOR_OBSERVED = "#000000"
COLOR_SYNTHETIC = "#f28e2b"

In [152]:
def plot_comparisons_2d_categorical_categorical(observed: pl.DataFrame,synthetic: pl.DataFrame,pairs: list):
    """
    """
    
    def plot_one_comparison(observed: pl.DataFrame,synthetic: pl.DataFrame,columns: list):
        """
        """
        assert len(columns) == 2, "2d plot requires selection of exactly two columns"
        dfs = {"observed": observed, "synthetic": synthetic}
        bef = lf.bivariate_empirical_frequencies(dfs,columns[0],columns[1])
        oplot = chart_categorical(bef,columns,COLOR_OBSERVED).transform_filter(alt.datum.Source == "observed")
        splot = chart_categorical(bef,columns,COLOR_SYNTHETIC).transform_filter(alt.datum.Source == "synthetic")

        # This is a little bit hacky but saves us having to create separate
        # functions for the left and right plot. Kind of a judgment call...
        splot.encoding.y['axis'] = None
        
        plots = [oplot,splot]
        return alt.hconcat(*plots,bounds='flush',spacing=0).resolve_scale(color="independent")

    def chart_categorical(df: pl.DataFrame,columns: list,color: str):
        """
        Build a single chart pane that does a 2D bubble plot of two columns,
        both with categorical data. The first column (column[0]) is on the X
        axis and the second (column[1]) on the Y axis. Points are plotted in
        the given color and sized by the Normalized frequency.
        """
        return alt.Chart(df.to_pandas()).mark_circle().encode(
            x = alt.X(f"{columns[0]}:N",axis={"orient": "top"}),
            y = alt.Y(f"{columns[1]}:N"),
            color = alt.ColorValue(color),
            size=alt.Size('Normalized frequency:Q',legend=None)
        )
    
    assert len(pairs) > 0, "At least one pair of columns is required"
    for c in set([x for xs in pairs for x in xs]):
        assert c in observed.columns, f"Column '{c}' not in observed data"
        assert c in synthetic.columns, f"Column '{c}' not in synthetic data"
        assert observed[c].dtype in [pl.String], f"Column '{c}' in observed data must be strimg datatype"
        assert synthetic[c].dtype in [pl.String], f"Column '{c}' in synthetic data must be string datatype"
    plots = [plot_one_comparison(observed,synthetic,p) for p in pairs]
    return (alt.vconcat(*plots,title=alt.TitleParams('2-D Marginals',subtitle='nominal-nominal')).resolve_scale(color="independent"))

In [153]:
def plot_comparisons_2d_numerical_numerical(observed: pl.DataFrame,synthetic: pl.DataFrame,pairs: list):
    """
    Generate Altair scatterplots with the synthetic data on top of the
    observed data so people can judge the quality of the synthetic data.
    Each plot is defined by a pair of column names that will define the X
    and Y axes. The <pairs> argument is a list of such pairs of column
    names.
    """
    
    def plot_one_comparison(observed: pl.DataFrame,synthetic: pl.DataFrame,columns: list):
        """
        Generate one plot that combines a scatterplot of the synthetic data
        on top of a scatterplot of the observed data. The <columns> argument
        is provides the columns to be graphed on the X and Y axes.
        """
        assert len(columns) == 2, "2d plot requires selection of exactly two columns"
        observed = observed[columns].with_columns(pl.lit("observed").alias("plot_data_source"))
        synthetic = synthetic[columns].with_columns(pl.lit("synthetic").alias("plot_data_source"))
        combined = pl.concat([observed,synthetic])
        oplot = chart_numerical(combined,columns,COLOR_OBSERVED).transform_filter(alt.datum.plot_data_source == "observed")
        splot = chart_numerical(combined,columns,COLOR_SYNTHETIC).transform_filter(alt.datum.plot_data_source == "synthetic")
        return oplot+splot

    def chart_numerical(df: pl.DataFrame,columns: list,color: str):
        """
        Build a single chart pane that does a 2D scatterplot of two columns,
        both with numerical data. The first column (column[0]) is on the X
        axis and the second on the Y axis. Points are plotted in the given
        color.
        """
        return alt.Chart(df.to_pandas()).mark_point(filled=True).encode(
            x = alt.X(f"{columns[0]}:Q",scale = alt.Scale(zero=False)),
            y = alt.Y(f"{columns[1]}:Q",scale = alt.Scale(zero=False)),
            color = alt.ColorValue(color)
        )
    
    assert len(pairs) > 0, "At least one pair of columns is required"
    for c in set([x for xs in pairs for x in xs]):
        assert c in observed.columns, f"Column '{c}' not in observed data"
        assert c in synthetic.columns, f"Column '{c}' not in synthetic data"
        assert observed[c].dtype in [pl.Float64,pl.Int64], f"Column '{c}' in observed data must be numerical datatype"
        assert synthetic[c].dtype in [pl.Float64,pl.Int64], f"Column '{c}' in synthetic data must be numerical datatype"
    plots = [plot_one_comparison(observed,synthetic,p) for p in pairs]
    return (alt.vconcat(*plots,title=alt.TitleParams('2-D Marginals',subtitle='numerical-numerical')).resolve_scale(color="independent"))

In [154]:
nn = plot_comparisons_2d_numerical_numerical(odf,sdf,[['vti','vb'],['vb','vti'],['ca','cb'],['ca','vb']])
cc = plot_comparisons_2d_categorical_categorical(odf,sdf,[['na','nb'],['ba','bb'],['na','bb']])
nn & cc