# Data Wrangler for Pokemon Identifier Project

Note: Place any additional gathered images into Tmp Directory

In [None]:
import os
import string
import csv
import re
import requests
import shutil
import random
import tensorflow as tf
import numpy as np
from bs4 import BeautifulSoup
from PIL import Image

## Helper Functions

### Global Values

In [None]:
#wether or not to gather images from the web
gatherFromWeb = False

#number of frames to gather at most from each gif
numFramesExtractGif = 0

#generation of pokemon to prepare for the final dataset
generationsToPrepare = [1, 2, 3, 4, 5, 6, 7, 8]

#list of URLs for internet sourced images
listOfImageURLs = []

#percent of images in scraped directory that will be used for training -- from 0 to 1
percentToUseForTrain = .9

#directory where images will be placed before being processed
tempDirectory = '../Tmp/'

#directory for image datasets
coreImageDir = "../Datasets/Images/"

#directory where scraped images will be placed
gatherDirectory = '../Datasets/Images/TmpScraped/'

#directory for main neural net data 
mainInfoDirectory = '../Datasets/Main/'

DIR_MODEL_IMAGES = '../Datasets/Main/Images/'

#list of websites to scrape
TARGETURLS = ['https://play.pokemonshowdown.com/sprites/']
# ["https://play.pokemonshowdown.com/sprites/"]

REPROCESS_GIFS = False
PROCESS_TEMP_IMAGES = False
INCLUDE_ALT_FORMS = False 
PREPARE_TEST_DS = True

#### Make directories that will be needed

In [None]:
if os.path.isdir(coreImageDir) is False: 
    os.mkdir(coreImageDir)
    
if os.path.isdir(gatherDirectory) is False:
    os.mkdir(gatherDirectory)

if os.path.isdir(tempDirectory) is False:
    os.mkdir(tempDirectory)
    
if os.path.isdir(mainInfoDirectory) is False:
    os.mkdir(mainInfoDirectory)
    
if os.path.isdir(os.path.join(mainInfoDirectory, 'Images')) is False:
    os.mkdir(os.path.join(mainInfoDirectory, 'Images'))

### Regex Helpers

In [None]:
compiledRE_forwardSlash = re.compile(r'/')
compiledRE_gif = re.compile(r'.gif$')
compiledRE_png = re.compile(r'.png$')
compiledRE_special = re.compile(r"[!@#$']")

### String Helpers

In [None]:
def extractFileNameFromPath(path: string, removeExtension: bool):
    nameBeginIndex = path.rfind('/')
    fullName = path[nameBeginIndex+1:]
    if removeExtension:
        extensionBeginIndex = fullName.rfind('.')
        return fullName[:extensionBeginIndex]
    else:
        return fullName

In [None]:
def removeFileNameFromPath(path: string):
    nameBeginIndex = path.rfind('/')
    return path[:nameBeginIndex]

In [None]:
def generateScrapedPath(file: string):
    pokemonName = compiledData.getProperPokemonName(file)
    if pokemonName is not False:
        fullPath = os.path.join(gatherDirectory, pokemonName)
        return fullPath
    return False

### Class

