In [1]:
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np

import warnings
warnings.filterwarnings('ignore')

tf.get_logger().setLevel("ERROR")




### Mushroom species in database

In [2]:
mushroom_species = ['Agaricus augustus', 'Agaricus xanthodermus', 'Amanita amerirubescens', 'Amanita augusta', 'Amanita brunnescens', 'Amanita calyptroderma', 'Amanita citrina', 'Amanita flavoconia', 'Amanita muscaria', 'Amanita pantherina', 'Amanita persicina', 'Amanita phalloides', 'Amanita rubescens', 'Amanita velosa', 'Apioperdon pyriforme', 'Armillaria borealis', 'Armillaria mellea', 'Armillaria tabescens', 'Artomyces pyxidatus', 'Bjerkandera adusta', 'Bolbitius titubans', 'Boletus edulis', 'Boletus pallidus', 'Boletus reticulatus', 'Boletus rex-veris', 'Calocera viscosa', 'Calycina citrina', 'Cantharellus californicus', 'Cantharellus cibarius', 'Cantharellus cinnabarinus', 'Cerioporus squamosus', 'Cetraria islandica', 'Chlorociboria aeruginascens', 'Chlorophyllum brunneum', 'Chlorophyllum molybdites', 'Chondrostereum purpureum', 'Cladonia fimbriata', 'Cladonia rangiferina', 'Cladonia stellaris', 'Clitocybe nebularis', 'Clitocybe nuda', 'Coltricia perennis', 'Coprinellus disseminatus', 'Coprinellus micaceus', 'Coprinopsis atramentaria', 'Coprinopsis lagopus', 'Coprinus comatus', 'Crucibulum laeve', 'Cryptoporus volvatus', 'Daedaleopsis confragosa', 'Daedaleopsis tricolor', 'Entoloma abortivum', 'Evernia mesomorpha', 'Evernia prunastri', 'Flammulina velutipes', 'Fomes fomentarius', 'Fomitopsis betulina', 'Fomitopsis mounceae', 'Fomitopsis pinicola', 'Galerina marginata', 'Ganoderma applanatum', 'Ganoderma curtisii', 'Ganoderma oregonense', 'Ganoderma tsugae', 'Gliophorus psittacinus', 'Gloeophyllum sepiarium', 'Graphis scripta', 'Grifola frondosa', 'Gymnopilus luteofolius', 'Gyromitra esculenta', 'Gyromitra gigas', 'Gyromitra infula', 'Hericium coralloides', 'Hericium erinaceus', 'Hygrophoropsis aurantiaca', 'Hypholoma fasciculare', 'Hypholoma lateritium', 'Hypogymnia physodes', 'Hypomyces lactifluorum', 'Imleria badia', 'Inonotus obliquus', 'Ischnoderma resinosum', 'Kuehneromyces mutabilis', 'Laccaria ochropurpurea', 'Lactarius deliciosus', 'Lactarius torminosus', 'Lactarius turpis', 'Laetiporus sulphureus', 'Leccinum albostipitatum', 'Leccinum aurantiacum', 'Leccinum scabrum', 'Leccinum versipelle', 'Lepista nuda', 'Leratiomyces ceres', 'Leucoagaricus americanus', 'Leucoagaricus leucothites', 'Lobaria pulmonaria', 'Lycogala epidendrum', 'Lycoperdon perlatum', 'Lycoperdon pyriforme', 'Macrolepiota procera', 'Merulius tremellosus', 'Mutinus ravenelii', 'Mycena haematopus', 'Mycena leaiana', 'Nectria cinnabarina', 'Omphalotus illudens', 'Omphalotus olivascens', 'Panaeolus papilionaceus', 'Panellus stipticus', 'Parmelia sulcata', 'Paxillus involutus', 'Peltigera aphthosa', 'Peltigera praetextata', 'Phaeolus schweinitzii', 'Phaeophyscia orbicularis', 'Phallus impudicus', 'Phellinus igniarius', 'Phellinus tremulae', 'Phlebia radiata', 'Phlebia tremellosa', 'Pholiota aurivella', 'Pholiota squarrosa', 'Phyllotopsis nidulans', 'Physcia adscendens', 'Platismatia glauca', 'Pleurotus ostreatus', 'Pleurotus pulmonarius', 'Psathyrella candolleana', 'Pseudevernia furfuracea', 'Pseudohydnum gelatinosum', 'Psilocybe azurescens', 'Psilocybe caerulescens', 'Psilocybe cubensis', 'Psilocybe cyanescens', 'Psilocybe ovoideocystidiata', 'Psilocybe pelliculosa', 'Retiboletus ornatipes', 'Rhytisma acerinum', 'Sarcomyxa serotina', 'Sarcoscypha austriaca', 'Sarcosoma globosum', 'Schizophyllum commune', 'Stereum hirsutum', 'Stereum ostrea', 'Stropharia aeruginosa', 'Stropharia ambigua', 'Suillus americanus', 'Suillus granulatus', 'Suillus grevillei', 'Suillus luteus', 'Suillus spraguei', 'Tapinella atrotomentosa', 'Trametes betulina', 'Trametes gibbosa', 'Trametes hirsuta', 'Trametes ochracea', 'Trametes versicolor', 'Tremella mesenterica', 'Trichaptum biforme', 'Tricholoma murrillianum', 'Tricholomopsis rutilans', 'Tylopilus felleus', 'Tylopilus rubrobrunneus', 'Urnula craterium', 'Verpa bohemica', 'Volvopluteus gloiocephalus', 'Vulpicida pinastri', 'Xanthoria parietina']

