In [9]:
import polars as pl
import datashader as ds
import datashader.transfer_functions as tf
from datetime import timedelta, datetime
import matplotlib.pyplot as plt
from threading import Lock
import numpy as np
import warnings
import xarray as xr

pl.Config.set_engine_affinity("streaming")
pl.Config.set_streaming_chunk_size(200000)

polars.config.Config

In [10]:
import os

os.environ["SHELL"] = "/app/bin/host-spawn"
os.environ["PATH"] = (
    "/var/home/caleb/Projects/stream-plotting/.venv/bin:/home/caleb/.local/share/zinit/plugins/starship---starship:/home/caleb/.local/share/zinit/polaris/bin:/home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin:/home/caleb/.cargo/bin:/usr/local/bin:/usr/bin:/bin:/usr/local/sbin:/usr/sbin:/sbin"
)
os.environ["POLARS_VISUALIZE_PHYSICAL_PLAN"] = (
    "/home/caleb/Projects/stream-plotting/a.dot"
)

In [None]:
all_lf = []
for i in range(1, 1000):
    print(datetime(2001, 1, ((i // 60) // 24) + 1, (i // 60) % 24, i % 60))
    lf = pl.LazyFrame()
    lf = lf.with_columns(
        pl.datetime_range(
            datetime(
                2001, 1, (((i - 1) // 60) // 24) + 1, ((i - 1) // 60) % 24, (i - 1) % 60
            ),
            datetime(2001, 1, ((i // 60) // 24) + 1, (i // 60) % 24, i % 60),
            "123ns",
            eager=False,
        )
        .sample(100_000_000)
        .sort()
        .alias("time")
    )
    lf = lf.with_columns(
        (pl.int_range(100000).sample(100_000_000, with_replacement=True) / 10000).alias(
            "value"
        ),
        (pl.int_range(100000).sample(100_000_000, with_replacement=True) / 10000).alias(
            "colour"
        ),
    )
    # lf.sink_parquet(f"test/test{i}.parquet")

In [4]:
# pl.scan_parquet("test/").head().collect()

In [11]:
lf = pl.scan_parquet("test/*")

In [None]:
def plot_with_datashader(
    plots_dict: dict,
    df: pl.DataFrame,
    x: str,
    y: str,
    c: str,
    period: timedelta,
    plot_width=800,
    plot_height=600,
    out_dir=".",
    prefix="plot",
) -> pl.DataFrame:
    """
    Plots a Polars DataFrame using Datashader and returns the same DataFrame.
    Calculates plot index using time and metadata, and uses it in the output filename.

    Additional Parameters:
    - start_time (int): Epoch timestamp in nanoseconds of the full data start.
    - period_ns (int): Period in nanoseconds for splitting.
    - out_dir (str): Output directory for saving plots.
    - prefix (str): Prefix for filenames.
    """
    plot_idx = df["group_idx"].first()
    #

    
    period_ns = int(period.total_seconds() * 1e9)
    # Determine plot index from min x (timestamp)

    group_len = df["len"].first()
    x_min = df[f"{x}_min"].first()
    x_max = df[f"{x}_max"].first()
    y_min = df[f"{y}_min"].first()
    y_max = df[f"{y}_max"].first()
    pdf = xr.Dataset(
        {
            x: (["points"], df[x].to_numpy(allow_copy=False)),
            y: (["points"], df[y].to_numpy(allow_copy=False)),
            c: (["points"], df[c].to_numpy(allow_copy=False)),
        }
    )
    # Create canvas and render
    cvs = ds.Canvas(
        plot_width=plot_width,
        plot_height=plot_height,
        x_range=(x_min, x_max),
        y_range=(y_min, y_max),
    )
    current_agg = cvs.points(pdf, x, y, ds.mean(c))
    with plots_dict.setdefault(f"{plot_idx}_lock", Lock()):
        plots_dict[plot_idx] = np.nanmean(
            np.dstack(
                (
                    plots_dict.get(plot_idx, np.full_like(current_agg.values, np.nan)),
                    current_agg.values,
                )
            ),
            2,
        )
        plots_dict.setdefault(f"{plot_idx}_count", np.array(0))[...] += df.height
    
    if plots_dict[f"{plot_idx}_count"] == np.array(group_len):
        total = np.nan_to_num(plots_dict[plot_idx], nan=0.0)
        current_agg.values = total
        # Save the image
        fig, ax = plt.subplots()
        fig.set_figheight(plot_height / 100)
        fig.set_figwidth(plot_width / 100)
        ax.imshow(tf.shade(current_agg).to_pil(), aspect="auto")
        ax.axis("off")
        filename = f"{out_dir}/{prefix}_{period_ns}_{plot_idx}.png"
        fig.savefig(filename, bbox_inches="tight", pad_inches=0)
        plt.close(fig)
        del plots_dict[f"{plot_idx}_count"] 
        # del plots_dict[f"{plot_idx}_lock"]
        del plots_dict[plot_idx]


def plot_lf(
    lf: pl.LazyFrame, period: timedelta, x: str, y: str, c: str, out_dir="plots"
):
    # Convert period to nanoseconds
    # period_ns = pl.duration_string_to_duration(period).cast(pl.Int64).max().item()

    # Add temporary epoch column for plotting
    lf = lf.with_columns(pl.col(x).dt.epoch("ns").alias(f"{x}_epoch_tmp")).with_columns(
        (
            (pl.col(f"{x}_epoch_tmp") - pl.col(f"{x}_epoch_tmp").min())
            // (period.total_seconds() * 1e9)
        ).alias("group_idx").cast(pl.UInt32)
    )
    plots = {}

    def custom_plot_fn(df):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            dfs = df.partition_by("group_idx")
            for tmp_df in dfs:
                plot_with_datashader(
                    plots,
                    tmp_df,
                    x=f"{x}_epoch_tmp",
                    y=y,
                    c=c,
                    period=period,
                    out_dir=out_dir,
                    prefix="plot",
                )
            return df

    # Run the dynamic grouping and plotting
    
    agg = lf.select(f"{x}_epoch_tmp", "group_idx", y, c).group_by("group_idx").agg(
        pl.col(f"{x}_epoch_tmp").max().alias(f"{x}_epoch_tmp_max"),
        pl.col(f"{x}_epoch_tmp").min().alias(f"{x}_epoch_tmp_min"),
        pl.col(y).max().alias(f"{y}_max"),
        pl.col(y).min().alias(f"{y}_min"),
        pl.len().alias("len"),
    )
    return lf.select(f"{x}_epoch_tmp", "group_idx", y, c).join(agg, on="group_idx").map_batches(
            custom_plot_fn,
            predicate_pushdown=False,
            projection_pushdown=False,
            slice_pushdown=False,
            streamable=True,
        )

In [17]:
a = plot_lf(lf.head(100000000), timedelta(seconds=1), "time", "value", "colour")
a.sink_parquet("out.tmp", engine="streaming")

In [None]:
a = plot_lf(lf, timedelta(seconds=30), "time", "value", "colour")
b = plot_lf(lf, timedelta(seconds=10), "time", "value", "colour")
c = plot_lf(lf, timedelta(seconds=2), "time", "value", "colour")
d = plot_lf(lf, timedelta(milliseconds=1000), "time", "value", "colour")
pl.collect_all([a, b, c, d], engine="streaming")

In [None]:
print(pl.explain_all([a, b, c, d]))