In [None]:
import itertools
from pathlib import Path

import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from elements.classes import ElementDataset

In [None]:
save_dir = Path("example_datasets")
save_dir.mkdir(exist_ok=True)

In [None]:
class_configs = [
    {"shape": None, "color": None, "texture": "solid"},
    {"shape": None, "color": "red", "texture": "solid"},
    {"shape": None, "color": "blue", "texture": "stripes_diagonal"},
    {"shape": None, "color": "green", "texture": "spots_polka"},
    {"shape": "circle", "color": None, "texture": "solid"},
    {"shape": "circle", "color": None, "texture": "spots_polka"},
    {"shape": "triangle", "color": "green", "texture": None},
    {"shape": "square", "color": "blue", "texture": None},
    {"shape": "triangle", "color": "red", "texture": "stripes_diagonal"},
    {"shape": "triangle", "color": "blue", "texture": "stripes_diagonal"},
    {"shape": "square", "color": "green", "texture": "spots_polka"},
    {"shape": "plus", "color": "magenta", "texture": "spots_polka"},
]

allowed_shapes = ['square', 'circle', 'triangle', 'plus']
allowed_colors = ['red', 'green', 'blue']
allowed_textures = ["solid", "spots_polka", "stripes_diagonal"]

allowed = {
    "shapes": allowed_shapes,
    "colors": allowed_colors,
    "textures": allowed_textures
}
dataset = ElementDataset(allowed, class_configs, 1000, 224, 4, 64, 16, 42, 123)

In [None]:
dataset.config

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(15, 9))
axes = axes.flatten()
for i, ax in enumerate(axes):
    img = dataset.get_item(i)
    classes = img.class_labels
    classes = [str(v) for v in classes]
    ax.imshow(img.img)
    ax.set_title(", ".join(classes))
    ax.axis("off")
plt.savefig(save_dir / "simple_small.png")
plt.show()

In [None]:
dataloader = DataLoader(dataset, 32)
labels = []
for sample in dataloader:
    labels.append(sample[1])
labels = np.concatenate(labels)
print("No. imgs per class")
print(labels.sum(axis=0))
print("No. classes per image")
vals, counts = np.unique(labels.sum(axis=1), return_counts=True)
print(", ".join([f"{vals[i]: .0f}: {counts[i]}"for i in range(len(vals))]))

In [None]:
allowed_shapes = ['square', 'circle', 'triangle', 'plus']
allowed_colors = ['red', 'green', 'blue']
allowed_textures = ["solid", "spots_polka", "stripes_diagonal"]

In [None]:
class_configs = list(itertools.product(allowed_shapes + [None], allowed_colors + [None], allowed_textures + [None]))
print(len(class_configs))
class_configs = [v for v in class_configs if sum([in_v is None for in_v in v]) < 2]
print(len(class_configs))

In [None]:
class_configs = [{"shape": v[0], "color": v[1], "texture": v[2]} for v in class_configs]
class_configs

In [None]:
allowed = {
    "shapes": allowed_shapes,
    "colors": allowed_colors,
    "textures": allowed_textures
}
dataset = ElementDataset(allowed, class_configs, 1000, 224, 4, 64, 16, 42, 123)

In [None]:
dataset.config

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(15, 9))
axes = axes.flatten()
for i, ax in enumerate(axes):
    img = dataset.get_item(i)
    classes = img.class_labels
    ax.imshow(img.img)
    ax.set_title(f"No. classes {len(classes)}")
    ax.axis("off")
plt.tight_layout()
plt.savefig(save_dir / "simple_all.png")
plt.show()

In [None]:
dataloader = DataLoader(dataset, 32)
labels = []
for sample in dataloader:
    labels.append(sample[1])
labels = np.concatenate(labels)
print("No. imgs per class")
print(labels.sum(axis=0))
print("No. classes per image")
vals, counts = np.unique(labels.sum(axis=1), return_counts=True)
print(", ".join([f"{vals[i]: .0f}: {counts[i]}"for i in range(len(vals))]))