In [None]:
import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import repaint_lib
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from skimage.morphology import erosion, disk
from dataset import get_dataset
from tqdm import tqdm


repaint_dir = os.path.join("repaint")


sns.set_theme()
sns.set_context("paper")

In [None]:
gamma = 5
s = repaint_lib.get_s(gamma=gamma)

dataset = get_dataset(
    "val", mu=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], return_feature=True, return_idx=True
)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

count = 0
for _data in tqdm(dataloader):
    idx, x0, attr, feature = _data
    print(idx)

    attr_name = "Smiling"
    attr_idx = dataset.attr_names.index(attr_name)
    target = attr[:, attr_idx].item()
    if not target:
        continue
    idx = idx.item()

    idx_repaint_dir = os.path.join(repaint_dir, str(idx))
    idx_feature_dir = os.path.join(idx_repaint_dir, "features")
    os.makedirs(idx_feature_dir, exist_ok=True)

    _, ax = plt.subplots(figsize=(5, 5))
    ax.imshow(feature.squeeze(), cmap="gray")
    ax.axis("off")
    plt.savefig(os.path.join(idx_feature_dir, "feature.png"), bbox_inches="tight")
    plt.savefig(os.path.join(idx_feature_dir, "feature.pdf"), bbox_inches="tight")
    plt.close()

    for _s in s[::-1]:
        m = repaint_lib.get_mask(feature, _s)

        footprint = disk(2)
        m_eroded = erosion(m.squeeze(), footprint)

        _, ax = plt.subplots(figsize=(5, 5))
        ax.imshow(m_eroded, cmap="gray")
        ax.axis("off")
        plt.savefig(
            os.path.join(idx_feature_dir, f"mask_{_s}.png"), bbox_inches="tight"
        )
        plt.savefig(
            os.path.join(idx_feature_dir, f"mask_{_s}.pdf"), bbox_inches="tight"
        )
        plt.close()


# x, attr, feature = data

# s = [0,]
# m = repaint_lib.get_mask(feature, s)

# footprint = disk(2)
# m_eroded = torch.stack([torch.tensor(erosion(_m.squeeze(), footprint)) for _m in m]).unsqueeze(1)
# print(m.size())

# data = torch.cat([x, feature.repeat(1, 3, 1, 1) / 4, m.repeat(1, 3, 1, 1), m_eroded.repeat(1, 3, 1, 1)], dim=0)

# attr_name = "Smiling"
# attr_idx = dataset.attr_names.index(attr_name)
# target = attr[:, attr_idx]
# print(target)

# _, ax = plt.subplots(figsize=(16, 9))
# im = make_grid(data, nrow=6)
# ax.imshow(im.permute(1, 2, 0))
# ax.axis("off")
# plt.show()