In [None]:
from __future__ import annotations

import datetime
import os
import os.path as osp
import textwrap
from typing import Callable, Optional, Type

import lqsvg.experiment.analysis as analysis
import lqsvg.experiment.utils as eutil
import lqsvg.torch.named as nt
import lqsvg.torch.utils as tutil
import matplotlib as mpl
import matplotlib.pyplot as plt

# Format y axis as percent: https://stackoverflow.com/a/36319915/7842251
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import ray
import seaborn as sns
import torch
import torch.nn as nn
from lqsvg.envs import lqr
from lqsvg.envs.lqr.generators import LQGGenerator
from lqsvg.envs.lqr.modules import (
    InitStateDynamics,
    LinearDynamicsModule,
    LQGModule,
    QuadraticReward,
)
from lqsvg.experiment.estimators import DPG, MAAC, AnalyticSVG, MonteCarloSVG
from lqsvg.experiment.plot import default_figsize, plot_surface
from lqsvg.np_util import RNG
from lqsvg.policy.modules import QuadQValue, TVLinearPolicy
from torch import Tensor
from tqdm.auto import tqdm, trange

In [None]:
ray.init()

In [None]:
DEBUG = False

# Matplotlib setup

Latex presets ([ref](https://jwalton.info/Embed-Publication-Matplotlib-Latex/))

In [None]:
from functools import partial

from lqsvg.experiment.plot import STYLE_PATH, available_styles, latex_size, save_pdf_tight, create_latex_style

create_latex_style()
set_size = partial(latex_size, width="thesis")

In [None]:
print(*available_styles())

In [None]:
print(f"Your style sheets are located at: {STYLE_PATH}")

In [None]:
# Using seaborn's style
plt.style.use("seaborn")
# With LaTex fonts
plt.style.use("tex")

In [None]:
def figpath(name: str) -> str:
    return osp.join("images", name + ".pdf")


def savefig(fig, name):
    return save_pdf_tight(fig, figpath(name))

In [None]:
if DEBUG:
    # Test
    x = np.linspace(0, 2 * np.pi, 100)
    # Initialise figure instance
    fig, ax = plt.subplots(1, 1, figsize=set_size())

    # Plot
    ax.plot(x, np.sin(x))
    ax.set_xlim(0, 2 * np.pi)
    ax.set_xlabel(r"$\theta$")
    ax.set_ylabel(r"$\sin (\theta)$")

    savefig(fig, "example")

In [None]:
if osp.exists(figpath("example")):
    os.remove(figpath("example"))

# Experimental setup

## Biased/Unbiased (DPG/MAAC) estimators

In [None]:
class KStepModules(nn.Module):
    def __init__(
        self,
        policy: TVLinearPolicy,
        transition: LinearDynamicsModule,
        reward: QuadraticReward,
        qvalue: QuaQValue,
    ):
        super().__init__()
        self.policy = policy
        self.transition = transition
        self.reward = reward
        self.qvalue = qvalue


class KStepEstimator(nn.Module):
    def __init__(
        self, modules: KStepModules, obs: Tensor, kind: Union[Type[DPG], Type[MAAC]]
    ):
        super().__init__()
        self.nn = modules
        self.state_dataset = obs
        self.kind = kind
        self.estimator = self.kind(
            modules.policy, modules.transition, modules.reward, modules.qvalue
        )
        self.n_steps: int = 0

    def sample_starting_obs(self, samples: int) -> Tensor:
        if samples == self.state_dataset.size("B"):
            return self.state_dataset
        idxs = torch.randint(low=0, high=self.state_dataset.size("B"), size=(samples,))
        return nt.index_select(self.state_dataset, dim="B", index=idxs)

    def surrogate(self, samples: int = 1) -> Tensor:
        obs = self.sample_starting_obs(samples)
        return self.estimator.surrogate(obs, n_steps=self.n_steps)

    def forward(self, samples: int = 1) -> tuple[Tensor, lqr.Linear]:
        obs = self.sample_starting_obs(samples)
        return self.estimator(obs, n_steps=self.n_steps)

    def delta_to_surrogate(
        self, samples: int, n_step: int, update_q: bool = False
    ) -> Callable[[np.ndarray], np.ndarray]:
        K_0, k_0 = self.nn.policy.standard_form()

        @torch.no_grad()
        def f_delta(delta: np.ndarray) -> np.ndarray:
            vector = nt.vector(tutil.as_float_tensor(delta))
            delta_K, delta_k = tutil.vector_to_tensors(vector, (K_0, k_0))
            policy = TVLinearPolicy.from_existing((K_0 + delta_K, k_0 + delta_k))
            if update_q:
                qvalue = QuadQValue.from_policy(
                    policy.standard_form(),
                    self.nn.transition.standard_form(),
                    self.nn.reward.standard_form(),
                )
            else:
                qvalue = self.nn.qvalue
            modules = KStepModules(policy, self.nn.transition, self.nn.reward, qvalue)
            estimator = KStepEstimator(modules, self.state_dataset, self.kind)
            estimator.n_steps = n_step
            surrogate = estimator.surrogate(samples)
            return surrogate.numpy()

        return f_delta

## Environment and policy generation

LQG parameters: 
1. $|S| = 2, |A| = 2, |H| = 20$
2. Stationary dynamics and cost
3. Passive dynamics eigvals $|\lambda_i| \sim \mathcal{U}(0.5, 1.5)$
4. Controllable
5. Transition covariance $\Sigma = I$
6. Cost $s^\intercal C_{ss} s + a^\intercal C_{aa} a$

In [None]:
def make_generator(seed: int) -> LQGGenerator:
    return LQGGenerator(
        n_state=2,
        n_ctrl=2,
        horizon=20,
        stationary=True,
        passive_eigval_range=(0.5, 1.5),
        controllable=True,
        rng=seed,
    )

In [None]:
def grad_estimates(
    estimator,
    sample_sizes: list[int],
    estimates_per_sample_size: int = 10,
    pbar: bool = False,
) -> list[list[lqr.Linear]]:
    """Helper to sample several SVG estimates for different sample sizes."""
    progress = tqdm(
        sample_sizes, desc="Computing SVG by sample size", leave=False, disable=not pbar
    )
    svgs_by_sample_size = []
    for size in progress:
        svgs = [estimator(samples=size)[1] for _ in range(estimates_per_sample_size)]
        svgs_by_sample_size += [svgs]
    return svgs_by_sample_size

In [None]:
class Trial:
    def __init__(self, seed: int, total_states: int = 1000):
        self.rng = np.random.default_rng(seed)
        self.generator = make_generator(self.rng)
        self.lqg, self.policy, self.qvalue = self.make_modules(self.generator)

        with tutil.default_generator_seed(seed):
            self.states = self.starting_states(self.policy, self.lqg, total_states)

        modules = KStepModules(
            self.policy, self.lqg.trans, self.lqg.reward, self.qvalue
        )
        self.estimator = {
            "dpg": KStepEstimator(modules, self.states, kind=DPG),
            "maac": KStepEstimator(modules, self.states, kind=MAAC),
        }

    def make_modules(
        self, generator: LQGGenerator
    ) -> tuple[LQGModule, TVLinearPolicy, QuadQValue]:
        with nt.suppress_named_tensor_warning():
            dynamics, cost, init = generator()
        model = LQGModule.from_existing(dynamics, cost, init)
        policy = TVLinearPolicy(model.n_state, model.n_ctrl, model.horizon)
        policy.stabilize_(dynamics, rng=self.rng)
        qvalue = QuadQValue.from_policy(policy.standard_form(), dynamics, cost)
        return model, policy, qvalue

    @staticmethod
    def starting_states(policy: TVLinearPolicy, model: LQGModule, num: int) -> Tensor:
        rollout_module = MonteCarloSVG(policy, model)
        n_trajs = num // model.horizon
        with torch.no_grad():
            obs, _, _, _, _ = rollout_module.rsample_trajectory(torch.Size([n_trajs]))
        obs = obs.flatten(["H", "B1"], "B")
        return obs

# Experiments

## Gradient estimation for fixed policies

In [None]:
@ray.remote
class GradientQualityTrial(Trial):
    """Helper to collect gradient quality metrics for a specific environment-policy pair.

    Environment-policy pair determined by seed.
    """

    def __init__(self, seed: int, total_states: int):
        super().__init__(seed, total_states)
        _, self.true_svg = AnalyticSVG(self.policy, self.lqg)()

    def gradient_quality_vs_samples(
        self, sample_sizes: np.ndarray, n_step: int, estimates: int
    ) -> pd.DataFrame:
        rows = sum(
            (self.gradient_estimation_data(s, n_step, estimates) for s in sample_sizes),
            start=[],
        )
        columns = [
            "#Samples",
            "Avg. cos sim with true grad",
            "Avg. pairwise cos sim",
            "Norm",
            "Estimator",
        ]
        data = pd.DataFrame(rows, columns=columns)
        data["K"] = n_step
        data["trial"] = self.generator.rng
        return data

    def gradient_estimation_data(
        self, samples: int, n_step: int, estimates: int
    ) -> list:
        rows = []
        for name, estimator in self.estimator.items():
            svgs = self.grad_estimates(estimator, samples, n_step, estimates)
            cossim = analysis.gradient_accuracy(svgs, self.true_svg)
            empvar = analysis.empirical_variance(svgs)
            norm = np.mean([eutil.linear_feedback_norm(s).numpy() for s in svgs])
            rows += [[samples, cossim, empvar, norm, name]]

        return rows

    @staticmethod
    def grad_estimates(
        estimator, samples: int, n_step: int, estimates: int = 10
    ) -> list[lqr.Linear]:
        old = estimator.n_steps
        estimator.n_steps = n_step
        estimates = [estimator(samples)[1] for _ in range(estimates)]
        estimator.n_steps = old
        return estimates

In [None]:
class GradientQualityComparison:
    """Helper to collect gradient quality metrics over several environment-policy pairs.

    Uses Ray to distribute computation via trial Actors.
    """

    def __init__(self, seeds: list[int], total_states: int = 1000):
        self.seeds = seeds
        self.total_states = total_states
        self.actors = None
        self.data = None

    def collect(self, sample_sizes: np.ndarray, n_steps: np.ndarray):
        if self.actors is None:
            self.actors = [
                GradientQualityTrial.remote(seed=s, total_states=self.total_states)
                for s in self.seeds
            ]

        remaining = [
            a.gradient_quality_vs_samples.remote(sample_sizes, n_step=k, estimates=10)
            for k in n_steps
            for a in self.actors
        ]
        ready = []
        with tqdm(desc="Collecting", total=len(remaining)) as pbar:
            while remaining:
                done, remaining = ray.wait(remaining)
                ready += done
                pbar.update(len(done))
        self.data = pd.concat(ray.get(ready), ignore_index=True)

    def save(self, name: str = ""):
        assert self.data is not None
        now = datetime.datetime.now().isoformat(timespec="minutes")
        fname = f"{type(self).__name__}-{now}.csv"
        prfx = name + "-" if name else ""
        path = osp.join("local", prfx + fname)
        self.data.to_csv(path, index=False)

    def load(self, path: str):
        assert self.data is None
        self.data = pd.read_csv(path, index_col=False)

### Bias/variance vs. sample size vs. K steps

In [None]:
# comp = GradientQualityComparison(seeds=np.arange(10), total_states=1000)
# comp.collect(
#     sample_sizes=np.linspace(1, comp.total_states, 200, dtype=int), n_steps=[2]
# )
# comp.save("BiasVarianceNorm-K=2")

In [None]:
def fgrid_size(rows: int, cols: int, **setsize_kws) -> dict:
    """Kwargs to set FacetGrid size based on number of rows and cols."""
    w, h = set_size(subplots=(rows, cols), **setsize_kws)
    return dict(height=h / rows, aspect=(w / cols) / (h / rows))


def plot_gradient_acc(data):
    cols = 2
    rows = (len(data["K"].unique()) + 1) // cols
    print(f"rows: {rows}, cols: {cols}")

    fgrid = sns.relplot(
        kind="line",
        data=data,
        x="#Samples",
        y="Avg. cos sim with true grad",
        hue="Estimator",
        col="K",
        col_wrap=cols,
        ci=95,
        facet_kws=dict(legend_out=False),
        **fgrid_size(rows, cols, fraction=1),
    )
    fgrid.set_xlabels("\#Samples")
    fgrid.set_ylabels("Accuracy")
    fgrid.tight_layout()
    return fgrid


def plot_gradient_prc(data):
    cols = 2
    rows = (len(data["K"].unique()) + 1) // cols
    print(f"rows: {rows}, cols: {cols}")

    fgrid = sns.relplot(
        kind="line",
        data=data,
        x="#Samples",
        y="Avg. pairwise cos sim",
        hue="Estimator",
        col="K",
        col_wrap=cols,
        ci=95,
        facet_kws=dict(legend_out=False),
        **fgrid_size(rows, cols, fraction=1),
    )
    fgrid.set_xlabels("\#Samples")
    fgrid.set_ylabels("Precision")
    fgrid.tight_layout()
    return fgrid

In [None]:
BIAS_VAR_NORM_DATA = pd.concat(
    (
        pd.read_csv(
            "results/BiasVarianceNorm-K=0,4,8-GradientQualityComparison-2021-05-28T12:04.csv",
            index_col=False,
        ),
        pd.read_csv(
            "results/BiasVarianceNorm-K=2-GradientQualityComparison-2021-05-31T11:04.csv",
            index_col=False,
        ),
    ),
    ignore_index=True,
)

In [None]:
BIAS_VAR_NORM_DATA["#Samples"].value_counts().sort_index()

In [None]:
savefig(plot_gradient_acc(BIAS_VAR_NORM_DATA), "GradientAcc")

In [None]:
savefig(plot_gradient_prc(BIAS_VAR_NORM_DATA), "GradientPrc")

### Gradient signal strength

In [None]:
def plot_gradient_norm(data, agg: bool = True):
    cols = 2
    rows = (len(data["K"].unique()) + 1) // cols
    print(f"rows: {rows}, cols: {cols}")

    fgrid = sns.relplot(
        kind="line",
        data=data,
        x="#Samples",
        y="Norm",
        hue="Estimator",
        col="K",
        col_wrap=cols,
        ci="sd",
        facet_kws=dict(legend_out=False, sharey=False),
        **fgrid_size(rows, cols, fraction=1),
    )
    for nstep, ax in fgrid.axes_dict.items():
        ax.set_ylim(np.percentile(data[data["K"] == nstep]["Norm"], [2.5, 97.5]))
    #     fgrid.set(ylim=np.percentile(data["Norm"], [2.5, 97.5]))
    fgrid.set_xlabels("\#Samples")
    fgrid.set_ylabels("Gradient norm")
    fgrid.tight_layout()
    return fgrid

In [None]:
savefig(plot_gradient_norm(BIAS_VAR_NORM_DATA, agg=False), "GradientNorm")

### Estimator bias at convergence

In [None]:
# comp = GradientQualityComparison(seeds=np.arange(10), total_states=50000)
# comp.collect([comp.total_states], np.arange(1, 21))
# comp.save("Convergence")

In [None]:
CONVERGENCE_DATA = pd.read_csv(
    "results/Convergence-GradientQualityComparison-2021-05-25T15.20.csv", index_col=False
)

In [None]:
CONVERGENCE_DATA.info()

In [None]:
def plot_acc_vs_nstep(data):
    data = data[data["#Samples"] == data["#Samples"].max()]
    fig, ax = plt.subplots(1, 1, figsize=set_size(fraction=0.7))
    sns.scatterplot(
        ax=ax,
        data=data,
        x="K",
        y="Avg. cos sim with true grad",
        hue="Estimator",
        style="Estimator",
        s=10,  # Marker size https://stackoverflow.com/a/52785672/7842251
        alpha=0.75,
    )
    ax.set_xlabel("$K$")
    ax.set_ylabel("Accuracy")
    ax.set_xticks(data["K"].unique()[::2])
    ax.legend(markerscale=0.75)
    return savefig(fig, "convergence")

In [None]:
plot_acc_vs_nstep(CONVERGENCE_DATA)

## Impact of gradient quality on policy optimization

In [None]:
from typing import Iterable, TypeVar

from raylab.utils.exp_data import load_exps_data

T = TypeVar("T")

In [None]:
def dict_exclude(mapping: dict[str, T], keys: Iterable[str]) -> dict[str, T]:
    return {k: mapping[k] for k in set(mapping.keys()).difference(set(keys))}


def push_config_to_dataframe(exp_data):
    progress = exp_data.progress
    params = dict_exclude(exp_data.params, ["wandb_tags", "wandb_dir"])
    return pd.concat((progress, pd.DataFrame(params, index=progress.index)), axis=1)


def read_optimization_data(path: str) -> pd.DataFrame:
    exps_data = load_exps_data(path)
    dfs = [push_config_to_dataframe(e) for e in exps_data]
    return pd.concat(dfs, ignore_index=True)

In [None]:
OPTIMIZATION_DATA = read_optimization_data("results/Experiment_2021-06-07_09-25-12/")
OPTIMIZATION_DATA.head()

In [None]:
OPTIMIZATION_DATA.info()

In [None]:
HUE_ORDER = "dpg maac".split()

In [None]:
def learning_curves(data, ci="sd"):
    data = data.copy()
    data["Iteration"] = data["training_iteration"]
    data["Cost"] = -data["true_value"]
    data["seed"] = data["env_seed"]

    cols = 2
    rows = (len(data["seed"].unique()) + 1) // cols
    print(f"rows: {rows}, cols: {cols}")
    w, h = set_size(fraction=1, subplots=(rows, cols))
    print(f"width: {w}, height: {h}")

    fgrid = sns.relplot(
        data=data,
        kind="line",
        x="Iteration",
        y="Cost",
        hue="estimator",
        hue_order=HUE_ORDER,
        ci=ci,
        col="seed",
        col_wrap=cols,
        height=h / rows,
        aspect=(w / cols) / (h / rows),
        facet_kws={"sharey": False, "sharex": True, "legend_out": False},
    )
    for ax in fgrid.axes:
        ax.set_title(None)
    for seed, ax in fgrid.axes_dict.items():
        ax.set_ylim(top=np.percentile(data[data["seed"] == seed]["Cost"], 97.5))
    fgrid.tight_layout()
    return fgrid

### Unnormalized gradients

In [None]:
savefig(
    learning_curves(
        OPTIMIZATION_DATA.query("normalize_svg == False & env_seed not in [2, 7]")
    ),
    "unnormalized_svg_optimization",
)

In [None]:
savefig(
    learning_curves(
        OPTIMIZATION_DATA.query(
            "normalize_svg == False & env_seed not in [2, 7, 5, 6, 8, 9]"
        )
    ),
    "unnormalized_svg_optimization_short",
)

### Normalized gradients

In [None]:
savefig(
    learning_curves(
        OPTIMIZATION_DATA.query("normalize_svg == True & env_seed not in [2, 7]")
    ),
    "normalized_svg_optimization",
)

In [None]:
savefig(
    learning_curves(
        OPTIMIZATION_DATA.query(
            "normalize_svg == True & env_seed not in [2, 7, 5, 6, 8, 9]"
        )
    ),
    "normalized_svg_optimization_short",
)

## Suboptimality Gap

### SGD (clip gradient norm)

In [None]:
def read_suboptimality_data(logdir: str) -> pd.DataFrame:
    data = read_optimization_data(logdir)
    data["Iteration"] = data["training_iteration"]
    data["Cost"] = -data["true_value"]
    data["Optimal"] = -data["optimal_value"]
    data["Suboptimality"] = 100 * (data["Cost"] - data["Optimal"]) / data["Optimal"]
    data["Grad Norm"] = data["grad_norm"]
    data["Dim"] = data["env_dim"]
    return data

In [None]:
NEW_SGD_DATA = read_suboptimality_data("results/SuboptimalityGap_2021-06-12_15-59-42/")

In [None]:
NEW_SGD_DATA.query("time_total_s >= 300 & Dim == 10").groupby("estimator").mean()[
    "time_total_s"
]

In [None]:
def progress_summary(data: pd.DataFrame) -> pd.DataFrame:
    data["Time (min)"] = (data["time_total_s"] // 60).astype(int)
    data = data.query("`Time (min)` in [1, 3, 5]")
    data = pd.pivot_table(
        data,
        values="Suboptimality",
        index=["estimator", "Time (min)"],
        columns="Dim",
        aggfunc=np.median,
    )
    return data.round(decimals=2)

In [None]:
progress_summary(NEW_SGD_DATA)  # .to_html()

In [None]:
print(progress_summary(NEW_SGD_DATA).to_html())

In [None]:
def suboptimality_curves(data, y="Suboptimality", ci="sd"):
    fgrid = sns.relplot(
        data=data,
        kind="line",
        x="time_total_s",
        y=y,
        hue="estimator",
        hue_order=HUE_ORDER,
        ci=ci,
        col="seed",
        row="Dim",
        facet_kws={"sharey": False, "legend_out": False, "margin_titles": True},
    )
    for ax in fgrid.axes.flatten():
        ax.yaxis.set_major_formatter(
            mtick.PercentFormatter(xmax=1.0, decimals=1, symbol="\%", is_latex=True)
        )
    fgrid.set_xlabels("Wall clock time (sec)")
    fgrid.tight_layout()
    return fgrid

In [None]:
savefig(suboptimality_curves(NEW_SGD_DATA), "SuboptimalitySGD_clipped")

In [None]:
NEW_SGD_DATA[["K", "B", "optimizer", "learning_rate", "clip_grad_norm"]].head()

## Optimization surface

In [None]:
assert False  # Avoid running experiments below

In [None]:
class OptimizationSurfaceComparison(Trial):
    def plot_real_vs_surrogate(
        self, estimator: str, samples: int, n_step: int, seed: Optional[int] = None
    ):
        sns.reset_orig()
        estim = self.estimator[estimator]
        estim.n_steps = n_step

        seed = seed or self.rng.integers(np.iinfo(int).max)
        with tutil.default_generator_seed(seed):
            _, svg = estim(samples)
            direction = tutil.tensors_to_vector(svg).numpy()

            real_XYZ = analysis.optimization_surface(
                self.delta_to_return(),
                direction=direction,
                max_scaling=3.0,
                steps=20,
                rng=seed,
            )
            surrogate_XYZ = analysis.optimization_surface(
                estim.delta_to_surrogate(samples, n_step, update_q=False),
                direction=direction,
                max_scaling=3.0,
                steps=20,
                rng=seed,
            )

        fig = plt.figure(figsize=default_figsize(2, 4))
        ax1 = fig.add_subplot(1, 2, 1, projection="3d")
        ax2 = fig.add_subplot(1, 2, 2, projection="3d")
        plot_surface(*real_XYZ, ax=ax1, invert_xaxis=True)
        plot_surface(*surrogate_XYZ, ax=ax2, invert_xaxis=True)
        ax1.set_xlabel("Random direction")
        ax2.set_xlabel("Random direction")
        ax1.set_ylabel(f"SVG ({estimator}) direction")
        ax2.set_ylabel(f"SVG ({estimator}) direction")
        ax1.set_zlabel("Policy return")
        ax2.set_zlabel("Surrogate value")
        fig.suptitle(f"NStep: {n_step}")
        plt.show()

    def delta_to_return(self) -> Callable[[np.ndarray], np.ndarray]:
        policy = self.policy.standard_form()
        dynamics, cost, init = self.lqg.standard_form()
        return analysis.delta_to_return(policy, dynamics, cost, init)

In [None]:
comparator = OptimizationSurfaceComparison(4)
comparator.plot_real_vs_surrogate("dpg", samples=200, n_step=0, seed=4)
comparator.plot_real_vs_surrogate("maac", samples=200, n_step=0, seed=4)

In [None]:
comparator.plot_real_vs_surrogate("dpg", samples=200, n_step=4, seed=4)
comparator.plot_real_vs_surrogate("maac", samples=200, n_step=4, seed=4)

In [None]:
comparator.plot_real_vs_surrogate("dpg", samples=200, n_step=10, seed=4)
comparator.plot_real_vs_surrogate("maac", samples=200, n_step=10, seed=4)

In [None]:
comparator.plot_real_vs_surrogate("dpg", samples=200, n_step=20, seed=4)
comparator.plot_real_vs_surrogate("maac", samples=200, n_step=20, seed=4)

In [None]:
# comparator = OptimizationSurfaceComparison(4)
# comparator.plot_real_vs_surrogate("dpg", samples=200, n_step=0, update_q=True)
# comparator.plot_real_vs_surrogate("maac", samples=200, n_step=0, update_q=True)
# comparator.plot_real_vs_surrogate("dpg", samples=200, n_step=3, update_q=True)
# comparator.plot_real_vs_surrogate("maac", samples=200, n_step=3, update_q=True)

In [None]:
comparator = OptimizationSurfaceComparison(42)
comparator.plot_real_vs_surrogate("dpg", samples=200, n_step=4, seed=42)
comparator.plot_real_vs_surrogate("maac", samples=200, n_step=4, seed=42)

In [None]:
comparator.plot_real_vs_surrogate("dpg", samples=500, n_step=4, seed=42)
comparator.plot_real_vs_surrogate("maac", samples=500, n_step=4, seed=42)