In [None]:
import pandas as pd
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import LabelEncoder
from collections import Counter

# df = pd.read_csv('your_data.csv')  # Replace with the correct path to your CSV file

# Ensure the column names are correct by removing any extra spaces (if any)
df.columns = df.columns.str.strip()

# Print column names to verify
print("Columns in dataset:", df.columns)

# Check if 'scientificName' column exists and use it as the target
if 'scientificName' not in df.columns:
    print("Error: 'scientificName' column not found in the dataset.")
else:
    # Assuming 'scientificName' is the target column in your dataset
    X = df.drop(columns=['scientificName'])  # Features
    y = df['scientificName']  # Target variable

    # Display the number of samples in each class before SMOTE
    print("Original class distribution:", Counter(y))

    # Optionally, filter out classes with fewer than 3 samples if they are too rare
    min_class_size = 3  # Threshold for minimum samples in each class
    y_counts = y.value_counts()
    valid_classes = y_counts[y_counts >= min_class_size].index

    # Filter the dataset to only include these valid classes
    X_filtered = X[y.isin(valid_classes)]
    y_filtered = y[y.isin(valid_classes)]

    # Remove non-numeric columns for model training (e.g., identifiers, textual columns)
    X_filtered = X_filtered.select_dtypes(include=['number'])  # Only keep numeric columns

    # Handle categorical features (if any)
    # Encode categorical features using LabelEncoder
    categorical_cols = X_filtered.select_dtypes(include=['object']).columns
    le = LabelEncoder()

    for col in categorical_cols:
        X_filtered[col] = le.fit_transform(X_filtered[col].astype(str))

    # Handle class imbalance using SMOTE with adjusted n_neighbors
    smote = SMOTE(random_state=42, k_neighbors=2)  # Decrease k_neighbors to 2 for small classes
    X_res, y_res = smote.fit_resample(X_filtered, y_filtered)

    # Display the class distribution after SMOTE
    print("Resampled class distribution:", Counter(y_res))

    # Compute class weights for the resampled data
    class_weights = compute_class_weight('balanced', classes=y_res.unique(), y=y_res)
    class_weight_dict = dict(zip(y_res.unique(), class_weights))

    # Train a model, e.g., RandomForestClassifier, with cross-validation
    clf = RandomForestClassifier(class_weight='balanced', random_state=42)

    # Perform cross-validation
    cross_val_scores = cross_val_score(clf, X_res, y_res, cv=5)

    # Print cross-validation accuracy scores
    print(f"Cross-validation accuracy scores: {cross_val_scores}")
    print(f"Mean cross-validation accuracy: {cross_val_scores.mean()}")

    # Train the classifier on the entire resampled dataset and generate predictions
    clf.fit(X_res, y_res)

    # Predict on the resampled data
    y_pred = clf.predict(X_res)

    # Generate a classification report
    print("Classification Report:")
    print(classification_report(y_res, y_pred))


In [None]:
pip install shap

In [None]:
import matplotlib.pyplot as plt

# Create a much taller figure to accommodate all class labels
plt.figure(figsize=(10, 16))  # Increase height significantly

# Create the SHAP summary plot
shap.summary_plot(shap_values, X_res, plot_type="bar", show=False)

# Get the legend and modify its properties
leg = plt.gca().get_legend()
if leg:
    # Increase the vertical spacing between legend entries
    leg.set_bbox_to_anchor((1.05, 1))  # Move legend further right
    leg._ncol = 1  # Ensure only one column of legend items
    
    # Adjust the spacing between legend items
    for t in leg.get_texts():
        t.set_y(1.5 * t.get_position()[1])  # Increase vertical spacing between text elements

# Save with high resolution
plt.savefig("shap_summary_plot.png", dpi=300, bbox_inches="tight")
plt.show()