In [2]:
import os
import numpy as np 
import pandas as pd

import tensorflow as tf
from tensorflow.keras.applications.efficientnet import preprocess_input

In [3]:
df_train = pd.read_csv("data/type_encodings.csv")
print(df_train.shape)
df_train.head()

(8430, 19)


Unnamed: 0,file,type_Bug,type_Dark,type_Dragon,type_Electric,type_Fairy,type_Fighting,type_Fire,type_Flying,type_Ghost,type_Grass,type_Ground,type_Ice,type_Normal,type_Poison,type_Psychic,type_Rock,type_Steel,type_Water
0,0001.png,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0
1,0002.png,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0
2,0003.png,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0
3,0004.png,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0
4,0005.png,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0


In [4]:
image_size = 450
batch_size = 32

In [5]:
def build_decoder(with_labels=True, target_size=(image_size, image_size)):
    def decode(path):
        file_bytes = tf.io.read_file(path)

        image = tf.image.decode_png(file_bytes, channels=3)
        image = tf.cast(image, tf.float32) / 255.0
        image = tf.image.resize(image, target_size)
        image = preprocess_input(image)
        
        return image
    
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    
    return decode_with_labels if with_labels else decode

In [6]:
def build_augmenter(with_labels=True):
    def augment(image):
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_saturation(image, 0.8, 1.2)
        image = tf.image.random_brightness(image, 0.2)
        image = tf.image.random_contrast(image, 0.8, 1.2)
        image = tf.image.random_hue(image, 0.2)
        
        return image
    

    def augment_with_labels(image, label):
        return augment(image), label
    
    
    return augment_with_labels if with_labels else augment

In [7]:
def build_dataset(paths, labels=None, bsize=32, decode_function=None, augment_function=None, augment=True):
    if decode_function is None:
        decode_function = build_decoder(labels is not None)

    if augment_function is None:
        augment_function = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    dataset = tf.data.Dataset.from_tensor_slices(slices)  
    dataset = dataset.map(decode_function, num_parallel_calls=AUTO)
    dataset = dataset.map(augment_function, num_parallel_calls=AUTO) if augment else dataset

    dataset = dataset.batch(bsize).prefetch(AUTO) 

    return dataset

In [8]:
model = tf.keras.models.load_model('model.h5')

In [9]:
load_dir = "data/images/test/"

test_paths = [os.path.join(load_dir, path) for path in os.listdir(load_dir)]

test_decoder = build_decoder(with_labels=False, target_size=(image_size, image_size))

data_test = build_dataset(test_paths, bsize=batch_size, augment=False, decode_function=test_decoder)

In [10]:
y_preds = model.predict(data_test, verbose=1)



In [35]:
df_test = pd.DataFrame([os.listdir(load_dir)] + [y_pred for y_pred in list(zip(*y_preds))], df_train.columns).T

print(df_test.shape)
df_test.head(9)

(151, 19)


Unnamed: 0,file,type_Bug,type_Dark,type_Dragon,type_Electric,type_Fairy,type_Fighting,type_Fire,type_Flying,type_Ghost,type_Grass,type_Ground,type_Ice,type_Normal,type_Poison,type_Psychic,type_Rock,type_Steel,type_Water
0,0001.png,0.000264,9.6e-05,5.9e-05,4.5e-05,4.5e-05,6.1e-05,2.7e-05,3.8e-05,0.000549,0.999961,3.4e-05,0.000135,3.2e-05,0.999417,5.1e-05,0.000131,1.8e-05,3e-05
1,0002.png,0.000267,0.000376,3.4e-05,0.000149,2.5e-05,6.5e-05,3.1e-05,2.9e-05,0.000473,0.99998,1.5e-05,0.00027,1.5e-05,0.988147,0.000161,6.6e-05,1.4e-05,3e-05
2,0003.png,0.000196,4e-05,3.5e-05,1.2e-05,6.3e-05,2.1e-05,5e-06,2.9e-05,7.6e-05,0.999998,4e-06,0.000143,3e-06,0.997504,6e-05,1e-05,1e-05,2e-06
3,0004.png,5e-06,0.000204,1.4e-05,2.5e-05,3e-06,0.001325,0.999938,2.9e-05,2e-06,1.4e-05,2.2e-05,0.000102,2e-06,6e-06,0.003179,2e-06,1e-06,0.000364
4,0005.png,2.7e-05,0.002298,0.000224,0.000469,2.5e-05,0.000643,0.999985,0.001148,9e-06,1.1e-05,0.001242,7.2e-05,4.2e-05,3e-06,0.000207,4.2e-05,1e-05,2.1e-05
5,0006.png,1e-06,0.000517,0.001545,3e-06,8e-06,0.000289,0.999997,0.989674,1e-06,2e-06,0.000108,7.1e-05,6e-06,3e-06,7.4e-05,0.000137,2e-06,3e-06
6,0007.png,7e-05,6e-06,1e-05,0.0,2e-06,0.000588,6e-06,3.1e-05,1e-06,2e-05,7.6e-05,7.1e-05,1.2e-05,0.000365,2.7e-05,0.000161,1e-06,0.999984
7,0008.png,0.003601,1.6e-05,5.1e-05,7.5e-05,8e-05,7.4e-05,4e-06,0.000298,6e-06,2.2e-05,0.000128,2.1e-05,5.3e-05,5.3e-05,2.5e-05,0.000772,4.5e-05,0.999944
8,0009.png,0.000137,2.6e-05,7.2e-05,3e-06,2.4e-05,0.000251,5e-06,8.7e-05,6e-06,4.5e-05,0.00504,0.001137,0.000193,1.6e-05,0.000369,0.47689,7.4e-05,0.998082