In [None]:
class DataWrangler:
    completeDatasets = [('../Datasets/GeneralData/TheCompletePokemonDataset/pokemon.csv',32, 30, 39), 
                        ('../Datasets/GeneralData/UpdatedCompletePokemonDataset/pokedex_(Update_04.21).csv', 1, 2, 5)]
    imageWebLocations = [
        ''
    ]
    def __init__(self):
        self.uniqueDexIDs = []
        self.uniqueDexNames = []
        self.pokemonGenerations = []
        self.pokeDictionary = {}
        for file,col_id,col_name,col_gen in DataWrangler.completeDatasets:
            self.populateDataFromFile(file, col_id, col_name, col_gen)
    
    def populateDataFromFile(self, filePath: string, col_id, col_name, col_gen):
        with open(filePath, encoding="utf8") as file: 
            csv_reader = csv.reader(file, delimiter=',')
        
            firstLine = True
            for line in csv_reader:
                if firstLine is not True: 
                    tempID = int(line[col_id])
                    if tempID not in self.uniqueDexIDs:
                        #remove trailing . keep from creating directories with trailing .
                        cleanName = re.sub("[.]", '', line[col_name])
                        self.uniqueDexIDs.append(tempID)
                        self.uniqueDexNames.append(cleanName)
                        self.pokemonGenerations.append(int(line[col_gen]))
                else: 
                    #figure out what columns in the dataset contain the pokemon name and pokedexID -- TODO 
                    firstLine = False

    #pick the correct pokemon that a given filename should associate with -- linear search, might want to improve in future 
    def getProperPokemonName(self, inString: string): 
        potentialMatches = []
        potentialMatchesIndex = []
        searchString = inString.lower()
        counter = 0
        
        for name in self.uniqueDexNames: 
            currName = name.lower() 
            cleanCurrName = re.sub("[!@#$'._]", '', currName)
            cleanCurrName = cleanCurrName.replace(" ", "")
            cleanPathName = re.sub("[!@#$'._]", '', inString)
            cleanPathName = cleanPathName.replace(" ", "")
            if (currName in searchString or cleanCurrName in searchString or cleanCurrName in cleanPathName):
                #need to clean up the string and find a way to chop out the name to compare with directly (eternatus has the name natu in it)
                potentialMatches.append(currName)
                potentialMatchesIndex.append(counter)
            counter += 1

        #after going through entire pokedex, go through list of potential matches and check which is most appropriate
        currBestMatch = None 
        currBestMatchIndex = None
        if len(potentialMatches) == 1:
            return potentialMatches[0]
        else:
            for match in potentialMatches:
                #if pokemon name is eternatus
                #matched list should include 'natu' AND 'eternatus' 
                #of the potential matches, determine which is the best

                #go through entire string and see how many characters of the string that it matches
                charCount = 0
                searchStringIndex = 0
                continueMatch = True

                #get start index of the potential pokemon name 
                try:
                    searchStringIndex = searchString.find(match)
                except:
                    #string does not contain name, bad match
                    continueMatch = False
                
                if continueMatch is True:
                    for i in range(len(match)):
                        if match[i] == searchString[searchStringIndex]:
                            charCount += 1
                            searchStringIndex += 1

                    #see if searchString matches the entire length of the potential pokemon name
                    if ((currBestMatch is None) or ((charCount == len(match)) and (len(match) > len(currBestMatch)))):
                        currBestMatch = match
            
            if currBestMatch is not None: 
                return currBestMatch
            else:
                return False
        
    def getPokemonGeneration(self, pokemonName: string) -> int:
        for counter in range(len(self.uniqueDexNames)):
            if (pokemonName.lower() == self.uniqueDexNames[counter].lower()):
                return int(self.pokemonGenerations[counter])
        return False

compiledData = DataWrangler()

### Image Helpers

In [None]:
def analyseImage(path):
    im = Image.open(path)
    results = {
        'size' : im.size, 
        'mode' : 'full'}
    try:
        while True: 
            if im.tile:
                tile = im.tile[0]
                update_region = tile[1]
                update_region_dimensions = update_region[2:]
                if update_region_dimensions != im.size:
                    results['mode'] = 'partial'
                    break; 
            im.seek(im.tell() + 1)
    except EOFError:
        pass
    return results

