In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
import sys
import seaborn as sns
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from sklearn.metrics import roc_curve, auc
import matplotlib.ticker as ticker



sys.path.append('..')

# Set the path to your images_reshaped directory
base_path = '../images_reshaped'

# Initialize lists to store data
data = []

# Iterate through the directory structure
for category in ['deadly', 'edible', 'poisonous', 'conditionally_edible']:
    category_path = os.path.join(base_path, category)
    for species_folder in os.listdir(category_path):
        species_path = os.path.join(category_path, species_folder)
        if os.path.isdir(species_path):
            for image_file in os.listdir(species_path):
                if image_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    image_path = os.path.join(species_path, image_file)
                    
                    # Load and preprocess the image
                    img = Image.open(image_path)
                    img_array = np.array(img) / 255.0  # Normalize to [0, 1]
                    
                    data.append({
                        'image_path': image_path,
                        'category': category,
                        'species': species_folder,
                        'image': img_array
                    })

# Create the DataFrame
df = pd.DataFrame(data)

# Add the 'edible' column
df['edible'] = (df['category'] == 'edible').astype(int)

# Encode categories and species
le_category = LabelEncoder()
le_species = LabelEncoder()
df['category_encoded'] = le_category.fit_transform(df['category'])
df['species_encoded'] = le_species.fit_transform(df['species'])

# Save the DataFrame without the 'image' column
df_save = df.drop(columns=['image'])
df_save.to_pickle('mushroom_metadata.pkl')

# Save the image data separately
np.save('mushroom_images.npy', np.stack(df['image'].values))

print("Data preprocessing completed.")
print(f"Metadata saved as 'mushroom_metadata.pkl'.")
print(f"Image data saved as 'mushroom_images.npy'.")
print(f"Total images processed: {len(df)}")
print(f"Number of edible mushrooms: {df['edible'].sum()}")
print(f"Number of non-edible mushrooms: {len(df) - df['edible'].sum()}")
print(f"Number of unique species: {df['species'].nunique()}")
print(f"Image shape: {df['image'].iloc[0].shape}")

In [None]:
df.head()

In [None]:
images = np.load('mushroom_images.npy')
labels = np.load('mushroom_metadata.pkl', allow_pickle=True)['edible']  # Assume this is a binary array (0 for non-edible, 1 for edible)

X = images
y = labels

X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)

X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.2, stratify=y_train_val, random_state=42)

In [None]:
def build_model(learning_rate=0.001, dense_units=224):
    base_model = ResNet50V2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    
    # Freeze the base model layers
    for layer in base_model.layers:
        layer.trainable = False
    
    x = GlobalAveragePooling2D()(base_model.output)
    x = Dense(dense_units, activation='relu')(x)
    output = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=base_model.input, outputs=output)
    
    model.compile(optimizer=Adam(learning_rate=learning_rate),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    
    return model

model = build_model(learning_rate=0.001, dense_units=224)

callbacks = [
    EarlyStopping(patience=5, restore_best_weights=True),
    ReduceLROnPlateau(factor=0.5, patience=3)
]

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=50,
    batch_size=32,
    callbacks=callbacks,
)


In [None]:
test_loss, test_accuracy = model.evaluate(X_test, y_test)
print(f"Test accuracy: {test_accuracy:.4f}")

model.save('mushroom_classification_model_v2.h5')

In [None]:
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], marker='o')
plt.plot(history.history['val_accuracy'], marker='o')
plt.title('Model Accuracy Over Epochs')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='lower right')

# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], marker='o')
plt.plot(history.history['val_loss'], marker='o')
plt.title('Model Loss Over Epochs')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')

plt.tight_layout()
plt.show()


In [None]:
model = load_model('mushroom_classification_model_v2.h5')

print("Model loaded successfully.")


In [None]:
images = np.load('mushroom_images.npy')
df_metadata = pd.read_pickle('mushroom_metadata.pkl')

y = df_metadata['edible'].values


In [None]:
indices = np.arange(len(images))

idx_train_val, idx_test, y_train_val, y_test = train_test_split(
    indices, y, test_size=0.2, stratify=y, random_state=42)

idx_train, idx_val, y_train, y_val = train_test_split(
    idx_train_val, y_train_val, test_size=0.2, stratify=y_train_val, random_state=42)

X_train = images[idx_train]
X_val = images[idx_val]
X_test = images[idx_test]


In [None]:
y_pred_probs = model.predict(X_test).flatten()

y_pred = (y_pred_probs > 0.5).astype(int)


In [None]:
df_test = df_metadata.iloc[idx_test].reset_index(drop=True)

df_results = df_test.copy()
df_results['y_true'] = y_test
df_results['y_pred'] = y_pred
df_results['y_pred_prob'] = y_pred_probs


In [None]:
fpr, tpr, thresholds = roc_curve(y_test, y_pred_probs)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, color='darkorange',
         lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([-0.05, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show()


In [None]:
plt.figure(figsize=(10, 6))
plt.scatter(np.arange(len(y_test)), y_pred_probs, c=y_test, cmap='bwr', alpha=0.7)
plt.xlabel('Sample Index')
plt.ylabel('Predicted Probability of Edibility')
plt.title('Predicted Probabilities Colored by True Label')
plt.colorbar(label='True Label (0=Non-Edible, 1=Edible)')
plt.show()


In [None]:
plt.figure(figsize=(12, 6))
sns.boxplot(x='category', y='y_pred_prob', data=df_results)
plt.title('Predicted Probabilities per Category')
plt.ylabel('Predicted Probability of Edibility')
plt.xlabel('Category')
plt.xticks(rotation=45)
plt.show()


In [None]:
categories = ['deadly', 'edible', 'poisonous', 'conditionally_edible']

category_species_performance = {}

for category in categories:
    df_category = df_results[df_results['category'] == category]
    
    species_performance = df_category.groupby('species').apply(
        lambda x: pd.Series({
            'total_samples': len(x),
            'correct_predictions': (x['y_true'] == x['y_pred']).sum(),
            'incorrect_predictions': (x['y_true'] != x['y_pred']).sum(),
            'accuracy': (x['y_true'] == x['y_pred']).mean()
        })
    ).reset_index()
    
    species_performance.sort_values('incorrect_predictions', ascending=False, inplace=True)
    
    category_species_performance[category] = species_performance

print("Calculated per-species performance for each category.")



In [None]:
for category in categories:
    species_performance = category_species_performance[category]
    
    species_with_errors = species_performance[species_performance['incorrect_predictions'] > 0]
    
    if species_with_errors.empty:
        print(f"No misclassifications in category: {category}")
        continue  # Skip to the next category if none
    
    species_with_errors.sort_values('incorrect_predictions', ascending=False, inplace=True)
    
    plt.figure(figsize=(12, 6))
    sns.barplot(
        x='incorrect_predictions',
        y='species',
        data=species_with_errors,
        palette='viridis',
        order=species_with_errors['species']
    )
    plt.title(f'Misclassifications Per Species in Category: {category}, for species that had incorrect predictions')
    plt.xlabel('Number of Incorrect Predictions')
    plt.ylabel('Species')
    plt.gca().xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
    plt.tight_layout()
    plt.show()

