In [2]:
import os
import logging
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
import requests
from time import sleep

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

options = webdriver.ChromeOptions()
driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=options)

# Function to get loaded images
def get_loaded_images():
    return driver.find_elements(By.CLASS_NAME, 'CoverImage')

# Function to download images
def download_images(url, download_folder, scroll_pause_time=50):
    driver.get(url)
    
    os.makedirs(download_folder, exist_ok=True)
    previous_image_count = 0

    while True:
        driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
        sleep(scroll_pause_time)

        current_images = get_loaded_images()
        current_image_count = len(current_images)

        if current_image_count == previous_image_count:
            break
        else:
            previous_image_count = current_image_count

    logger.info(f"Total images loaded: {current_image_count}")

    for i, image_div in enumerate(current_images):
        try:
            style_attribute = image_div.get_attribute('style')
            background_image_url = style_attribute.split('"')[1]

            response = requests.get(background_image_url)
            with open(os.path.join(download_folder, f"image_{i}.jpg"), 'wb') as img_file:
                img_file.write(response.content)

            logger.info(f"Downloaded image {i+1}")

        except Exception as e:
            logger.error(f"Error downloading image {i+1}: {e}")

# Download images for Galerina marginata
galerina_url = "https://www.inaturalist.org/taxa/154735-Galerina-marginata/browse_photos"
galerina_folder = "/Users/katka/Desktop/non_eatable_mushrooms"
download_images(galerina_url, galerina_folder, scroll_pause_time=50)

# Download images for Psilocybe cubensis
psilocybe_url = "https://www.inaturalist.org/taxa/328244-Psilocybe-cubensis/browse_photos"
psilocybe_folder = "/Users/katka/Desktop/psilocybe_images"
download_images(psilocybe_url, psilocybe_folder, scroll_pause_time=15)

driver.quit()


INFO:WDM:Get LATEST chromedriver version for google-chrome
INFO:WDM:Get LATEST chromedriver version for google-chrome
INFO:WDM:Driver [/Users/katka/.wdm/drivers/chromedriver/mac64/125.0.6422.141/chromedriver-mac-x64/chromedriver] found in cache
INFO:__main__:Total images loaded: 49
INFO:__main__:Downloaded image 1
INFO:__main__:Downloaded image 2
INFO:__main__:Downloaded image 3
INFO:__main__:Downloaded image 4
INFO:__main__:Downloaded image 5
INFO:__main__:Downloaded image 6
INFO:__main__:Downloaded image 7
INFO:__main__:Downloaded image 8
INFO:__main__:Downloaded image 9
INFO:__main__:Downloaded image 10
INFO:__main__:Downloaded image 11
INFO:__main__:Downloaded image 12
INFO:__main__:Downloaded image 13
INFO:__main__:Downloaded image 14
INFO:__main__:Downloaded image 15
INFO:__main__:Downloaded image 16
INFO:__main__:Downloaded image 17
INFO:__main__:Downloaded image 18
INFO:__main__:Downloaded image 19
INFO:__main__:Downloaded image 20
INFO:__main__:Downloaded image 21
INFO:__main_

In [3]:
import os
import random
import shutil
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, GlobalAveragePooling2D
import numpy as np


random.seed(42)
tf.random.set_seed(42)


edible_dir = "/Users/katka/Desktop/eatable_mushrooms"
non_edible_dir = "/Users/katka/Desktop/non_eatable_mushrooms"
new_images_dir = "/Users/katka/Desktop/new_mushrooms"

train_dir = "/Users/katka/Desktop/train"
validation_dir = "/Users/katka/Desktop/validation"


categories = ['edible', 'non_edible']
for category in categories:
    os.makedirs(os.path.join(train_dir, category), exist_ok=True)
    os.makedirs(os.path.join(validation_dir, category), exist_ok=True)

def split_and_move_images(source_dir, train_dest_dir, val_dest_dir):
    images = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]
    random.shuffle(images)
    split_index = int(0.8 * len(images))  
    train_images = images[:split_index]
    validation_images = images[split_index:]

    for img in train_images:
        shutil.copy(os.path.join(source_dir, img), os.path.join(train_dest_dir, img))
    for img in validation_images:
        shutil.copy(os.path.join(source_dir, img), os.path.join(val_dest_dir, img))

split_and_move_images(edible_dir, os.path.join(train_dir, 'edible'), os.path.join(validation_dir, 'edible'))

