In [3]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
try:
    from skimage.feature.texture import greycomatrix, greycoprops
except ImportError:
    from skimage.feature import greycomatrix, greycoprops
from skimage.filters import threshold_otsu
from scipy.stats import skew, kurtosis
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import pairwise_distances
from typing import List, Tuple

# --------------------------
# Feature extraction utils
# --------------------------

def read_image_rgb(path: str):
    """Read image using OpenCV and return as RGB uint8 array."""
    bgr = cv2.imread(path, cv2.IMREAD_COLOR)
    if bgr is None:
        raise FileNotFoundError(f"Cannot read image: {path}")
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    return rgb

def extract_hsv_histogram(img_rgb: np.ndarray, bins=(16,8,8)) -> np.ndarray:
    """
    Compute normalized HSV histogram (flattened).
    bins: tuple (H_bins, S_bins, V_bins)
    """
    hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV)
    hist = cv2.calcHist([hsv], [0,1,2], None, bins,
                        [0,180,0,256,0,256])  # OpenCV H range 0-180
    hist = hist.flatten().astype('float32')
    # L1 normalize
    s = hist.sum()
    if s > 0:
        hist /= s
    return hist

def extract_hsv_stats(img_rgb: np.ndarray) -> np.ndarray:
    """
    Compute mean, std, skewness, kurtosis for each HSV channel.
    Returns 12 values (4 stats × 3 channels).
    """
    hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV).astype('float32')
    stats = []
    for ch in range(3):
        vals = hsv[:,:,ch].flatten()
        stats.append(vals.mean())
        stats.append(vals.std(ddof=0))
        stats.append(skew(vals))
        stats.append(kurtosis(vals))
    return np.array(stats, dtype='float32')

def extract_glcm_features(img_rgb: np.ndarray, distances=[1], levels=16) -> np.ndarray:
    """
    Compute GLCM features averaged over angles (0,45,90,135).
    Returns contrast, correlation, energy, homogeneity (4 values).
    We quantize grayscale image to `levels` (e.g., 16).
    """
    gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
    # Quantize to levels (0..levels-1)
    if gray.dtype != np.uint8:
        gray = gray.astype('uint8')
    q = (gray.astype('float32') * (levels - 1) / 255.0).round().astype('uint8')
    # angles in radians
    angles = [0, np.pi/4, np.pi/2, 3*np.pi/4]
    # compute glcm (skimage expects values in [0, levels-1])
    # Use symmetric and normed so values are comparable
    G = greycomatrix(q, distances=distances, angles=angles, levels=levels,
                     symmetric=True, normed=True)
    # props to extract
    props = []
    for prop in ['contrast', 'correlation', 'energy', 'homogeneity']:
        p = greycoprops(G, prop)  # shape (len(distances), len(angles))
        # average across distances and angles
        props.append(p.mean())
    return np.array(props, dtype='float32')

