In [1]:
import os
import numpy as np 
import pandas as pd

import tensorflow as tf
from tensorflow.keras.applications.efficientnet import preprocess_input

In [2]:
df_train = pd.read_csv("data/type_encodings.csv")
print(df_train.shape)
df_train.head()

(8430, 19)


Unnamed: 0,file,type_Bug,type_Dark,type_Dragon,type_Electric,type_Fairy,type_Fighting,type_Fire,type_Flying,type_Ghost,type_Grass,type_Ground,type_Ice,type_Normal,type_Poison,type_Psychic,type_Rock,type_Steel,type_Water
0,0001.png,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0
1,0002.png,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0
2,0003.png,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0
3,0004.png,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0
4,0005.png,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0


In [3]:
image_size = 450
batch_size = 32

In [4]:
def build_decoder(with_labels=True, target_size=(image_size, image_size)):
    def decode(path):
        file_bytes = tf.io.read_file(path)

        image = tf.image.decode_png(file_bytes, channels=3)
        image = tf.cast(image, tf.float32) / 255.0
        image = tf.image.resize(image, target_size)
        image = preprocess_input(image)
        
        return image
    
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    
    return decode_with_labels if with_labels else decode

In [5]:
def build_augmenter(with_labels=True):
    def augment(image):
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_saturation(image, 0.8, 1.2)
        image = tf.image.random_brightness(image, 0.2)
        image = tf.image.random_contrast(image, 0.8, 1.2)
        image = tf.image.random_hue(image, 0.2)
        
        return image
    

    def augment_with_labels(image, label):
        return augment(image), label
    
    
    return augment_with_labels if with_labels else augment

In [6]:
def build_dataset(paths, labels=None, bsize=32, decode_function=None, augment_function=None, augment=True):
    if decode_function is None:
        decode_function = build_decoder(labels is not None)

    if augment_function is None:
        augment_function = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    dataset = tf.data.Dataset.from_tensor_slices(slices)  
    dataset = dataset.map(decode_function, num_parallel_calls=AUTO)
    dataset = dataset.map(augment_function, num_parallel_calls=AUTO) if augment else dataset

    dataset = dataset.batch(bsize).prefetch(AUTO) 

    return dataset

In [7]:
model = tf.keras.models.load_model('model.h5')

In [8]:
load_dir = "data/images/test/"

test_paths = [os.path.join(load_dir, path) for path in os.listdir(load_dir)]

test_decoder = build_decoder(with_labels=False, target_size=(image_size, image_size))

data_test = build_dataset(test_paths, bsize=batch_size, augment=False, decode_function=test_decoder)

In [9]:
y_preds = model.predict(data_test, verbose=1)



In [10]:
df_test = pd.DataFrame([os.listdir(load_dir)] + [y_pred for y_pred in list(zip(*y_preds))], df_train.columns).T

print(df_test.shape)
df_test.head(9)

(151, 19)


Unnamed: 0,file,type_Bug,type_Dark,type_Dragon,type_Electric,type_Fairy,type_Fighting,type_Fire,type_Flying,type_Ghost,type_Grass,type_Ground,type_Ice,type_Normal,type_Poison,type_Psychic,type_Rock,type_Steel,type_Water
0,1-Bulbasaur.png,2.4e-05,0.000104,3.1e-05,6e-06,3.2e-05,7.7e-05,3.7e-05,3.9e-05,0.000164,0.998991,0.000436,0.000508,7.3e-05,0.999847,0.000225,0.000273,9e-06,0.000111
1,10-Caterpie.png,0.999966,5e-06,2.8e-05,7.1e-05,0.000312,0.000981,0.000281,0.001787,5e-06,0.005381,1.1e-05,1e-06,0.000116,0.000434,3.5e-05,0.000314,8e-05,0.000119
2,100-Voltorb.png,0.000103,2.4e-05,0.000366,0.999955,0.000153,2e-06,0.000195,0.000136,4.3e-05,8.9e-05,3.7e-05,3.4e-05,2.9e-05,3.8e-05,0.000215,1.6e-05,0.000239,0.00014
3,101-Electrode.png,1.1e-05,3e-06,6.1e-05,0.999994,5.7e-05,1e-06,7e-06,0.000479,4e-06,0.0,1.1e-05,8e-06,2e-06,1e-06,2.2e-05,3.7e-05,0.0001,4.6e-05
4,102-Exeggcute.png,3e-06,3e-06,0.0,4.1e-05,2.3e-05,1.3e-05,1.2e-05,2e-06,2e-06,0.999115,0.0,0.000356,1.1e-05,0.0,0.999999,6.8e-05,1e-05,4e-06
5,103-Exeggutor.png,2.5e-05,1.6e-05,3e-06,9.9e-05,4.7e-05,1.2e-05,1.9e-05,5.7e-05,9e-06,0.982763,2.2e-05,0.000252,4.8e-05,0.000149,0.99993,0.000127,1.7e-05,2.8e-05
6,104-Cubone.png,0.0,6e-06,1e-05,1e-06,1e-06,3e-06,0.000131,2e-06,3e-06,3e-06,0.999713,9.7e-05,1e-06,0.000947,2.9e-05,6.1e-05,3.7e-05,4e-06
7,105-Marowak.png,8e-06,0.000109,0.000681,4e-06,1e-05,0.00031,7.1e-05,0.000641,2.7e-05,3e-06,0.997944,0.000351,5.3e-05,2.3e-05,0.000158,0.059247,0.000579,0.000119
8,106-Hitmonlee.png,1e-06,3.8e-05,2e-05,1e-06,0.0,0.999991,9.5e-05,3.1e-05,1e-06,0.0,1.2e-05,1e-06,1.2e-05,1e-06,3.1e-05,1.1e-05,3e-06,1.1e-05