split_and_move_images(non_edible_dir, os.path.join(train_dir, 'non_edible'), os.path.join(validation_dir, 'non_edible'))

img_size = (224, 224)
batch_size = 32

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

validation_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary'
)

validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary'
)


train_edible_count = len(os.listdir(os.path.join(train_dir, 'edible')))
train_non_edible_count = len(os.listdir(os.path.join(train_dir, 'non_edible')))
val_edible_count = len(os.listdir(os.path.join(validation_dir, 'edible')))
val_non_edible_count = len(os.listdir(os.path.join(validation_dir, 'non_edible')))

print(f"Training set - Edible: {train_edible_count}, Non-edible: {train_non_edible_count}")
print(f"Validation set - Edible: {val_edible_count}, Non-edible: {val_non_edible_count}")

base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
                                               include_top=False,
                                               weights='imagenet')

base_model.trainable = False


model = Sequential([
    base_model,
    GlobalAveragePooling2D(),
    Dense(1, activation='sigmoid')
])


model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss='binary_crossentropy',
              metrics=['accuracy'])


history = model.fit(
    train_generator,
    epochs=10,
    validation_data=validation_generator
)


evaluation = model.evaluate(validation_generator)
print("Validation Accuracy:", evaluation[1])

model.save("/Users/katka/mushroom_classifier_model.h5")


def predict_new_images(new_images_dir, model_path, img_size):
    new_images = os.listdir(new_images_dir)
    model = load_model(model_path)
    
    for image_file in new_images:
        image_path = os.path.join(new_images_dir, image_file)
        img = tf.keras.preprocessing.image.load_img(image_path, target_size=img_size)
        img_array = tf.keras.preprocessing.image.img_to_array(img)
        img_array = np.expand_dims(img_array, 0)  


        predictions = model.predict(img_array)
        predicted_class = "eatable" if predictions[0][0] < 0.5 else "non-eatable"
        
        print(f"Image: {image_file}, Predicted Class: {predicted_class}")


predict_new_images("/Users/katka/Desktop/new_mushrooms", "/Users/katka/mushroom_classifier_model.h5", (224, 224))


2024-06-02 10:01:02.235715: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO:numexpr.utils:NumExpr defaulting to 4 threads.


Found 2321 images belonging to 2 classes.
Found 697 images belonging to 2 classes.
Training set - Edible: 1294, Non-edible: 1031
Validation set - Edible: 324, Non-edible: 374
Epoch 1/10


  self._warn_if_super_not_called()


[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m168s[0m 2s/step - accuracy: 0.4871 - loss: 0.8558 - val_accuracy: 0.5854 - val_loss: 0.6608
Epoch 2/10
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m190s[0m 3s/step - accuracy: 0.5903 - loss: 0.6700 - val_accuracy: 0.7016 - val_loss: 0.5763
Epoch 3/10
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m166s[0m 2s/step - accuracy: 0.6776 - loss: 0.5931 - val_accuracy: 0.7575 - val_loss: 0.5231
Epoch 4/10
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m163s[0m 2s/step - accuracy: 0.7506 - loss: 0.5301 - val_accuracy: 0.8006 - val_loss: 0.4777
Epoch 5/10
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m133s[0m 2s/step - accuracy: 0.7700 - loss: 0.5046 - val_accuracy: 0.8164 - val_loss: 0.4460
Epoch 6/10
[1m73/73[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m129s[0m 2s/step - accuracy: 0.8045 - loss: 0.4558 - val_accuracy: 0.8336 - val_loss: 0.4189
Epoch 7/10
[1m73/73[0m [32m━━━━━━━━━━━━━━━



Validation Accuracy: 0.8751793503761292




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
Image: images (2).jpeg, Predicted Class: eatable
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step
Image: 3-s2.0-B9780128002124000741-f74-01-9780128002124.jpg, Predicted Class: eatable
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 65ms/step
Image: images.jpeg, Predicted Class: eatable
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 89ms/step
Image: images (1).jpeg, Predicted Class: eatable
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step
Image: 408678.jpeg, Predicted Class: eatable
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 65ms/step
Image: unnamed.jpeg, Predicted Class: eatable
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 71ms/step
Image: 4e6ed0dc-4af088c5-8d88-cc13ba79.jpeg, Predicted Class: eatable
