# Demo Using the Unwrapped Sprite Datasets

In [None]:
import matplotlib.pyplot as plt
import random

from poke_sprite_dataset.datasets.gen_v_sprites import (
    GenVUnwrappedSprites, 
    ConditionalGenVUnwrappedSprites, 
    int_to_type, 
    int_to_color
)
from torchvision.utils import make_grid

In [None]:
DATA_DIR = '/home/kyle/projects/pokemon_data/data'

sprites_shiny = GenVUnwrappedSprites(
    data_dir=DATA_DIR,
    get_shiny=True,
)

sprites_no_shiny = GenVUnwrappedSprites(
    data_dir=DATA_DIR ,
    get_shiny=False,
)

print('Samples with Shiny:\t', len(sprites_shiny))
print('Samples without Shiny:\t', len(sprites_no_shiny))

In [None]:
shiny_sample_idx = random.sample(range(len(sprites_shiny)), 16)
no_shiny_sample_idx = random.sample(range(len(sprites_no_shiny)), 16)

shiny_grid = make_grid([sprites_shiny[i][:4, :, :] for i in shiny_sample_idx], nrow=4)
no_shiny_grid = make_grid([sprites_no_shiny[i][:4, :, :] for i in no_shiny_sample_idx], nrow=4)

plt.figure(figsize=(32, 16))
fig, axs = plt.subplots(nrows=1, ncols=2)
axs[0].imshow(shiny_grid.permute(1, 2, 0))
axs[0].set_title('Including Shiny')
axs[0].axis('off')

axs[1].imshow(no_shiny_grid.permute(1, 2, 0))
axs[1].set_title('No Shiny')
axs[1].axis('off')

plt.show()

## Conditional Dataset

In [None]:
cond_dataset = ConditionalGenVUnwrappedSprites(DATA_DIR, get_shiny=True)

In [None]:
cond_idxs = random.sample(range(len(cond_dataset)), 16)

fig, axs = plt.subplots(nrows=4, ncols=4, figsize=(16, 16))

for n, idx in enumerate(cond_idxs):
    i, j = n // 4, n % 4

    sprite, cond_data = cond_dataset[idx]

    types = [int_to_type(t) for t in cond_data['types']]
    color = int_to_color(cond_data['color'])

    axs[i][j].imshow(sprite.permute(1, 2, 0))
    axs[i][j].set_title(f"{cond_data['name'].title()}. Color: {color}. Types: {types}")
    axs[i][j].axis('off')

plt.tight_layout()
plt.show()