# Plot Imagesets

Sets of visual stimuli used for training, consisting of simple 2D object
shapes formed from n sides and p conformations per side.

Imagesets:
- N3P2
- N4P2

In [None]:
import math

import matplotlib.pyplot as plt

from hsnn.utils import io, ImageSet
from hsnn.transforms import Compose, Resize
from hsnn import viz

dsize = (128, 128)
SAVE_DIR = io.BASE_DIR / "out/figures/fig4"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

def plot_grid(imageset: ImageSet, num_cols=4, title=None) -> None:
    num_rows = math.ceil(len(imageset) / num_cols)
    num_cols, num_rows
    f, axes = plt.subplots(num_rows, num_cols, sharex=True, sharey=True)

    for idx, image in enumerate(imageset):
        col = idx % num_cols
        row = idx // num_cols
        ax: plt.Axes = axes[row, col]
        if row == 3 and col == 3:
            f.delaxes(ax)
        else:
            ax.xaxis.set_ticklabels([])
            ax.xaxis.set_ticks([])
            ax.yaxis.set_ticklabels([])
            ax.yaxis.set_ticks([])
            ax.imshow(image, cmap="gray", vmin=0, vmax=255)
            # ax.set_title(Path(imageset.image_paths[idx]).stem)
    if title:
        f.suptitle(title)
    f.set_size_inches(8, 8 / 2)
    f.tight_layout()

In [None]:
imageset = ImageSet(io.DATA_DIR / "n3p2", transform=Compose(Resize(dsize)))

axes = viz.plot_images(
    imageset,
    num_rows=1,
    vmin=0,
    vmax=255,
    cmap="gray",
    plot_ticks=False,
    figsize=(5.5, 1.4),
)
f = plt.gcf()
f.savefig(SAVE_DIR / "fig_n3p2.pdf", format="pdf", dpi=300)

In [None]:
imageset = ImageSet(io.DATA_DIR / "n4p2", transform=Compose(Resize(dsize)))

axes = viz.plot_images(
    imageset,
    num_rows=2,
    vmin=0,
    vmax=255,
    cmap="gray",
    plot_ticks=False,
    figsize=(5.5, 1.4),
)
f = plt.gcf()
f.savefig(SAVE_DIR / "fig_n4p2.pdf", format="pdf", dpi=300)
