In [1]:
import os
import pandas as pd
from PIL import Image, ImageSequence, GifImagePlugin
import imagehash
from io import BytesIO

import requests
from pathlib import Path

In [2]:
# setting up necessary directory names

DATA_PATH = Path("../data")
SPRITE_PATH = Path("sprite")
ARTWORK_PATH = Path("artwork")
MASK_PATH = Path("masks")
DATALOADER_CSV_PATH = DATA_PATH / Path("data.csv")
# DATA_PATH / SPRITE_PATH

In [3]:
# Some information for streamlining api access
api_url = "https://pokeapi.co/api/v2/pokemon/"

# Create a list of roman numerals for generations to pull sprites from
# There are no more pixel art sprites after gen v  
# and the art style changes between gen ii and iii
version_key = "versions"
roman_numerals = ["iii", "iv", "v"]
generations = [f"generation-{i}" for i in roman_numerals]
game_names = [["emerald", "firered-leafgreen", "ruby-sapphire"], ["platinum", "diamond-pearl", "heartgold-soulsilver"], ["black-white"]]
generations_to_games = {gen: games for gen, games in zip(generations, game_names)}

animation_gen = "generation-v"
anim_game = "black-white"
anim_key = "animated"

sprite_keys = ["front_default", "front_shiny", "front_female", "front_shiny_female"]
artwork_keys = ["front_default", "front_shiny"]

save_animations = False

In [4]:
# read in df with pokemon
CSV_PATH = Path("../pokemon.csv")
poke_df = pd.read_csv(CSV_PATH)
poke_df = poke_df.drop(columns=['Unnamed: 0'])


poke_df.tail()

Unnamed: 0,id,name,height,weight,ability_1,ability_3,forms_switchable,gender_rate,genera,has_gender_differences,...,flavor_text_17,ability_2,flavor_text_18,flavor_text_19,flavor_text_20,flavor_text_21,flavor_text_22,flavor_text_23,flavor_text_24,flavor_text_25
1274,10267,koraidon,35,3030,orichalcum-pulse,orichalcum-pulse,False,-1,,False,...,,,,,,,,,,
1275,10268,miraidon,28,2400,hadron-engine,hadron-engine,False,-1,,False,...,,,,,,,,,,
1276,10269,miraidon,28,2400,hadron-engine,hadron-engine,False,-1,,False,...,,,,,,,,,,
1277,10270,miraidon,28,2400,hadron-engine,hadron-engine,False,-1,,False,...,,,,,,,,,,
1278,10271,miraidon,28,2400,hadron-engine,hadron-engine,False,-1,,False,...,,,,,,,,,,


In [5]:
# create directories where pokemon images will be written
if not os.path.exists(DATA_PATH / SPRITE_PATH):
    os.makedirs(DATA_PATH / SPRITE_PATH)

if not os.path.exists(DATA_PATH / ARTWORK_PATH):
    os.makedirs(DATA_PATH / ARTWORK_PATH)

if not os.path.exists(DATA_PATH / MASK_PATH):
    os.makedirs(DATA_PATH / MASK_PATH)


In [6]:
# NOTE: Remember to drop "RGBA" mode pngs as these are sprite representations of 3d pokemon models

def get_sprite_links(ids: list[int], base_url: str):
    result = {}
    for id in ids:
        print(f"Processing id: {id}")
        resp = requests.get(base_url + str(id))

        if resp.status_code == 200:
            data = resp.json()
            result[id] = data["sprites"]
        else:
            print(f"Retrieving id {id} failed!")
    
    return result

def extract_pertinent_links(link_dict):
    result = {}
    for id, data in link_dict.items():
        links = {}
        links["default"] = {key: data[key] for key in sprite_keys}
        links["artwork"] = {key: data["other"]["official-artwork"][key] for key in artwork_keys}

        # get sprites for all prior versions
        for gen in generations:
            curr_gen = data["versions"][gen]
            links[gen] = {game: {key: curr_gen[game][key]  
                                 for key in sprite_keys if key in curr_gen[game]} 
                          for game in generations_to_games[gen]}

        # get animations from gen v
        anim_data = data["versions"][animation_gen][anim_game][anim_key]
        links["animation"] = {key: anim_data[key] for key in sprite_keys}

        result[id] = links

    return result


