In [1]:
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms

from ipywidgets import interact, IntText

from poke_sprite_dataset.datasets.full_art_dataset import full_art_dataset
from poke_sprite_dataset.datasets.helpers import unwrap_embedding

# Unconditional Full-Art Dataset

In [2]:
dataset = full_art_dataset('/home/kyle/projects/pokemon_data/data')

preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((250, 250), interpolation=transforms.InterpolationMode.BICUBIC),
    ]
)

def transform(examples):
    images = [preprocess(image.convert("RGBA")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform)

In [3]:
def show_image(idx: int):
    idx = int(idx)
    if idx < 0 or idx >= len(dataset):
        print(f"Index {idx} is out of bounds")
        idx = 0
    image = dataset[idx]["images"]
    plt.imshow(image.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

interact(show_image, idx=IntText(value=0, description='Index:'));

interactive(children=(IntText(value=0, description='Index:'), Output()), _dom_classes=('widget-interact',))

# Conditional Full-Art Dataset

In [4]:
dataset = full_art_dataset('/home/kyle/projects/pokemon_data/data', conditional=True)

preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((250, 250), interpolation=transforms.InterpolationMode.BICUBIC),
    ]
)

def transform(examples):
    images = [preprocess(image.convert("RGBA")) for image in examples["image"]]
    return {"images": images, "condition": examples["condition"]}

dataset.set_transform(transform)

In [5]:
def show_image(idx: int):
    idx = int(idx)
    if idx < 0 or idx >= len(dataset):
        print(f"Index {idx} is out of bounds")
        idx = 0
    image = dataset[idx]["images"]
    label = dataset[idx]["condition"]
    print(label)
    print(unwrap_embedding(label))
    plt.imshow(image.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

print(f"Index min 0, max {len(dataset) - 1}")
interact(show_image, idx=IntText(value=0, description=f'Index:'));

Index min 0, max 1011


interactive(children=(IntText(value=0, description='Index:'), Output()), _dom_classes=('widget-interact',))