In [None]:
import sys
import pickle
from pathlib import Path

# Add project root to sys.path to enable imports from src
sys.path.append(str(Path.cwd().parent))

import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F

from src.data.synthetic_dataset import generate_synthetic_dataset

from utils import set_matplotlib_configuration

In [None]:
SAVE_FOLDER = Path("images/")
FONTSIZE = 8.1

In [None]:
_, SAVEFIG_KWARGS = set_matplotlib_configuration(fontsize=FONTSIZE)
SAVE_FOLDER.mkdir(exist_ok=True)

# Example of synthetic datasets

In [None]:
def show_row(imgs, axs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[i].imshow(np.asarray(img))
        axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
images_per_dataset = 8

fig, axs = plt.subplots(3, images_per_dataset, squeeze=True, figsize=(5.1, 1.9))

for idx, p in enumerate([0.0, 0.5, 1.0]):
    dataset = generate_synthetic_dataset(images_per_dataset, p, (1024, 1024), base_color="blue", base_shape="circle")
    show_row([d[0] for d in dataset], axs[idx])
    axs[idx][0].set_ylabel(f"$p\,$=$\,{p}$", fontsize=FONTSIZE+1.0)

plt.subplots_adjust(wspace=0.1, hspace=0.01)
plt.savefig(SAVE_FOLDER / "synthetic_examples.pdf", **SAVEFIG_KWARGS)

# Show the profiles in 2D

In [None]:
def draw_shape(ax, shape, xy, color, size=100):
    if shape == "circle":
        ax.scatter(xy[0], xy[1], c=[color], s=size, marker="o", linewidth=0.5, edgecolor="black")
    elif shape == "square":
        ax.scatter(xy[0], xy[1], c=[color], s=size, marker="s", linewidth=0.5, edgecolor="black")
    elif shape == "triangle":
        ax.scatter(xy[0], xy[1], c=[color], s=size, marker="^", linewidth=0.5, edgecolor="black")
    elif shape == "cross":
        ax.scatter(xy[0], xy[1], c=[color], s=size, marker="X", linewidth=0.5, edgecolor="black")
    elif shape == "line":
        ax.scatter(xy[0], xy[1], c=[color], s=size, marker="_", linewidth=0.5, edgecolor="black")

In [None]:
## get the hex colors to be consistent with the other plots
# import seaborn as sns
# colorblind_palette = sns.color_palette('colorblind')
# print(colorblind_palette.as_hex())

color_mapping = {
    "red": "#d55e00",
    "yellow": "#ece133",
    "blue": "#0173b2",
    "green": "#029e73",
}

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(6., 1.0))

all_data = [
    ("coleds_clmean", "CoLEDS"),
    ("wd_training", "WDP"),
    ("logit", "LgP"),
    ("esc_clf", "REPA"),
    ("esc_vae", "VAE-E"),
]
for idx, (folder, algo) in enumerate(all_data):
    ax = axs[idx]
    with open(f"../outputs/synthetic/{folder}/visual.pkl", 'rb') as file:
        data = pickle.load(file)

    ax.set_title(algo, fontsize=FONTSIZE-0.8)
    for xy, shape, color in zip(data["coords"], data["shapes"], data["colors"]):
        color = color_mapping[color]
        draw_shape(ax, shape, xy, color, size=30)

    ax.set_xticklabels([])
    ax.set_yticklabels([])
fig.subplots_adjust(wspace=0.12)
plt.savefig(SAVE_FOLDER / "synthetic_2d.pdf", **SAVEFIG_KWARGS)