In [7]:
poke_ids = [i for i in poke_df["id"]]

sprite_links = get_sprite_links(poke_ids, api_url)


Processing id: 1
Processing id: 2
Processing id: 3
Processing id: 4
Processing id: 5
Processing id: 6
Processing id: 7
Processing id: 8
Processing id: 9
Processing id: 10
Processing id: 11
Processing id: 12
Processing id: 13
Processing id: 14
Processing id: 15
Processing id: 16
Processing id: 17
Processing id: 18
Processing id: 19
Processing id: 20
Processing id: 21
Processing id: 22
Processing id: 23
Processing id: 24
Processing id: 25
Processing id: 26
Processing id: 27
Processing id: 28
Processing id: 29
Processing id: 30
Processing id: 31
Processing id: 32
Processing id: 33
Processing id: 34
Processing id: 35
Processing id: 36
Processing id: 37
Processing id: 38
Processing id: 39
Processing id: 40
Processing id: 41
Processing id: 42
Processing id: 43
Processing id: 44
Processing id: 45
Processing id: 46
Processing id: 47
Processing id: 48
Processing id: 49
Processing id: 50
Processing id: 51
Processing id: 52
Processing id: 53
Processing id: 54
Processing id: 55
Processing id: 56
P

In [8]:
def flatten_links(link_dict, id: int) -> dict[str, str]:
    flattened_links = {}
    
    for key in link_dict:
        # sprites for previous generations are nested, need to un-nest these
        curr_links = {}
        # generations have game-indexed sub-dictionaries for sprites 
        if key in generations:
            curr_links = {f"{id}_{key}_{game}_{style}": url 
                         for game in generations_to_games[key] 
                         for style, url in link_dict[key][game].items() if url}
        else:
            curr_links = {f"{id}_{key}_{style}": url 
                          for style, url in link_dict[key].items() if url}
        
        # add flattened dict to result dict
        flattened_links.update(curr_links)    
        
    return flattened_links


def gif_to_images(anim):
    return [frame.convert("RGBA").copy() for frame in ImageSequence.Iterator(anim)]


def create_mask(im: Image): 
    
    mask = im.copy()
    if mask.mode != "RGBA":
        mask = mask.convert("RGBA")

    # Get the pixel data for the image
    pixels = mask.load()

    # Iterate over each pixel in the image
    for x in range(mask.width):
        for y in range(mask.height):
            # Check if the pixel is not transparent
            if pixels[x, y][3] > 0:
                # Set the pixel value to black
                pixels[x, y] = (0, 0, 0, 255)
    
    return mask


# def create_anim_frame_mask(frame: Image):
#     mask = frame.copy()

#     for x in range(mask.width):
#         for y in range(mask.height):
#             if 


def add_background(im: Image, target_size: tuple[int, int] = (96, 96)):
    # create white background
    bg = Image.new("RGB", target_size, (255, 255, 255))
    x, y = im.size
    if x > target_size[0] or y > target_size[1]:
        print("WOOPS! target image size larger than input!")
        return None

    x_shift = (target_size[0] - x) // 2 - 1
    y_shift = (target_size[1] - y) // 2 - 1
    
    bg.paste(im, (x_shift, y_shift), im)
    return bg


def get_image(url: str) -> Image:
    image = None
    resp = requests.get(url)
    if resp.status_code == 200:
        # Read the binary data from the response
        image_data = BytesIO(resp.content)

        # Create a PIL Image object from the image data
        try:
            image = Image.open(image_data)
        except:
            image = None 

    return image


