In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import euclidean
import numpy as np
from PIL import Image
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.models import Model
from tensorflow.keras.applications.efficientnet import preprocess_input

# Load inference model
base_model = EfficientNetB3(include_top=False, input_shape=(300, 300, 3), weights='imagenet', pooling='avg')
model = Model(inputs=base_model.input, outputs=base_model.output)

# Load stats
mean_vec = np.load("mean_vec.npy")
inv_cov_matrix = np.load("inv_cov_matrix.npy")
soil_features = np.load("soil_features.npy")

# Compute Mahalanobis distance
def mahalanobis_distance(x, mean_vec, inv_cov_matrix):
    delta = x - mean_vec
    return np.sqrt(np.dot(np.dot(delta, inv_cov_matrix), delta.T))

# Compute thresholds from training features
mahal_distances = np.array([mahalanobis_distance(f, mean_vec, inv_cov_matrix) for f in soil_features])
mahal_thresh = np.percentile(mahal_distances, 95)

cosine_similarities = cosine_similarity([mean_vec], soil_features)[0]
cosine_thresh = np.percentile(cosine_similarities, 5)  # Lower means more anomalous

euclid_dists = np.array([euclidean(f, mean_vec) for f in soil_features])
euclid_thresh = np.percentile(euclid_dists, 95)

# Inference function
def classify_image(img_path):
    img = Image.open(img_path).convert("RGB").resize((300, 300))
    img_array = np.array(img) / 255.0
    img_array = preprocess_input(img_array)
    img_array = np.expand_dims(img_array, axis=0)
    feature = model.predict(img_array, verbose=0).squeeze()

    m_dist = mahalanobis_distance(feature, mean_vec, inv_cov_matrix)
    c_sim = cosine_similarity([feature], [mean_vec])[0][0]
    e_dist = euclidean(feature, mean_vec)

    is_soil = (m_dist < mahal_thresh) and (c_sim > cosine_thresh) and (e_dist < euclid_thresh)
    return "SOIL" if is_soil else "NON-SOIL", {
        "mahalanobis": m_dist,
        "cosine_similarity": c_sim,
        "euclidean": e_dist
    }

# Example
label, metrics = classify_image("test_images/sample.jpg")
print(f"Predicted Label: {label}")
print("Metrics:", metrics)
