# SIFT object recognition

In [None]:
import cv2, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict
from tqdm.auto import tqdm

plt.rcParams['figure.figsize'] = (7, 4)
print('OpenCV', cv2.__version__)

# === paths ===
keys_dir   = Path('dataset/keys')
video_path = Path('dataset/video.mp4')
output_dir = Path('results')
output_dir.mkdir(parents=True, exist_ok=True)


OpenCV 4.11.0


In [None]:
# === hyper‑params ===
FRAME_SCALE      = 0.5         # 0.5 → szer. ≈ 640 px przy 1280×720
CUDA_SIFT_PARAMS = dict(nfeatures=800, contrastThreshold=0.04,
                        edgeThreshold=10, nOctaveLayers=3)

RATIO_TEST       = 0.75
MIN_INLIERS      = 10
MIN_INLIER_RATIO = 0.25


In [None]:
def preprocess_and_upload(img_bgr, blur=False):
    """Skalowanie, konwersja do gray, (opcjonalnie blur), upload na GPU."""
    small = cv2.resize(img_bgr, None, fx=FRAME_SCALE, fy=FRAME_SCALE,
                       interpolation=cv2.INTER_AREA)
    gray  = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
    if blur:
        gray = cv2.GaussianBlur(gray, (3, 3), 0)
    gpu   = cv2.cuda_GpuMat()
    gpu.upload(gray)
    return gpu, gray.shape  # shape w zeskalowanej przestrzeni


In [None]:
def create_detector():
    return cv2.cuda.SIFT_create(**CUDA_SIFT_PARAMS)

def create_flann_matcher():
    index_params = dict(algorithm=1, trees=5)     # KD‑Tree
    search_params = dict(checks=30)
    return cv2.FlannBasedMatcher(index_params, search_params)


In [None]:
def load_key_images(detector, folder):
    db = []
    for path in sorted(folder.iterdir()):
        if path.suffix.lower() not in {'.jpg', '.jpeg', '.png', '.bmp'}:
            continue
        img_bgr = cv2.imread(str(path))
        gpu, shape = preprocess_and_upload(img_bgr)
        kp_gpu, des_gpu = detector.detectAndComputeAsync(gpu, None)
        kp   = detector.convert(kp_gpu)
        des  = des_gpu.download()                  # float32 (128D)
        if des is not None:
            db.append({'name': path.stem,
                       'kp': kp,
                       'des': des,
                       'shape': shape})
    print('Loaded', len(db), 'key images')
    return db


In [6]:
def good_matches(matcher, d_query, d_train):
    if d_query is None or d_train is None:
        return []
    knn = matcher.knnMatch(d_query, d_train, k=2)
    good = []
    for pair in knn:
        if len(pair) == 2:
            m, n = pair
            if m.distance < RATIO_TEST * n.distance:
                good.append(m)
    return good


In [7]:
def spatial_consistency_filter(matches, kp2, radius=100, min_neighbors=3):
    """Dodatkowy filtr każde dopasowanie zachowujemy, jeśli w promieniu
    `radius` pikseli ma co najmniej `min_neighbors` innych punktów."""
    filtered = []
    for m in matches:
        pt2 = np.array(kp2[m.trainIdx].pt)
        neighbors = [mm for mm in matches
                     if np.linalg.norm(np.array(kp2[mm.trainIdx].pt) - pt2) < radius]
        if len(neighbors) > min_neighbors:
            filtered.append(m)
    return filtered


In [8]:
def annotate(frame, box, label):
    thickness = max(2, int(0.004 * frame.shape[1]))
    cv2.polylines(frame, [np.int32(box)], True, (255, 0, 0),
                  thickness, cv2.LINE_AA)
    cv2.putText(frame, label, (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 1,
                (255, 0, 0), 2, cv2.LINE_AA)


In [9]:
def process_video():
    det     = create_detector()
    matcher = create_flann_matcher()
    key_db  = load_key_images(det, keys_dir)
    print('Loaded', len(key_db), 'key images')

    cap   = cv2.VideoCapture(str(video_path))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or None
    fps   = cap.get(cv2.CAP_PROP_FPS) or 30
    w, h  = int(cap.get(3)), int(cap.get(4))
    writer = cv2.VideoWriter(str(output_dir/'sift.mp4'),
                             cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))

    stats = defaultdict(int)
    pbar  = tqdm(total=total, desc='video', unit='frame')

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        gray_f = preprocess(frame, blur=False)
        kp_f, des_f = det.detectAndCompute(gray_f, None)

        best = None
        best_matches = []
        best_H = None

        for db in key_db:
            gm = good_matches(matcher, db['des'], des_f)
            gm = spatial_consistency_filter(gm, kp_f)  # ekstra filtr
            if len(gm) < MIN_INLIERS:
                continue

            src = np.float32([db['kp'][m.queryIdx].pt for m in gm]).reshape(-1,1,2)
            dst = np.float32([kp_f[m.trainIdx].pt for m in gm]).reshape(-1,1,2)
            H, mask = cv2.findHomography(src, dst, cv2.RANSAC, 5.0)
            if H is None:
                continue

            inliers = mask.ravel().sum()
            if inliers >= MIN_INLIERS and inliers / len(gm) >= MIN_INLIER_RATIO:
                if inliers > len(best_matches):
                    best = db
                    best_matches = gm
                    best_H = H

        if best is not None:
            h0, w0 = best['shape']
            box = cv2.perspectiveTransform(
                np.float32([[0,0], [w0,0], [w0,h0], [0,h0]]).reshape(-1,1,2),
                best_H)
            annotate(frame, box, f"{best['name']} ({len(best_matches)})")
            stats[best['name']] += len(best_matches)

        writer.write(frame)
        pbar.update(1)

    pbar.close(); cap.release(); writer.release()
    return stats


In [10]:
stats = process_video()

Loaded 2 key images


video:   0%|          | 0/763 [00:00<?, ?frame/s]

KeyboardInterrupt: 

In [None]:
if stats:
    names  = list(stats.keys())
    counts = [stats[n] for n in names]
    plt.bar(names, counts)
    plt.xticks(rotation=45, ha='right')
    plt.ylabel('Cumulative matches')
    plt.title('Good matches per key image (SIFT)')
    plt.show()


In [None]:
cap = cv2.VideoCapture(str(output_dir/'sift.mp4'))
ret, fr = cap.read(); cap.release()
if ret:
    fr_rgb = cv2.cvtColor(fr, cv2.COLOR_BGR2RGB)
    plt.imshow(fr_rgb)
    plt.title('Annotated frame sample (SIFT)')
    plt.axis('off')
    plt.show()