def process_and_save_images(links):
    row_list = []

    for id, data in links.items():
        
        print(f"Processing id: {id}")
        used_sprites = set()
        mask_paths = {}

        for key, url in data.items():

            im_queue = []
            path_queue = []

            if "artwork" in key:
                im = get_image(url).convert("RGBA")
                im = add_background(im, im.size)
                im_path = DATA_PATH / ARTWORK_PATH / Path(f"{key}.png")
                im.save(im_path)

            elif "animation" in key:
                
                if not save_animations: continue    

                gif = get_image(url)
                frames = gif_to_images(gif)
                for ix, frame in enumerate(frames):
                    
                    # check for new sprite
                    im_hash = str(imagehash.phash(frame))
                    if im_hash in used_sprites: continue
                    used_sprites.add(im_hash)

                    # create mask
                    mask = create_mask(frame)
                    frame, mask = add_background(frame), add_background(mask)
                    frame_path = SPRITE_PATH / Path(f"{key}_{ix}.png")
                    mask_path = MASK_PATH / Path(f"{key}_{ix}_mask.png")
                    # add image and mask to queue 
                    im_queue.append([frame, mask])
                    path_queue.append([frame_path, mask_path])
            else:
                # Bodge to solve an issue where some poke pngs have malformed identification strings
                im = get_image(url)
                if not im: continue
                im = im.convert("RGBA")

                # create mask
                mask = create_mask(im)
                im, mask = add_background(im), add_background(mask)
                if not im or not mask: continue

                # check for unique image
                im_hash = str(imagehash.phash(im))
                if im_hash in used_sprites: continue
                used_sprites.add(im_hash)

                im_path = SPRITE_PATH / Path(f"{key}.png")
                mask_path = MASK_PATH / Path(f"{key}_mask.png")

                # check for unique mask
                mask_hash = str(imagehash.phash(mask))
                if mask_hash in mask_paths.keys():
                    mask_path = mask_paths[mask_hash]
                else:
                    mask_paths[mask_hash] = mask_path

                im_queue.append([im, mask])
                path_queue.append([im_path, mask_path])
            
            # save images in im_queue to path_queue locations
            for ims, paths in zip(im_queue, path_queue):
                # save
                im, mask = ims
                im_path, mask_path = paths
                im.save(DATA_PATH / im_path)
                
                # avoid saving duplicate masks for shiny pokemon
                mask.convert("L").save(DATA_PATH / mask_path)

                # create row for csv
                row_list.append({"image": im_path.as_posix(), "mask": mask_path.as_posix(), "id": id})

        
    return pd.DataFrame(row_list)
        


In [9]:
# flatten all links
links = extract_pertinent_links(sprite_links)

flattened_links = {id: flatten_links(data, id) for id, data in links.items()}

# flattened_links[1]
# save images to disk
image_df = process_and_save_images(flattened_links)


Processing id: 1
Processing id: 2
Processing id: 3
Processing id: 4
Processing id: 5
Processing id: 6
Processing id: 7
Processing id: 8
Processing id: 9
Processing id: 10
Processing id: 11
Processing id: 12
Processing id: 13
Processing id: 14
Processing id: 15
Processing id: 16
Processing id: 17
Processing id: 18
Processing id: 19
Processing id: 20
Processing id: 21
Processing id: 22
Processing id: 23
Processing id: 24
Processing id: 25
Processing id: 26
Processing id: 27
Processing id: 28
Processing id: 29
Processing id: 30
Processing id: 31
Processing id: 32
Processing id: 33
Processing id: 34
Processing id: 35
Processing id: 36
Processing id: 37
Processing id: 38
Processing id: 39
Processing id: 40
Processing id: 41
Processing id: 42
Processing id: 43
Processing id: 44
Processing id: 45
Processing id: 46
Processing id: 47
Processing id: 48
Processing id: 49
Processing id: 50
Processing id: 51
Processing id: 52
Processing id: 53
Processing id: 54
Processing id: 55
Processing id: 56
P

In [10]:
image_df.shape

(5001, 3)

In [11]:
# create dataframe for a pytorch dataset
cols_to_keep = ["id", "type_1", "type_2", "is_legendary", "is_mythical", "generation", "egg_group_0", "egg_group_1"]
filtered_poke_df = poke_df[cols_to_keep]
final_df = image_df.merge(filtered_poke_df, on="id", how="left")
final_df.shape

(5001, 10)

In [12]:
# write to a csv to be used in a dataset
final_df.to_csv(DATALOADER_CSV_PATH, index=False)

In [13]:
# create CSV for mask dataset
mask_csv_path = "masks.csv"
mask_df = final_df[["mask"]].drop_duplicates()
print(f"Mask df shape: {mask_df.shape}")

mask_df.to_csv(DATA_PATH / Path(mask_csv_path), index=False)


Mask df shape: (2784, 1)
