In [14]:
import pandas as pd
import os
import numpy as np
from sklearn.metrics import confusion_matrix
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import load_img, img_to_array

## load trained model

* creating central crop samples in this notebook: https://www.kaggle.com/asheniranga/256-256-sorghum-cultivar-pre-process-by-cropping  
* training process done in this notebook: https://www.kaggle.com/code/asheniranga/sorghum-training#Training

In [2]:
model = load_model('../input/sorghum-training/best_checkpoint.hdf5')

In [3]:
model.summary()

In [4]:
df_valid = pd.read_csv('../input/128128-sorghum-cultivar/valid_meta.csv')
df_pred = pd.read_csv('../input/sorghum-id-fgvc-9/sample_submission.csv')

In [9]:
valid_generator = ImageDataGenerator().flow_from_dataframe(dataframe=df_valid,
                                                           directory='../input/128128-sorghum-cultivar/train',
                                                           x_col='image',
                                                           y_col='cultivar',
                                                           batch_size=32,
                                                           target_size=(256,256))

In [37]:
model.evaluate(valid_generator)

In [17]:
class_indices = {"PI_144134": 0, "PI_145619": 1, "PI_145626": 2, "PI_145633": 3, "PI_146890": 4, "PI_152591": 5, 
                 "PI_152651": 6, "PI_152694": 7, "PI_152727": 8, "PI_152728": 9, "PI_152730": 10, "PI_152733": 11, 
                 "PI_152751": 12, "PI_152771": 13, "PI_152816": 14, "PI_152828": 15, "PI_152860": 16, "PI_152862": 17, 
                 "PI_152923": 18, "PI_152961": 19, "PI_152965": 20, "PI_152966": 21, "PI_152967": 22, "PI_152971": 23, 
                 "PI_153877": 24, "PI_154750": 25, "PI_154844": 26, "PI_154846": 27, "PI_154944": 28, "PI_154987": 29, 
                 "PI_154988": 30, "PI_155516": 31, "PI_155760": 32, "PI_155885": 33, "PI_156178": 34, "PI_156217": 35, 
                 "PI_156268": 36, "PI_156326": 37, "PI_156330": 38, "PI_156393": 39, "PI_156463": 40, "PI_156487": 41, 
                 "PI_156871": 42, "PI_156890": 43, "PI_157030": 44, "PI_157035": 45, "PI_157804": 46, "PI_167093": 47, 
                 "PI_170787": 48, "PI_175919": 49, "PI_176766": 50, "PI_179749": 51, "PI_180348": 52, "PI_181080": 53, 
                 "PI_181083": 54, "PI_195754": 55, "PI_196049": 56, "PI_196583": 57, "PI_196586": 58, "PI_196598": 59, 
                 "PI_197542": 60, "PI_19770": 61, "PI_213900": 62, "PI_217691": 63, "PI_218112": 64, "PI_221548": 65, 
                 "PI_221651": 66, "PI_22913": 67, "PI_229841": 68, "PI_251672": 69, "PI_253986": 70, "PI_255239": 71, 
                 "PI_255744": 72, "PI_257599": 73, "PI_257600": 74, "PI_266927": 75, "PI_267573": 76, "PI_273465": 77, 
                 "PI_273969": 78, "PI_276837": 79, "PI_297130": 80, "PI_297155": 81, "PI_297171": 82, "PI_302252": 83, 
                 "PI_303658": 84, "PI_329256": 85, "PI_329286": 86, "PI_329299": 87, "PI_329300": 88, "PI_329301": 89, 
                 "PI_329310": 90, "PI_329319": 91, "PI_329326": 92, "PI_329333": 93, "PI_329338": 94, "PI_329351": 95, 
                 "PI_35038": 96, "PI_52606": 97, "PI_63715": 98, "PI_92270": 99}

In [43]:
sample = df_valid.sample(n=500)
test_generator = ImageDataGenerator().flow_from_dataframe(dataframe=sample,
                                                                     directory='../input/128128-sorghum-cultivar/train',
                                                                     x_col='image',
                                                                     y_col='cultivar',
                                                                     batch_size=32,
                                                                     target_size=(256,256))

In [33]:
test_generator.next()[0].shape

In [36]:
model.evaluate(test_generator)

In [None]:
confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"])

y_true = ["honda", "chevrolet", "honda", "toyota", "toyota", "chevrolet"]
y_pred = ["honda", "chevrolet", "honda", "toyota", "toyota", "honda"]
data = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(data, columns=np.unique(y_true), index = np.unique(y_true))
df_cm.index.name = 'Actual'
df_cm.columns.name = 'Predicted'
plt.figure(figsize = (10,7))
sn.set(font_scale=1.4)#for label size
sn.heatmap(df_cm, cmap="Blues", annot=True,annot_kws={"size": 16})# font size

## inferences

In [18]:
test_preds = []

for i, file in enumerate(os.listdir('../input/sorghum-cultivar-identification-512512/test/')):
    img = img_to_array(load_img(os.path.join('../input/sorghum-cultivar-identification-512512/test/', file), target_size=(256, 256)))/255.0
    img_arr = np.expand_dims(img, axis=0)
    preds = np.argmax(model.predict(img_arr)[0])

    label = list(class_indices.keys())[list(class_indices.values()).index(preds)]

    test_preds.append([file, label])

    print(f'{i + 1}/{len(os.listdir("../input/sorghum-cultivar-identification-512512/test/"))}', end='\r')

In [19]:
test_preds = pd.DataFrame(test_preds, columns=['filename', 'cultivar'])
test_preds.to_csv('submission_11.csv', index=False)

In [22]:
test_preds.sort_values(by='filename')