In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

from plnn.dataset import get_dataloaders

In [2]:
SEED = 131241245
rng = np.random.default_rng(seed=SEED)

OUTDIR = "out/dataset_images"

os.makedirs(OUTDIR, exist_ok=True)

In [5]:
datdir_names = [
    "data_phi1_1a",
    "data_phi1_1b",
    "data_phi1_1c",
    "data_phi1_2a",
    "data_phi1_2b",
    "data_phi1_2c",
    "data_phi1_3a",
    "data_phi1_3b",
    "data_phi1_3c",
    "data_phi1_4a",
    "data_phi1_4b",
    "data_phi1_4c",
    "data_phi1_4a",
    "data_phi1_4b",
    "data_phi1_4c",
    "data_phi2_1a",
    "data_phi2_1b",
    # "data_phi2_1c",
    "data_phiq_1a",
    "data_phiq_2a",
]


In [6]:
FIGSIZE = None
NIDXS_HIGHLIGHT = 2

for datdir_name in datdir_names:

    datdir = f"../data/training_data/{datdir_name}"

    datdir_train = f"{datdir}/training"
    datdir_valid = f"{datdir}/validation"
    datdir_test = f"{datdir}/testing"
    nsims_train = np.genfromtxt(f"{datdir_train}/nsims.txt", dtype=int)
    nsims_valid = np.genfromtxt(f"{datdir_valid}/nsims.txt", dtype=int)
    nsims_test = np.genfromtxt(f"{datdir_test}/nsims.txt", dtype=int)

    _, _, _, train_dset, valid_dset, test_dset = get_dataloaders(
        datdir_train, datdir_valid, nsims_train, nsims_valid,
        shuffle_train=False,
        shuffle_valid=False,
        return_datasets=True,
        include_test_data=True,
        datdir_test=datdir_test, nsims_test=nsims_test, shuffle_test=False,
        batch_size_test=1,
        seed=rng.integers(2**32)
    )

    for dset, dset_name in zip(
        [train_dset, valid_dset, test_dset], ['train', 'valid', 'test']
    ):
        fig, ax = plt.subplots(1, 1, figsize=FIGSIZE)

        for d in dset:
            x0 = d[0][1]
            x1 = d[1]
            ax.plot(
                x0[:,0], x0[:,1], '.', 
                color='k',
                markersize=1, 
                alpha=0.5, 
                rasterized=True
            )
            ax.plot(
                x1[:,0], x1[:,1], '.', 
                color='k',
                markersize=1, 
                alpha=0.5, 
                rasterized=True
            )

        dataidxs = np.sort(
            rng.choice(len(dset), size=NIDXS_HIGHLIGHT, replace=False)
        )
        for j, dataidx in enumerate(dataidxs):
            x0 = dset[dataidx][0][1]
            x1 = dset[dataidx][1]
            l, = ax.plot(
                x0[:,0], x0[:,1], '.', markersize=1, alpha=0.5,
                label=f"obs {dataidx} (n={len(x0)})",
                rasterized=False,
                # color=highlight_colors[j],
            )
            ax.plot(
                x1[:,0], x1[:,1], '.', markersize=1, alpha=1,
                color=l.get_color(),
                rasterized=False,
            )

        ax.set_title("");

        train_ax_xlims = ax.get_xlim()
        train_ax_ylims = ax.get_ylim()

        plt.savefig(f"{OUTDIR}/data_{datdir_name}_{dset_name}.pdf", bbox_inches='tight')
        plt.close()
    