### Loading the CNN Model

In [3]:
model_loaded = load_model("../App/Model/mushroomCNNclasifier.h5")
print("Model loaded successfully!")

Model loaded successfully!


### Getting predictions from CNN Model

In [4]:
def get_predictions(img_path: str, top_k: int = 3):
    # Prepare image to prediction
    img = image.load_img(img_path, target_size=(224, 224))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)

    # Predict
    predictions = model_loaded.predict(img_array)
    predictions_flat = predictions.ravel()

    # Top k predictions
    indexes = np.argpartition(predictions_flat, -top_k)[-top_k:]
    values = predictions_flat[indexes]

    # Sort descending
    sorted = np.argsort(values)[::-1]
    indexes, values = indexes[sorted], values[sorted]
    species = [mushroom_species[i] for i in indexes]

    predictions_dict = dict(zip(species, values))

    return predictions_dict

### Calculate predictions if multiple photos were taken

In [None]:
def calculate_predictions(predictions):
    # num_pred = len(predictions)
    all_species = set()
    for pred in predictions:
        all_species.update(pred.keys())

    ensemble = {}
    for species in all_species:
        confidences = [pred.get(species, 0.0) for pred in predictions]
        ensemble[species] = np.mean(confidences)

    return ensemble

In [None]:
predictions1 = {
    "Boletus edulis": 0.87,
    "Boletus aereus": 0.10,
    "Boletus reticulatus": 0.01
}

predictions2 = {
    "Boletus edulis": 0.95,
    "Boletus aereus": 0.04,
    "Agaricus campestris": 0.01
}

predictions3 = {
    "Boletus edulis": 0.93,
    "Agaricus campestris": 0.07,
    "Boletus reticulatus": 0.01
}
predictions = [predictions1, predictions2, predictions3]
pred = calculate_predictions(predictions)

In [None]:
# def calculate_predictions(predictions):
#
#     num_pred = len(predictions)
#
#     all_species = set()
#     for pred in predictions:
#         all_species.update(pred.keys())
#
#     ensemble = {}
#     for species in all_species:
#         confidences = [pred.get(species, 0.0) for pred in predictions]
#         ensemble[species] = np.mean(confidences)
#
#     ensemble_sorted = sorted(ensemble.items(), key=lambda x: x[1], reverse=True)
#     primary_species, primary_confidence = ensemble_sorted[0]
#
#     # top_species = [max(pred, key=pred.get) for pred in predictions]
#     # agreement_count = top_species.count(primary_species)
#     # agreement_ratio = agreement_count / num_pred
#     #
#     # if agreement_ratio == 1.0:
#     #     primary_confidence = primary_confidence * 1.1
#
#     ensemble[primary_species] = primary_confidence
#     return ensemble
#
# predictions1 = {
#     "Boletus edulis": 0.87,
#     "Boletus aereus": 0.10,
#     "Boletus reticulatus": 0.01
# }
#
# predictions2 = {
#     "Boletus edulis": 0.95,
#     "Boletus aereus": 0.04,
#     "Agaricus campestris": 0.01
# }
#
# predictions3 = {
#     "Boletus edulis": 0.93,
#     "Agaricus campestris": 0.07,
#     "Boletus reticulatus": 0.01
# }
# predictions = [predictions1, predictions2, predictions3]
# pred = calculate_predictions(predictions)