# Template Matching Techniques
# Author: CuongND

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob

# Template matching functions with proper type handling
def apply_template_matching(img, template, method=cv2.TM_CCOEFF_NORMED):
    """Apply template matching using specified method"""
    # Convert to grayscale if needed
    if len(img.shape) == 3:
        img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    else:
        img_gray = img
        
    if len(template.shape) == 3:
        template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
    else:
        template_gray = template
    
    # Perform template matching
    res = cv2.matchTemplate(img_gray, template_gray, method)
    min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
    
    # Return results
    return {
        'result': res,
        'min_val': min_val,
        'max_val': max_val,
        'min_loc': min_loc,
        'max_loc': max_loc
    }

def apply_multi_scale_template_matching(img, template, scales=[1.0], method=cv2.TM_CCOEFF_NORMED):
    """Apply template matching at multiple scales"""
    results = []
    for scale in scales:
        # Resize template
        h, w = template.shape[:2]
        resized_template = cv2.resize(template, (int(w * scale), int(h * scale)))
        
        # Apply template matching
        res = apply_template_matching(img, resized_template, method)
        results.append({
            'scale': scale,
            'result': res['result'],
            'max_val': res['max_val'],
            'max_loc': res['max_loc'],
            'template_size': resized_template.shape[:2]
        })
    
    # Find best match across all scales
    best_match = max(results, key=lambda x: x['max_val'])
    return best_match

def load_images(input_folder, template_path=None):
    """Load and process all images from folder"""
    image_paths = glob(os.path.join(input_folder, '**/*.*'), recursive=True)
    image_paths = [p for p in image_paths if p.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
    
    # Load template if provided
    template = None
    if template_path:
        template = cv2.imread(template_path)
        if template is None:
            print(f"Warning: Could not load template from {template_path}")
    
    results = []
    for img_path in image_paths:
        img = cv2.imread(img_path)
        if img is not None:
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Apply template matching if template is available
            if template is not None:
                # Apply different template matching methods
                res_ccoeff = apply_template_matching(img, template, cv2.TM_CCOEFF_NORMED)
                res_ccorr = apply_template_matching(img, template, cv2.TM_CCORR_NORMED)
                res_sqdiff = apply_template_matching(img, template, cv2.TM_SQDIFF_NORMED)
                
                # Multi-scale matching
                multi_scale = apply_multi_scale_template_matching(img, template, 
                                scales=[0.8, 0.9, 1.0, 1.1, 1.2], 
                                method=cv2.TM_CCOEFF_NORMED)
                
                results.append({
                    'original': img_rgb,
                    'ccoeff': res_ccoeff,
                    'ccorr': res_ccorr,
                    'sqdiff': res_sqdiff,
                    'multi_scale': multi_scale,
                    'name': os.path.basename(img_path),
                    'template': cv2.cvtColor(template, cv2.COLOR_BGR2RGB) if template is not None else None
                })
            else:
                results.append({
                    'original': img_rgb,
                    'name': os.path.basename(img_path)
                })
    return results

def display_results(results, num_images=5, start_index=0):
    """Display comparison results for all template matching types"""
    if not results:
        print("No images to display")
        return
    
    start_index = max(0, min(start_index, len(results)-1))
    end_index = min(start_index + num_images, len(results))
    display_images = results[start_index:end_index]
    
    plt.figure(figsize=(25, 5*len(display_images)))
    
    for i, result in enumerate(display_images):
        # Original image
        plt.subplot(len(display_images), 5, 5*i+1)
        plt.imshow(result['original'])
        plt.title(f"Original\n{result['name']}")
        plt.axis('off')
        
        # Template (if available)
        if 'template' in result and result['template'] is not None:
            plt.subplot(len(display_images), 5, 5*i+2)
            plt.imshow(result['template'])
            plt.title("Template")
            plt.axis('off')
        
        # CCOEFF result
        if 'ccoeff' in result:
            plt.subplot(len(display_images), 5, 5*i+3)
            plt.imshow(result['original'])
            
            # Draw rectangle around best match
            h, w = result['template'].shape[:2] if 'template' in result else (50, 50)
            top_left = result['ccoeff']['max_loc']
            bottom_right = (top_left[0] + w, top_left[1] + h)
            cv2.rectangle(result['original'], top_left, bottom_right, (0,255,0), 2)
            
            plt.imshow(result['original'])
            plt.title(f"CCOEFF (score: {result['ccoeff']['max_val']:.2f})")
            plt.axis('off')
        
        # Multi-scale result
        if 'multi_scale' in result:
            plt.subplot(len(display_images), 5, 5*i+4)
            plt.imshow(result['original'])
            
            # Draw rectangle around best match
            h, w = result['multi_scale']['template_size']
            top_left = result['multi_scale']['max_loc']
            bottom_right = (top_left[0] + w, top_left[1] + h)
            cv2.rectangle(result['original'], top_left, bottom_right, (255,0,0), 2)
            
            plt.imshow(result['original'])
            plt.title(f"Multi-scale (score: {result['multi_scale']['max_val']:.2f})")
            plt.axis('off')
        
        # Heatmap of matching result
        if 'ccoeff' in result:
            plt.subplot(len(display_images), 5, 5*i+5)
            plt.imshow(result['ccoeff']['result'], cmap='hot')
            plt.title("Matching Heatmap")
            plt.colorbar()
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Example usage:
# 1. Load and process all images with template matching
# template_path = 'path/to/template.png'  # Specify your template path
# bia4_results = load_images('images\\bo du lieu bia so 4\\train', template_path)
# display_results(bia4_results, num_images=5, start_index=0)

# bia7_results = load_images('images\\bo du lieu bia so 7\\train', template_path)
# display_results(bia7_results, num_images=5, start_index=0)
