In [None]:
import torch
import matplotlib.pyplot as plt
from pathlib import Path

def plot_image(images, rescale_method="clamp", name="temp_image"):
    # Create the 4x4 grid
    fig, axes = plt.subplots(4, 4, figsize=(6, 6))
    axes = axes.flatten()

    for img, ax in zip(images, axes):
        if rescale_method == "tanh":
            img = torch.tanh(img)
        elif rescale_method == "clamp":
            img = torch.clamp(img, -1.0, 1.0)
        elif rescale_method == "none":
            pass
        else:
            raise ValueError("Unsupported rescale method")
        img = (img + 1) / 2
        img = img.permute(1, 2, 0)
        ax.imshow(img)
        ax.axis("off")

    plt.tight_layout()
    Path("./images").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"./images/{name}.png")
    plt.close(fig)

images = torch.load("./images/image_batch.pth", map_location="cpu")
plot_image(images[:16], rescale_method="clamp", name=f"images")