In [None]:
#split a given gif into seperate images -- will return paths to all new files
def gifToImages(pathToGif: string, destinationPath: string):

    #get number of keyframes of gif
    if (os.path.isfile(pathToGif)):
        createdFilePaths = []
        
        trans_color = (255, 255, 255)
        try:
            mode = analyseImage(pathToGif)
        except:
            print(f"Failed to check {pathToGif} skipping")
            return []
        
        try:
            with Image.open(pathToGif) as openGif:
                numFrames = openGif.n_frames
                numToExtract = 0

                #check if the number of frames in a given gif is more than the max number defined to get
                if numFramesExtractGif != 0 and numFrames > numFramesExtractGif:
                    numToExtract = numFramesExtractGif
                else:
                    numToExtract = numFrames

                framesToGet = np.linspace(0, openGif.n_frames - 1, numToExtract)
                isFirstFrame = True
                palette = openGif.getpalette()

                for frameNumber in framesToGet.astype(np.int64):
                    # openGif.seek(frameNumber)
                    # currFrame = openGif.convert('RGBA')
                    # currFrame = Image.alpha_composite(currFrame, img.convert('RGBA'))

                    # image = currFrame.convert("RGBA")
                    # datas = image.getdata()
                    # newData = []

    #                     if isFirstFrame:
    #                         palette = currFrame.getpalette()
    #                         isFirstFrame = False
    #                     else:
    #                         image.putpalette(palette)

                    # for item in datas:
                    #     if item[3] == 0: 
                    #         #transparent
                    #         newData.append(trans_color)
                    #     else:
                    #         newData.append(tuple(item[:3]))

                    # image = Image.new("RGB", openGif.size)
                    # image.getdata()
                    # image.putdata(newData)
                    openGif.seek(frameNumber)

                    if not openGif.getpalette():
                        openGif.putpalette(palette)
                
                    new_frame = Image.new('RGBA', openGif.size, "BLACK")
                    
                    if mode == 'partial':
                        new_frame.paste(last_frame)

                    new_frame.paste(openGif, (0,0), openGif.convert('RGBA'))
                    new_frame.n_frames = 1
                    fileName = f'{extractFileNameFromPath(pathToGif, True)}-{frameNumber}.png'
                    finalFullPath = os.path.join(destinationPath, fileName)
                    createdFilePaths.append(finalFullPath)
                    new_frame.convert('RGB').save(finalFullPath)    
        except:
            print('err')
        return createdFilePaths

# gifToImages('./Tmp/abomasnow-mega.gif', generateScrapedPath('./Tmp/abomasnow-mega.gif'))

In [None]:
#apply any formatting that is needed for the given image and place into correct directory
def processImage(pathToImage: string, isWebPath: bool, overrideDestinationPath: string=None): 
    destinationPath = None
    if overrideDestinationPath is None:
        destinationPath = gatherDirectory
    else:
        destinationPath = overrideDestinationPath

    if (len(compiledRE_png.findall(pathToImage)) !=0):
        #current image is in the temp directory, copy to other directory

        #dont copy if the file is already in the proper compiled directory
        if os.path.isdir(destinationPath) is False: 
            os.mkdir(destinationPath)
        try:
            with Image.open(pathToImage) as image:
                #replace alpha channel with black channel
                new_frame = Image.new('RGBA', image.size, "BLACK")
                new_frame.paste(image, (0,0), image.convert('RGBA'))

                fileName = os.path.basename(pathToImage)
                finalPath = os.path.join(destinationPath, fileName)
                new_frame.convert('RGB').save(finalPath)
        except Image.UnidentifiedImageError as imgErr:
            print(f"Unable to ID image {pathToImage}")

            
        # shutil.copy2(pathToImage, destinationPath)
        #check if image is a gif and convert to a group of images
    elif (len(compiledRE_gif.findall(pathToImage)) != 0):
        #it is a gif -- if processing image from other dataset (not currently in tmp), shouldnt do extra copy to tmp directory :: TODO: UNLESS EXTRA PROCESSING IS NEEDED (fixing images in some way)
        createdGifImages = gifToImages(pathToImage, destinationPath)

        #will need to process each image just created
        # if createdGifImages is not None:
        #     for newImages in createdGifImages:
        #         processImage(newImages, False, destinationPath)