In [36]:
type_dict = df_test.to_dict(orient='list')

type_dict['predicted_types'] = []

for i in range(len(type_dict['file'])):
    types = [key for key in type_dict.keys() if key not in ['file', 'predicted_types']]
    type_dict['predicted_types'].append([tpe[5:] for tpe in types if type_dict[tpe][i] > 0.66])
    
type_dict['predicted_types'] = [",".join(row) for row in type_dict['predicted_types']]

df_test = pd.DataFrame.from_dict(type_dict)

cols = df_test.columns.tolist()
cols = [cols[0]] + [cols[-1]] + cols[1:-1]

df_test = df_test[cols]

print(df_test.shape)
df_test.head(9)

(151, 20)


Unnamed: 0,file,predicted_types,type_Bug,type_Dark,type_Dragon,type_Electric,type_Fairy,type_Fighting,type_Fire,type_Flying,type_Ghost,type_Grass,type_Ground,type_Ice,type_Normal,type_Poison,type_Psychic,type_Rock,type_Steel,type_Water
0,0001.png,"Grass,Poison",0.000264,9.6e-05,5.9e-05,4.536115e-05,4.5e-05,6.1e-05,2.7e-05,3.8e-05,0.000549376,0.999961,3.4e-05,0.000135,3.2e-05,0.999417,5.1e-05,0.000131,1.79313e-05,3e-05
1,0002.png,"Grass,Poison",0.000267,0.000376,3.4e-05,0.0001489222,2.5e-05,6.5e-05,3.1e-05,2.9e-05,0.0004727244,0.99998,1.5e-05,0.00027,1.5e-05,0.988147,0.000161,6.6e-05,1.44174e-05,3e-05
2,0003.png,"Grass,Poison",0.000196,4e-05,3.5e-05,1.194831e-05,6.3e-05,2.1e-05,5e-06,2.9e-05,7.607859e-05,0.999998,4e-06,0.000143,3e-06,0.997504,6e-05,1e-05,9.729082e-06,2e-06
3,0004.png,Fire,5e-06,0.000204,1.4e-05,2.524184e-05,3e-06,0.001325,0.999938,2.9e-05,2.018253e-06,1.4e-05,2.2e-05,0.000102,2e-06,6e-06,0.003179,2e-06,7.130916e-07,0.000364
4,0005.png,Fire,2.7e-05,0.002298,0.000224,0.0004691482,2.5e-05,0.000643,0.999985,0.001148,9.429445e-06,1.1e-05,0.001242,7.2e-05,4.2e-05,3e-06,0.000207,4.2e-05,1.026122e-05,2.1e-05
5,0006.png,"Fire,Flying",1e-06,0.000517,0.001545,3.31172e-06,8e-06,0.000289,0.999997,0.989674,1.208868e-06,2e-06,0.000108,7.1e-05,6e-06,3e-06,7.4e-05,0.000137,1.621698e-06,3e-06
6,0007.png,Water,7e-05,6e-06,1e-05,2.593269e-07,2e-06,0.000588,6e-06,3.1e-05,5.730649e-07,2e-05,7.6e-05,7.1e-05,1.2e-05,0.000365,2.7e-05,0.000161,5.019169e-07,0.999984
7,0008.png,Water,0.003601,1.6e-05,5.1e-05,7.499216e-05,8e-05,7.4e-05,4e-06,0.000298,5.753751e-06,2.2e-05,0.000128,2.1e-05,5.3e-05,5.3e-05,2.5e-05,0.000772,4.495028e-05,0.999944
8,0009.png,Water,0.000137,2.6e-05,7.2e-05,3.092974e-06,2.4e-05,0.000251,5e-06,8.7e-05,6.003422e-06,4.5e-05,0.00504,0.001137,0.000193,1.6e-05,0.000369,0.47689,7.369825e-05,0.998082


In [32]:
df_test.to_csv("data/predictions.csv", index=False)