In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from functools import lru_cache

class BanknoteDetector:
    def __init__(self, dataset_path):
        self.reference_images = []
        self.reference_keypoints = []
        self.reference_descriptors = []
        self.reference_filenames = []
        self.orb = cv2.ORB_create(
            nfeatures=1000,
            scaleFactor=1.2,
            nlevels=12,
            edgeThreshold=15,
            firstLevel=0,
            WTA_K=2,
            scoreType=cv2.ORB_HARRIS_SCORE,
            patchSize=31,
            fastThreshold=20,
        )
        self.load_reference_images(dataset_path)
        self.compute_reference_features()
        
    def preprocess_image(self, image, target_width=1000):
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()

        original_height, original_width = gray.shape[:2]

        scaling_factor = target_width / original_width
        new_height = int(original_height * scaling_factor)
        gray = cv2.resize(gray, (target_width, new_height), interpolation=cv2.INTER_LINEAR)
        
        gray = cv2.GaussianBlur(gray, (5, 5), 0)
        
        clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(16, 16))
        gray = clahe.apply(gray)

        return gray


    def load_reference_images(self, dataset_path):
        def load_single_image(img_path):
            return cv2.imread(img_path, cv2.IMREAD_COLOR)

        with ThreadPoolExecutor() as executor:
            futures = []
            valid_files = []
            
            for filename in os.listdir(dataset_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif')):
                    img_path = os.path.join(dataset_path, filename)
                    futures.append(executor.submit(load_single_image, img_path))
                    valid_files.append(filename)
            
            for filename, future in zip(valid_files, futures):
                img = future.result()
                if img is not None:
                    self.reference_images.append(img)
                    self.reference_filenames.append(filename)

    def compute_reference_features(self):
        def process_single_image(img):
            gray = self.preprocess_image(img)
            return self.orb.detectAndCompute(gray, None)

        with ThreadPoolExecutor() as executor:
            futures = [executor.submit(process_single_image, img) for img in self.reference_images]
            
            for future in futures:
                kp, desc = future.result()
                self.reference_keypoints.append(kp)
                self.reference_descriptors.append(desc)

    @lru_cache(maxsize=100)
    def _match_features(self, test_desc_bytes, ref_desc_bytes):
        test_descriptors = np.frombuffer(test_desc_bytes, dtype=np.uint8).reshape(-1, 32)
        ref_descriptors = np.frombuffer(ref_desc_bytes, dtype=np.uint8).reshape(-1, 32)
        
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
        matches = bf.knnMatch(test_descriptors, ref_descriptors, k=2)
        
        good_matches = []
        for match_group in matches:
            if len(match_group) == 2:
                m, n = match_group
                if m.distance < 0.75 * n.distance:
                    good_matches.append(m)
        return good_matches

    def detect_banknote(self, test_image):
        gray = self.preprocess_image(test_image)
        test_keypoints, test_descriptors = self.orb.detectAndCompute(gray, None)
        
        if test_descriptors is None or len(test_descriptors) == 0:
            return self.reference_images[0], None, self.reference_filenames[0]

        def process_single_match(args):
            i, ref_descriptors = args
            if ref_descriptors is None or len(ref_descriptors) == 0:
                return i, [], None
            
            test_desc_bytes = test_descriptors.tobytes()
            ref_desc_bytes = ref_descriptors.tobytes()
            good_matches = self._match_features(test_desc_bytes, ref_desc_bytes)
            
            if len(good_matches) > 10:
                src_pts = np.float32([test_keypoints[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
                dst_pts = np.float32([self.reference_keypoints[i][m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
                matrix, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
                return i, good_matches, (matrix, mask)
            return i, good_matches, None

        best_match_index = -1
        best_score = 0

        with ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(process_single_match, (i, ref_desc))
                for i, ref_desc in enumerate(self.reference_descriptors)
            ]
            
            for future in futures:
                i, good_matches, homography_result = future.result()
                if homography_result is not None:
                    matrix, mask = homography_result
                    score = np.sum(mask)
                    if score > best_score:
                        best_score = score
                        best_match_index = i

        if best_match_index == -1 and len(self.reference_images) > 0:
            best_match_index = 0

        return (self.reference_images[best_match_index], 
                best_match_index, 
                self.reference_filenames[best_match_index])

    def process_test_images(self, test_dataset_path):
        correct_matches = 0
        total_images = 0
        results_log = []

        def process_single_test(args):
            filename, test_image = args
            matched_image, matched_index, matched_filename = self.detect_banknote(test_image)
            
            test_base_name = filename.split('_')[0]
            reference_base_name = matched_filename.split('_')[0]
            matched = test_base_name == reference_base_name
            
            return {
                'test_image': filename,
                'matched_reference': matched_filename,
                'matched': matched,
                'test_img': test_image,
                'matched_img': matched_image
            }

        test_images = []
        with ThreadPoolExecutor() as executor:
            futures = []
            for filename in os.listdir(test_dataset_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif')):
                    test_path = os.path.join(test_dataset_path, filename)
                    futures.append((filename, executor.submit(cv2.imread, test_path, cv2.IMREAD_COLOR)))

            for filename, future in futures:
                test_image = future.result()
                if test_image is not None:
                    test_images.append((filename, test_image))

        with ThreadPoolExecutor() as executor:
            futures = [executor.submit(process_single_test, item) for item in test_images]
            
            for future in futures:
                result = future.result()
                if result['matched']:
                    correct_matches += 1
                
                results_log.append(result)
                total_images += 1

                plt.figure(figsize=(6, 3))
                plt.subplot(1, 2, 1)
                plt.imshow(cv2.cvtColor(result['test_img'], cv2.COLOR_BGR2RGB))
                plt.title(f"Test: {result['test_image'][:10]}...", fontsize=8)
                plt.axis('off')

                plt.subplot(1, 2, 2)
                plt.imshow(cv2.cvtColor(result['matched_img'], cv2.COLOR_BGR2RGB))
                plt.title(f"Match: {result['matched_reference'][:10]}...", fontsize=8)
                plt.axis('off')

                plt.tight_layout(pad=0.5)
                plt.show()

        accuracy = (correct_matches / total_images) * 100 if total_images > 0 else 0
        return results_log, accuracy

def detect_banknotes(dataset_path, test_dataset_path):
    detector = BanknoteDetector(dataset_path)
    results, accuracy = detector.process_test_images(test_dataset_path)
    print("\nMatching Results:")
    for result in results:
        print(f"Test Image: {result['test_image']}")
        print(f"Matched Reference: {result['matched_reference']}")
        print(f"Correct Match: {result['matched']}\n")
    print(f"Overall Accuracy: {accuracy:.2f}%")

if __name__ == '__main__':
    dataset_path = './train'
    test_dataset_path = './test'
    detect_banknotes(dataset_path, test_dataset_path)