# processImage('./Datasets/Images/1300-big-front-gifs/001-bulbasaur-s.gif', False)

### Web Helpers

In [None]:
def downloadImage(imageURL: string):

    #create filename for new file - get file name from URL along with parent directory on remote server (combine)
    nameBeginIndex = imageURL.rfind('/')
    pathWithoutName = imageURL[:nameBeginIndex]
    extendedDirIndex = pathWithoutName.rfind('/')
    fileName = pathWithoutName[extendedDirIndex+1:] + '--' + imageURL[nameBeginIndex+1:]

    fullNewFilePath = os.path.join(tempDirectory, fileName)

    if (os.path.isfile(fullNewFilePath) is not True):
        #download the file from the remote and place in new path
        read = requests.get(imageURL)

        with open (fullNewFilePath, 'wb') as f: 
            f.write(read.content)
            f.close()
            
# downloadImage('https://play.pokemonshowdown.com/sprites/ani-back/ferroseed.gif')

### Supporting methods for image search

In [None]:
#recursively search through a provided URL to find gifs
def browseForImages(currRoot):
    global listOfImageURLs
    
    #avoid april fools day images on pokemon showdown
    if "afd" not in currRoot:
        page = requests.get(currRoot)
        soup = BeautifulSoup(page.content, "html.parser")

        results = soup.find_all("a", text=compiledRE_forwardSlash)
        pngSources = soup.find_all("a", text=compiledRE_png)
        gifSources = soup.find_all("a", text=compiledRE_gif)

        for image in pngSources: 
            full = currRoot + image.text
            listOfImageURLs.append(full)

        for image in gifSources: 
            full = currRoot + image.text
            listOfImageURLs.append(full)

        #navigate through all of the possible directories 
        for each in results: 
            subURL = currRoot + each.text
            browseForImages(subURL)

## Gather data from internet resources

#### Gather image paths into list and then download images as needed

In [None]:
if gatherFromWeb is True:
    #gather target URLs for images
    for target in TARGETURLS:
        browseForImages(target)

    #go through and download images as needed
    for url in listOfImageURLs: 
        downloadImage(url)

Took 130 minutes to complete

## Sort data gathered into useable dataset for testing and training

In [None]:
def searchForFiles(currentDir):
    if os.path.isdir(currentDir):
        nextLevelContents = os.listdir(currentDir)
        for content in nextLevelContents:
                #go through all contents except for gathered directory 
                fullPath = os.path.join(currentDir, content)    
                searchForFiles(fullPath)   
    else:
        #this child has to be a file -- copy to core dataset 
        pokemonName = compiledData.getProperPokemonName(currentDir)
        if pokemonName is not False and compiledData.getPokemonGeneration(pokemonName) in generationsToPrepare:
            datasetPath = os.path.join(gatherDirectory, pokemonName)
            if os.path.isdir(datasetPath) is False:
                os.mkdir(datasetPath)
            processImage(currentDir, False, datasetPath)

if PROCESS_TEMP_IMAGES is True:
    searchForFiles(tempDirectory)

## Verifiy Data

## Test Images to ensure proper format

In [None]:
def testFile(filePath, printErrors=True) -> bool:
    if os.path.getsize(filePath) == 0 or os.path.isdir(filePath):
        if printErrors:
            print(filePath + " is zero length or is directory, ignoring")
        return False
    elif "afd" in filePath:
        if printErrors:
            print(filePath + " this is garbage file, removing")
        return False
    elif "digimon" in filePath:
        if printErrors:
            print(f"{filePath} is a digimon, ignoring")
        return False
    elif "meganium" not in filePath and "yanmega" not in filePath:
        if INCLUDE_ALT_FORMS is False and ("mega" in filePath or "gigantamax" in filePath or "gmax" in filePath): 
            if printErrors:
                print(f"{filePath} is alt form, ignoring")
            return False
    else:
        #attempt to open file to confirm that it is a valid file
        tmp = Image.open(filePath)
        tmp.load()
        if tmp.format != 'PNG':
            if printErrors:
                print(file + " is not correct format, ignoring")
            return False
        if tmp.n_frames > 1:
            if printErrors:
                print(file + "too many frames, ignoring")
        tmp.close()

        #ensure all images are encoded in the correct format 
        with open(filePath, 'rb') as imageFile:
            if imageFile.read().startswith(b'RIFF'):
                if printErrors:
                    print(file + " isnt right type, ignoring")
                return False
    return True