def extract_shape_features(img_rgb: np.ndarray) -> np.ndarray:
    """
    Compute shape features based on Otsu threshold:
    - area (largest contour)
    - perimeter
    - aspect_ratio (w/h of bounding box)
    - extent (area / bbox_area)
    - solidity (area / convex_hull_area)
    - circularity (4*pi*area / perimeter^2)
    Returns 6 values.
    """
    gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
    # Otsu threshold
    try:
        th = threshold_otsu(gray)
        _, bw = cv2.threshold(gray, int(th), 255, cv2.THRESH_BINARY)
    except Exception:
        # fallback to Otsu via OpenCV
        _, bw = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # find contours
    contours, _ = cv2.findContours(bw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        # return zeros if nothing found
        return np.zeros(6, dtype='float32')

    # pick largest contour by area
    c = max(contours, key=cv2.contourArea)
    area = cv2.contourArea(c)
    perimeter = cv2.arcLength(c, True)
    x,y,w,h = cv2.boundingRect(c)
    bbox_area = w * h if (w*h) > 0 else 1.0
    aspect_ratio = float(w) / float(h) if h > 0 else 0.0
    extent = float(area) / float(bbox_area)
    # solidity = area / convexHullArea
    hull = cv2.convexHull(c)
    hull_area = cv2.contourArea(hull) if hull is not None else 0.0
    solidity = float(area) / hull_area if hull_area > 0 else 0.0
    # circularity
    circularity = 4.0 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0.0

    feats = np.array([area, perimeter, aspect_ratio, extent, solidity, circularity], dtype='float32')
    return feats

def extract_features_from_path(path: str, config: str = 'C') -> np.ndarray:
    """
    Extract feature vector for an image path.
    config: 'A' -> color only, 'B' -> glcm only, 'C' -> combined (color+glcm+shape)
    """
    img = read_image_rgb(path)
    feats = []
    if config in ('A', 'C'):
        hist = extract_hsv_histogram(img, bins=(16,8,8))
        stats = extract_hsv_stats(img)
        feats.append(hist)
        feats.append(stats)
    if config in ('B', 'C'):
        glcm = extract_glcm_features(img, distances=[1], levels=16)
        feats.append(glcm)
    if config == 'C':
        shape = extract_shape_features(img)
        feats.append(shape)
    # concatenate
    if feats:
        return np.concatenate(feats).astype('float32')
    else:
        return np.array([], dtype='float32')

# --------------------------
# Indexing and searching
# --------------------------

def build_index(db_folder: str, config: str = 'C') -> Tuple[List[str], np.ndarray, StandardScaler]:
    """
    Build index: iterate files in db_folder, extract features based on config.
    Returns: list of file paths (order), feature_matrix (n_samples x n_features), fitted StandardScaler
    """
    img_paths = []
    for fn in sorted(os.listdir(db_folder)):
        full = os.path.join(db_folder, fn)
        if os.path.isfile(full):
            img_paths.append(full)

    features = []
    valid_paths = []
    for p in img_paths:
        try:
            fv = extract_features_from_path(p, config=config)
            if fv.size == 0:
                continue
            features.append(fv)
            valid_paths.append(p)
        except Exception as e:
            print(f"Skipping {p}: {e}")

    if not features:
        raise RuntimeError("No features extracted from db_folder. Check images and paths.")

    X = np.vstack(features)  # shape (n_samples, n_features)

    # Standardize features (fit on DB)
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    return valid_paths, X_scaled, scaler

def search_query(query_path: str, db_paths: List[str], X_scaled: np.ndarray, scaler: StandardScaler,
                 config: str = 'C', metric: str = 'cosine', topk: int = 5) -> List[Tuple[str, float]]:
    """
    Extract features for query, scale with scaler, compute distances to DB, return topk (path, distance).
    metric: 'cosine' or 'euclidean'
    """
    qfv = extract_features_from_path(query_path, config=config)
    if qfv.size == 0:
        raise RuntimeError("Query feature vector is empty.")
    qfv_scaled = scaler.transform(qfv.reshape(1, -1))
    # compute pairwise distances between query and DB
    dists = pairwise_distances(qfv_scaled, X_scaled, metric=metric).flatten()
    idx_sorted = np.argsort(dists)  # ascending (smallest = most similar)
    results = [(db_paths[i], float(dists[i])) for i in idx_sorted[:topk]]
    return results

# --------------------------
# Utility: extract label (for evaluation)
# --------------------------

def filename_label(path: str) -> str:
    """
    Extract label from filename. Adjust this function to your dataset naming scheme.
    Default: take filename before first underscore or dash or first token.
    """
    fn = os.path.basename(path)
    name = os.path.splitext(fn)[0]
    # common separators: underscore, dash
    for sep in ['_', '-']:
        if sep in name:
            return name.split(sep)[0]
    # fallback: whole name
    return name

def precision_at_k(results: List[Tuple[str, float]], query_path: str, k: int = 5) -> float:
    qlabel = filename_label(query_path)
    topk = results[:k]
    hits = sum(1 for p, _ in topk if filename_label(p) == qlabel)
    return hits / k

# --------------------------
# Visualization
# --------------------------

def show_search_results(query_path: str, results: List[Tuple[str, float]], figsize=(14,4)):
    """Show query image and top-k results horizontally with distances."""
    k = len(results)
    plt.figure(figsize=figsize)
    # show query at left
    plt.subplot(1, k+1, 1)
    qimg = read_image_rgb(query_path)
    plt.imshow(qimg)
    plt.axis('off')
    plt.title("Query")
    # show results
    for i, (p, dist) in enumerate(results, start=2):
        plt.subplot(1, k+1, i)
        img = read_image_rgb(p)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"Rank {i-1}\nDist: {dist:.4f}")
    plt.tight_layout()
    plt.show()

# --------------------------
# Example usage
# --------------------------

if __name__ == "__main__":
    # PUT YOUR PATHS HERE
    DB_FOLDER = "db"       # folder with database images
    QUERY_FOLDER = "query" # folder with query images

    # Choose configuration: 'A' (color), 'B' (GLCM), 'C' (combined).
    CONFIG = 'C'
    METRIC = 'cosine'  # or 'euclidean'

    # Build index (extract features for DB)
    print("Building index...")
    db_paths, X_scaled, scaler = build_index(DB_FOLDER, config=CONFIG)
    print(f"Indexed {len(db_paths)} images, feature dim = {X_scaled.shape[1]}")

    # Loop through queries and display results
    for qfn in sorted(os.listdir(QUERY_FOLDER)):
        qpath = os.path.join(QUERY_FOLDER, qfn)
        if not os.path.isfile(qpath):
            continue
        print(f"\nSearching for query: {qpath}")
        results = search_query(qpath, db_paths, X_scaled, scaler, config=CONFIG, metric=METRIC, topk=5)
        # print results
        for rank, (p, d) in enumerate(results, start=1):
            print(f"Rank {rank}: {os.path.basename(p)}  Dist={d:.6f}")
        # show images
        show_search_results(qpath, results)
        # compute Precision@5 if labels are available
        p5 = precision_at_k(results, qpath, k=5)
        print(f"Precision@5 = {p5:.2f}")


ImportError: cannot import name 'greycomatrix' from 'skimage.feature' (C:\Users\magnu\PyCharmMiscProject\.venv\Lib\site-packages\skimage\feature\__init__.py)