In [85]:
import polars as pl
import altair as alt
import lpm_fidelity
import random             # Only needed for the test data
import copy               # Only needed for the test data

In [86]:
# This is just some test data to play with
eobs = {
    "date": ["2021-01-04","2021-01-05","2021-01-06","2021-01-07","2021-01-08","2021-01-11","2021-01-12","2021-01-13","2021-01-14","2021-01-15","2021-01-19","2021-01-20"],
    "vti": [190.048508,191.583771,193.307281,196.407532,197.358429,196.179733,196.873093,197.091003,196.902817,195.189224,196.972137,199.458298],
    "vb": [190.040939,192.759628,198.970963,202.255219,202.175842,202.175842,205.192230,203.961853,206.521790,203.763412,206.114990,207.365189],
    "ca": [200,50,247,23,139,166,176,110,168,14,229,207],
    "cb": [168,133,171,117,134,132,248,8,161,13,168,133],
    "na": ['C3','D2','B0','D3','F0','D2','F3','B0','E0','C2','D0','A2'],
    "nb": ['C1','A0','D3','A3','E0','E3','D0','H1','B1','E1','E0','A3'],
    "ba": ['yes','yes','yes','yes','yes','yes','no','no','yes','yes','yes','no'],
    "bb": ['no','no','no','yes','no','no','no','yes','yes','no','yes','no'],
}
esyn = copy.deepcopy(eobs)
odf = pl.DataFrame(eobs)
sdf = pl.DataFrame(esyn)
sdf = sdf.with_columns(pl.col('vti')+random.uniform(-2,2))      # Just shift it a bit
sdf = sdf.with_columns(pl.col('vb')+random.uniform(-2,2))       # Just shift it a bit
sdf = sdf.with_columns(pl.col('ca')+int(random.uniform(-6,6)))  # Just shift it a bit
sdf = sdf.with_columns(pl.col('cb')+int(random.uniform(-6,6)))  # Just shift it a bit

In [87]:
COLOR_OBSERVED = "#000000"
COLOR_SYNTHETIC = "#f28e2b"

In [88]:
def collapse(df: pl.DataFrame,column: str):
    """
    Collapse the type of the named Polars DataFrame column, to N for real
    numerical types (Float64 or Int64) or C for categorical. Anything else
    results in an exception.
    """
    ctype = df[column].dtype
    if ctype in [pl.Float64,pl.Int64]:
        ptype = "N"
    elif ctype == pl.String:
        ptype = "C"
    assert ptype, f"Datatype of column {column} must be real number or string"
    return ptype

In [89]:
def chart_numericals(df: pl.DataFrame,columns: list,color: str):
    """
    Build a single chart pane that does a 2D scatterplot of two columns,
    both with numerical data. The first column (column[0]) is on the X
    axis and the second on the Y axis. Points are plotted in the given
    color.
    """
    return alt.Chart(df.to_pandas()).mark_point(filled=True).encode(
        x = alt.X(f"{columns[0]}:Q",scale = alt.Scale(zero=False)),
        y = alt.Y(f"{columns[1]}:Q",scale = alt.Scale(zero=False)),
        color = alt.ColorValue(color)
    )

In [90]:
def chart_categoricals(df: pl.DataFrame,columns: list,color: str):
    """
    Build a single chart pane that does a 2D scatterplot of two columns,
    both with categorical data. The first column (column[0]) is on the X
    axis and the second on the Y axis. Points are plotted in the given
    color.
    """
    return alt.Chart(df.to_pandas()).mark_point(filled=True).encode(
        x = alt.X(f"{columns[0]}:Q",scale = alt.Scale(zero=False)),
        y = alt.Y(f"{columns[1]}:Q",scale = alt.Scale(zero=False)),
        color = alt.ColorValue(color)
    )

In [91]:
def plot_comparison_2d(observed: pl.DataFrame,synthetic: pl.DataFrame,columns: list):
    """
    """
    assert len(columns) == 2, "2d plot requires selection of exactly two columns"
    observed = observed[columns].with_columns(pl.lit("observed").alias("plot_data_source"))
    synthetic = synthetic[columns].with_columns(pl.lit("synthetic").alias("plot_data_source"))
    combined = pl.concat([observed,synthetic])
    types = collapse(combined,columns[0])+collapse(combined,columns[1])
    match types:
        case "NN":
            oplot = chart_numericals(combined,columns,COLOR_OBSERVED).transform_filter(alt.datum.plot_data_source == "observed")
        case "CC":
            oplot = chart_categoricals(combined,columns,COLOR_OBSERVED).transform_filter(alt.datum.plot_data_source == "observed")
    match stypes:
        case "NN":
            splot = chart_numerical_numerical(combined,columns,COLOR_SYNTHETIC).transform_filter(alt.datum.plot_data_source == "synthetic")
        case "CC":
            splot = chart_categoricals(combined,columns,COLOR_SYNTHETIC).transform_filter(alt.datum.plot_data_source == "synthetic")
    return oplot+splot

In [92]:
def plot_comparisons_2d(observed: pl.DataFrame,synthetic: pl.DataFrame,pairs: list):
    """
    """
    assert len(pairs) > 1
    for c in set([x for xs in pairs for x in xs]):
        assert c in observed.columns, f"Column '{c}' not in observed data"
        assert c in synthetic.columns, f"Column '{c}' not in synthetic data"
    plots = [plot_comparison_2d_numerical_numerical(observed,synthetic,p) for p in pairs]
    return (alt.vconcat(*plots).resolve_scale(color="independent").properties(title='2-D Marginals'))

In [93]:
plot_comparisons_2d_numerical_numerical(odf,sdf,[['vti','vb'],['vb','vti'],['ca','cb'],['ca','vb']])