# testFile('../Datasets/Main/Images/Train/gyarados/pokemon--gyarados.png')

In [None]:
# # listOfPokemonDirs = os.listdir('../Datasets/Main/Images/Train/')
# listOfAllImages = list(paths.list_images(gatherDirectory))
# verifiedFiles = []

# for file in listOfAllImages: 
#     pokemonName =  compiledData.getProperPokemonName(file)
#     gen = compiledData.getPokemonGeneration(pokemonName)
#     if gen in generationsToPrepare: 
#         if testFile(file) is True: 
#             verifiedFiles.append(file)

# #copy verified files to core model directory
# for file in verifiedFiles: 
#     pokemonName = compiledData.getProperPokemonName(file)
#     finalPokemonDir = os.path.join(DIR_MODE_IMAGES, pokemonName)
#     if os.path.isdir(finalPokemonDir) is False: 
#         os.mkdir(finalPokemonDir)
#     shutil.copy2(file, finalPokemonDir)

In [None]:
if PROCESS_TEMP_IMAGES is True:
    listOfPokemonDirs = os.listdir(gatherDirectory)
    for pokemonDir in listOfPokemonDirs:
        #check if the pokemon is in the generation of targeted pokemon 
        pokemonName = compiledData.getProperPokemonName(pokemonDir)
        gen = compiledData.getPokemonGeneration(pokemonName)
        if gen in generationsToPrepare:


            #decide if to copy image, then copy if so 
            pathPokemonDir = os.path.join(gatherDirectory, pokemonDir)
            # pathPokemonDir = os.path.join(pokemonDir, 
            # fileList = os.listdir('../Datasets/Main/Images/Train/gyarados')|
            fileList = os.listdir(pathPokemonDir)
            verifiedList = []
            
            for file in fileList: 
                fullPath = os.path.join(pathPokemonDir, file)
                # fullPath = os.path.join('../Datasets/Main/Images/Train/gyarados', file)
                if testFile(fullPath) is True:
                    verifiedList.append(file)

            for file in fileList:
                finalPokemonDir = os.path.join(DIR_MODEL_IMAGES, pokemonName)
                currPath = os.path.join(pathPokemonDir, file)
                if os.path.isdir(finalPokemonDir) is False: 
                    os.mkdir(finalPokemonDir)
                shutil.copy2(currPath, os.path.join(finalPokemonDir, file))

## Data For Model Verification and Testing

### Prepare Test DS

In [None]:
prepareTestDS = True 
testDS = "./testImages"
if PREPARE_TEST_DS is True: 
    listOfDirs = os.listdir(DIR_MODEL_IMAGES)
    if os.path.isdir(testDS) is False: 
        os.mkdir(testDS)
    dirCounter = 0
    for dir in listOfDirs: 
        counter = 0 
        done = False 
        seed = 32
        imageDir = os.path.join(DIR_MODEL_IMAGES, dir)
        images = os.listdir(imageDir)
        random.seed(seed)
        random.shuffle(images)
        while done is False:
            print(counter)
            sourcePath = os.path.join(imageDir, images[counter])
            destinationPath = os.path.join(testDS, images[counter])
            if (testFile(sourcePath, False)):
                shutil.copy2(sourcePath, destinationPath)
                done = True
            else:
                counter += 1

### Main Dataset Stats

<!-- os.path.listdir -->