In [11]:
type_dict = df_test.to_dict(orient='list')

type_dict['predicted_types'] = []

for i in range(len(type_dict['file'])):
    types = [key for key in type_dict.keys() if key not in ['file', 'predicted_types']]
    type_dict['predicted_types'].append([tpe[5:] for tpe in types if type_dict[tpe][i] > 0.66])
    
type_dict['predicted_types'] = [",".join(row) for row in type_dict['predicted_types']]

df_test = pd.DataFrame.from_dict(type_dict)

cols = df_test.columns.tolist()
cols = [cols[0]] + [cols[-1]] + cols[1:-1]

df_test = df_test[cols]

print(df_test.shape)
df_test.head(9)

(151, 20)


Unnamed: 0,file,predicted_types,type_Bug,type_Dark,type_Dragon,type_Electric,type_Fairy,type_Fighting,type_Fire,type_Flying,type_Ghost,type_Grass,type_Ground,type_Ice,type_Normal,type_Poison,type_Psychic,type_Rock,type_Steel,type_Water
0,1-Bulbasaur.png,"Grass,Poison",2.439883e-05,0.000104,3.052376e-05,5.56704e-06,3.220753e-05,7.726545e-05,3.7e-05,3.9e-05,0.0001639128,0.9989911,0.0004363358,0.000508,7.3e-05,0.9998471,0.000225,0.000273,9e-06,0.000111
1,10-Caterpie.png,Bug,0.9999656,5e-06,2.752723e-05,7.072623e-05,0.0003121793,0.0009814799,0.000281,0.001787,5.232124e-06,0.005381286,1.11185e-05,1e-06,0.000116,0.0004335344,3.5e-05,0.000314,8e-05,0.000119
2,100-Voltorb.png,Electric,0.0001028526,2.4e-05,0.0003656447,0.9999551,0.0001533926,1.997462e-06,0.000195,0.000136,4.318224e-05,8.920381e-05,3.70747e-05,3.4e-05,2.9e-05,3.780896e-05,0.000215,1.6e-05,0.000239,0.00014
3,101-Electrode.png,Electric,1.14057e-05,3e-06,6.064797e-05,0.9999937,5.746422e-05,5.570287e-07,7e-06,0.000479,4.348149e-06,4.955473e-07,1.056937e-05,8e-06,2e-06,1.068524e-06,2.2e-05,3.7e-05,0.0001,4.6e-05
4,102-Exeggcute.png,"Grass,Psychic",2.586301e-06,3e-06,1.012435e-07,4.058241e-05,2.305597e-05,1.270052e-05,1.2e-05,2e-06,2.47104e-06,0.999115,3.930783e-07,0.000356,1.1e-05,2.898599e-07,0.999999,6.8e-05,1e-05,4e-06
5,103-Exeggutor.png,"Grass,Psychic",2.529558e-05,1.6e-05,3.119356e-06,9.925936e-05,4.731451e-05,1.179043e-05,1.9e-05,5.7e-05,9.147397e-06,0.9827628,2.159144e-05,0.000252,4.8e-05,0.0001490116,0.99993,0.000127,1.7e-05,2.8e-05
6,104-Cubone.png,Ground,1.048079e-07,6e-06,9.726838e-06,9.268185e-07,8.066089e-07,3.12324e-06,0.000131,2e-06,2.860408e-06,2.569639e-06,0.9997128,9.7e-05,1e-06,0.000947088,2.9e-05,6.1e-05,3.7e-05,4e-06
7,105-Marowak.png,Ground,7.707275e-06,0.000109,0.0006809533,4.103809e-06,1.018737e-05,0.0003097951,7.1e-05,0.000641,2.72677e-05,3.237711e-06,0.9979439,0.000351,5.3e-05,2.251297e-05,0.000158,0.059247,0.000579,0.000119
8,106-Hitmonlee.png,Fighting,5.148481e-07,3.8e-05,2.032861e-05,5.018714e-07,4.544455e-07,0.9999912,9.5e-05,3.1e-05,6.592793e-07,4.032373e-07,1.16053e-05,1e-06,1.2e-05,1.011798e-06,3.1e-05,1.1e-05,3e-06,1.1e-05


In [12]:
df_test.to_csv("data/predictions.csv", index=False)