In [None]:
import os
import itertools
import pandas as pd
from pathlib import Path
import shutil
import subprocess
import sys
import h5py
import numpy as np
import cv2
import matplotlib.pyplot as plt
from datetime import datetime

class SignatureMatchingSystem:
    def __init__(self, real_data_dir="real_data", temp_dir="temp_matching", results_dir="matching_results"):
        self.real_data_dir = Path(real_data_dir)
        self.temp_dir = Path(temp_dir)
        self.results_dir = Path(results_dir)
        
        # Create necessary directories
        self.temp_dir.mkdir(exist_ok=True)
        self.results_dir.mkdir(exist_ok=True)
        
        # Results tracking
        self.results_log = []
        
    def get_all_image_pairs(self):
        """Generate all possible pairs of images across all subfolders"""
        all_images = []
        
        # Collect all images with their person ID
        for person_folder in self.real_data_dir.iterdir():
            if person_folder.is_dir():
                person_id = person_folder.name
                for img_file in person_folder.iterdir():
                    if img_file.suffix.lower() in ['.png', '.jpg', '.jpeg']:
                        all_images.append({
                            'person_id': person_id,
                            'image_path': img_file,
                            'image_name': img_file.name,
                            'full_name': f"{person_id}_{img_file.name}"
                        })
        
        # Generate all pairs
        pairs = []
        for img1, img2 in itertools.combinations(all_images, 2):
            pairs.append({
                'img1': img1,
                'img2': img2,
                'same_person': img1['person_id'] == img2['person_id'],
                'pair_name': f"{img1['full_name']}_vs_{img2['full_name']}"
            })
            
        return pairs
    
    def create_pair_batch(self, pairs, batch_size=50):
        """Create batches of pairs for processing"""
        for i in range(0, len(pairs), batch_size):
            yield pairs[i:i + batch_size]
    
    def setup_batch_directory(self, batch_pairs, batch_id):
        """Setup batch directory with proper structure for main.py"""
        batch_temp_dir = self.temp_dir / f"batch_{batch_id}"
        
        # Clean up existing directory
        if batch_temp_dir.exists():
            shutil.rmtree(batch_temp_dir)
        batch_temp_dir.mkdir(exist_ok=True)
        
        # Create 'images' subdirectory as expected by main.py
        images_dir = batch_temp_dir / "images"
        images_dir.mkdir(exist_ok=True)
        
        # Copy images to images subdirectory
        copied_images = set()
        pair_lines = []
        
        for pair in batch_pairs:
            # Copy images if not already copied
            for img_key in ['img1', 'img2']:
                img_info = pair[img_key]
                new_name = img_info['full_name']
                
                if new_name not in copied_images:
                    src_path = img_info['image_path']
                    dst_path = images_dir / new_name
                    shutil.copy2(src_path, dst_path)
                    copied_images.add(new_name)
            
            # Add pair to pairs.txt
            pair_lines.append(f"{pair['img1']['full_name']} {pair['img2']['full_name']}")
        
        # Write pairs.txt in batch directory
        pairs_file = batch_temp_dir / "pairs.txt"
        with open(pairs_file, 'w') as f:
            f.write('\n'.join(pair_lines))
        
        return batch_temp_dir, len(pair_lines)
    
    def run_matching_for_batch(self, batch_temp_dir):
        """Run the main.py script for a batch with corrected command"""
        original_cwd = os.getcwd()
        
        try:
            # Change to batch directory and run main.py from original location
            cmd = [
                sys.executable, 
                str(Path(original_cwd) / "main.py"),  # Full path to main.py
                "-d", ".",  # Current directory (batch directory)
                "-p", "superpoint+lightglue",
                "-s", "custom_pairs",
                "--pair_file", "pairs.txt",
                "-f"
            ]
            
            print(f"Running command from {batch_temp_dir}: {' '.join(cmd)}")
            
            # Run from batch directory
            result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(batch_temp_dir))
            
            if result.returncode != 0:
                print(f"Error running matching: {result.stderr}")
                print(f"stdout: {result.stdout}")
                return False
            else:
                print(f"Matching completed successfully")
                print(f"stdout: {result.stdout}")
            
            return True
            
        except Exception as e:
            print(f"Exception during matching: {e}")
            return False
    
    def load_keypoints_and_scores(self, h5file, image_name):
        """Load keypoints and their confidence scores from features file"""
        try:
            with h5py.File(h5file, "r") as f:
                if image_name in f:
                    keypoints = f[image_name]["keypoints"][()]
                    scores = f[image_name]["scores"][()]
                    return keypoints, scores
                else:
                    return None, None
        except Exception as e:
            print(f"Error loading keypoints for {image_name}: {e}")
            return None, None
    
    def load_matches(self, h5file, image_name1, image_name2):
        """Load matches between two images with proper error handling"""
        try:
            with h5py.File(h5file, "r") as f:
                # Try both orders
                if image_name1 in f and image_name2 in f[image_name1]:
                    matches = f[image_name1][image_name2][()]
                    return matches if len(matches) > 0 else None
                elif image_name2 in f and image_name1 in f[image_name2]:
                    matches = f[image_name2][image_name1][()]
                    return matches if len(matches) > 0 else None
                else:
                    return None
        except Exception as e:
            print(f"Error loading matches between {image_name1} and {image_name2}: {e}")
            return None
    
    def create_signature_mask(self, image_path, blur_kernel=3, threshold_value=240):
        """Create a binary mask to identify signature regions"""
        try:
            img = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
            if img is None:
                return None
            
            blurred = cv2.GaussianBlur(img, (blur_kernel, blur_kernel), 0)
            _, mask = cv2.threshold(blurred, threshold_value, 255, cv2.THRESH_BINARY_INV)
            
            kernel = np.ones((3, 3), np.uint8)
            mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
            mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
            
            return mask
        except Exception as e:
            print(f"Error creating mask for {image_path}: {e}")
            return None
    
    def filter_keypoints_by_signature(self, keypoints, scores, mask, dilation_radius=10):
        """Filter keypoints to keep only those that lie on or near signature strokes"""
        if mask is None or len(keypoints) == 0:
            return np.array([]), np.array([]), np.array([])
        
        try:
            kernel = cv2.getStructuringElement(
                cv2.MORPH_ELLIPSE, (dilation_radius * 2 + 1, dilation_radius * 2 + 1)
            )
            dilated_mask = cv2.dilate(mask, kernel, iterations=1)
            
            valid_indices = []
            for i, (x, y) in enumerate(keypoints):
                x_int, y_int = int(round(x)), int(round(y))
                if 0 <= x_int < dilated_mask.shape[1] and 0 <= y_int < dilated_mask.shape[0]:
                    if dilated_mask[y_int, x_int] > 0:
                        valid_indices.append(i)
            
            valid_indices = np.array(valid_indices)
            
            if len(valid_indices) > 0:
                filtered_keypoints = keypoints[valid_indices]
                filtered_scores = scores[valid_indices]
                return filtered_keypoints, filtered_scores, valid_indices
            else:
                return np.array([]), np.array([]), np.array([])
        except Exception as e:
            print(f"Error filtering keypoints: {e}")
            return np.array([]), np.array([]), np.array([])
    
    def filter_matches_by_signature_features(self, matches, valid_indices1, valid_indices2):
        """Filter matches to keep only those between signature features"""
        if matches is None or len(matches) == 0 or len(valid_indices1) == 0 or len(valid_indices2) == 0:
            return np.array([])
        
        try:
            idx1_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(valid_indices1)}
            idx2_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(valid_indices2)}
            
            filtered_matches = []
            for match in matches:
                idx1, idx2 = match[0], match[1]
                if idx1 in idx1_map and idx2 in idx2_map:
                    filtered_matches.append([idx1_map[idx1], idx2_map[idx2]])
            
            return np.array(filtered_matches)
        except Exception as e:
            print(f"Error filtering matches: {e}")
            return np.array([])
    
    def visualize_matches(self, img1_path, img2_path, keypoints1, keypoints2, 
                         original_matches, filtered_keypoints1, filtered_keypoints2, 
                         filtered_matches, pair_info, save_path):
        """Create and save visualization of matches"""
        try:
            img1 = cv2.imread(str(img1_path))
            img2 = cv2.imread(str(img2_path))
            
            if img1 is None or img2 is None:
                print(f"Could not load images: {img1_path}, {img2_path}")
                return False
            
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 16))
            
            # Plot 1: Original matches
            if original_matches is not None and len(original_matches) > 0:
                kpts1_orig = [cv2.KeyPoint(float(x), float(y), 1) for x, y in keypoints1]
                kpts2_orig = [cv2.KeyPoint(float(x), float(y), 1) for x, y in keypoints2]
                
                cv_matches_orig = [
                    cv2.DMatch(_queryIdx=int(i), _trainIdx=int(j), _distance=0) 
                    for i, j in original_matches
                ]
                
                img_matches_orig = cv2.drawMatches(
                    img1, kpts1_orig, img2, kpts2_orig, cv_matches_orig, None,
                    flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
                )
                
                ax1.imshow(cv2.cvtColor(img_matches_orig, cv2.COLOR_BGR2RGB))
                ax1.set_title(f"Original Feature Matches ({len(original_matches)} matches)", fontsize=16)
            else:
                ax1.text(0.5, 0.5, "No original matches found", ha='center', va='center', 
                        transform=ax1.transAxes, fontsize=16)
            ax1.axis("off")
            
            # Plot 2: Signature matches
            if (filtered_matches is not None and len(filtered_matches) > 0 and 
                len(filtered_keypoints1) > 0 and len(filtered_keypoints2) > 0):
                
                kpts1_filt = [cv2.KeyPoint(float(x), float(y), 1) for x, y in filtered_keypoints1]
                kpts2_filt = [cv2.KeyPoint(float(x), float(y), 1) for x, y in filtered_keypoints2]
                
                cv_matches_filt = [
                    cv2.DMatch(_queryIdx=int(i), _trainIdx=int(j), _distance=0) 
                    for i, j in filtered_matches
                ]
                
                img_matches_filt = cv2.drawMatches(
                    img1, kpts1_filt, img2, kpts2_filt, cv_matches_filt, None,
                    flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
                )
                
                ax2.imshow(cv2.cvtColor(img_matches_filt, cv2.COLOR_BGR2RGB))
                ax2.set_title(f"Signature Feature Matches ({len(filtered_matches)} matches)", fontsize=16)
            else:
                ax2.text(0.5, 0.5, "No signature matches found", ha='center', va='center', 
                        transform=ax2.transAxes, fontsize=16)
            ax2.axis("off")
            
            # Add pair information
            same_person_text = "SAME PERSON" if pair_info['same_person'] else "DIFFERENT PERSON"
            fig.suptitle(f"{pair_info['pair_name']} - {same_person_text}", fontsize=18, y=0.98)
            
            plt.tight_layout()
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()
            
            return True
            
        except Exception as e:
            print(f"Error creating visualization: {e}")
            return False
    
    def find_results_files(self, batch_temp_dir):
        """Find the results files (features.h5 and matches.h5) in the batch directory"""
        # Look for results directory created by main.py
        results_dirs = list(batch_temp_dir.glob("results_*"))
        
        if results_dirs:
            results_dir = results_dirs[0]  # Take the first (should be only one)
            features_path = results_dir / "features.h5"
            matches_path = results_dir / "matches.h5"
        else:
            # Sometimes files might be created directly in the batch directory
            features_path = batch_temp_dir / "features.h5"
            matches_path = batch_temp_dir / "matches.h5"
        
        return features_path, matches_path
    
    def process_batch_results(self, batch_pairs, batch_temp_dir, batch_id):
        """Process results from a batch"""
        features_path, matches_path = self.find_results_files(batch_temp_dir)
        
        if not features_path.exists() or not matches_path.exists():
            print(f"Missing features ({features_path.exists()}) or matches ({matches_path.exists()}) file for batch {batch_id}")
            print(f"Looking in: {batch_temp_dir}")
            # List all files in the directory for debugging
            print("Files found:")
            for file in batch_temp_dir.rglob("*"):
                print(f"  {file}")
            return
        
        batch_results_dir = self.results_dir / f"batch_{batch_id}"
        batch_results_dir.mkdir(exist_ok=True)
        
        batch_log = []
        
        for pair in batch_pairs:
            try:
                img1_name = pair['img1']['full_name']
                img2_name = pair['img2']['full_name']
                
                print(f"Processing pair: {img1_name} vs {img2_name}")
                
                # Load keypoints
                keypoints1, scores1 = self.load_keypoints_and_scores(features_path, img1_name)
                keypoints2, scores2 = self.load_keypoints_and_scores(features_path, img2_name)
                
                if keypoints1 is None or keypoints2 is None:
                    print(f"Failed to load keypoints for {img1_name} or {img2_name}")
                    continue
                
                # Load matches
                original_matches = self.load_matches(matches_path, img1_name, img2_name)
                
                # Create signature masks
                img1_path = batch_temp_dir / "images" / img1_name
                img2_path = batch_temp_dir / "images" / img2_name
                
                mask1 = self.create_signature_mask(img1_path)
                mask2 = self.create_signature_mask(img2_path)
                
                # Filter keypoints by signature regions
                filt_kpts1, filt_scores1, valid_idx1 = self.filter_keypoints_by_signature(
                    keypoints1, scores1, mask1
                )
                filt_kpts2, filt_scores2, valid_idx2 = self.filter_keypoints_by_signature(
                    keypoints2, scores2, mask2
                )
                
                # Filter matches
                filtered_matches = self.filter_matches_by_signature_features(
                    original_matches, valid_idx1, valid_idx2
                )
                
                # Create visualization
                viz_path = batch_results_dir / f"{pair['pair_name']}.png"
                viz_success = self.visualize_matches(
                    img1_path, img2_path, keypoints1, keypoints2,
                    original_matches, filt_kpts1, filt_kpts2, filtered_matches,
                    pair, viz_path
                )
                
                # Log results
                result_entry = {
                    'pair_name': pair['pair_name'],
                    'img1': img1_name,
                    'img2': img2_name,
                    'same_person': pair['same_person'],
                    'person1_id': pair['img1']['person_id'],
                    'person2_id': pair['img2']['person_id'],
                    'original_features_1': len(keypoints1),
                    'original_features_2': len(keypoints2),
                    'signature_features_1': len(filt_kpts1),
                    'signature_features_2': len(filt_kpts2),
                    'original_matches': len(original_matches) if original_matches is not None else 0,
                    'signature_matches': len(filtered_matches) if filtered_matches is not None else 0,
                    'visualization_created': viz_success,
                    'batch_id': batch_id,
                    'timestamp': datetime.now().isoformat()
                }
                
                batch_log.append(result_entry)
                self.results_log.append(result_entry)
                
                print(f"  Original matches: {result_entry['original_matches']}")
                print(f"  Signature matches: {result_entry['signature_matches']}")
                
            except Exception as e:
                print(f"Error processing pair {pair['pair_name']}: {e}")
                continue
        
        # Save batch results
        if batch_log:
            batch_df = pd.DataFrame(batch_log)
            batch_df.to_csv(batch_results_dir / "batch_results.csv", index=False)
    
    def run_comprehensive_matching(self, batch_size=20):
        """Run comprehensive signature matching across all pairs"""
        print("Starting comprehensive signature matching...")
        
        # Get all possible pairs
        all_pairs = self.get_all_image_pairs()
        print(f"Total pairs to process: {len(all_pairs)}")
        
        same_person_pairs = sum(1 for p in all_pairs if p['same_person'])
        different_person_pairs = len(all_pairs) - same_person_pairs
        print(f"Same person pairs: {same_person_pairs}")
        print(f"Different person pairs: {different_person_pairs}")
        
        # Process in batches
        for batch_id, batch_pairs in enumerate(self.create_pair_batch(all_pairs, batch_size)):
            print(f"\nProcessing batch {batch_id + 1} ({len(batch_pairs)} pairs)...")
            
            try:
                # Setup batch directory
                batch_temp_dir, num_pairs = self.setup_batch_directory(batch_pairs, batch_id)
                print(f"Setup completed for {num_pairs} pairs in {batch_temp_dir}")
                
                # Run matching
                success = self.run_matching_for_batch(batch_temp_dir)
                if not success:
                    print(f"Matching failed for batch {batch_id}")
                    continue
                
                # Process results
                self.process_batch_results(batch_pairs, batch_temp_dir, batch_id)
                
                # Clean up temporary files (optional - comment out for debugging)
                # shutil.rmtree(batch_temp_dir)
                
            except Exception as e:
                print(f"Error processing batch {batch_id}: {e}")
                continue
        
        # Save comprehensive results
        if self.results_log:
            comprehensive_df = pd.DataFrame(self.results_log)
            comprehensive_df.to_csv(self.results_dir / "comprehensive_results.csv", index=False)
            
            # Generate summary statistics
            self.generate_summary_report(comprehensive_df)
        else:
            print("No results to save!")
    
    def generate_summary_report(self, df):
        """Generate a summary report of matching results"""
        summary = {
            'total_pairs': len(df),
            'same_person_pairs': len(df[df['same_person'] == True]),
            'different_person_pairs': len(df[df['same_person'] == False]),
            'pairs_with_original_matches': len(df[df['original_matches'] > 0]),
            'pairs_with_signature_matches': len(df[df['signature_matches'] > 0]),
            'avg_original_matches_same_person': df[df['same_person'] == True]['original_matches'].mean(),
            'avg_original_matches_different_person': df[df['same_person'] == False]['original_matches'].mean(),
            'avg_signature_matches_same_person': df[df['same_person'] == True]['signature_matches'].mean(),
            'avg_signature_matches_different_person': df[df['same_person'] == False]['signature_matches'].mean(),
        }
        
        # Save summary
        summary_df = pd.DataFrame([summary])
        summary_df.to_csv(self.results_dir / "summary_report.csv", index=False)
        
        # Print summary
        print("\n" + "="*60)
        print("COMPREHENSIVE MATCHING SUMMARY")
        print("="*60)
        for key, value in summary.items():
            print(f"{key.replace('_', ' ').title()}: {value:.2f}" if isinstance(value, float) else f"{key.replace('_', ' ').title()}: {value}")
        print("="*60)


# Usage
if __name__ == "__main__":
    # Initialize the matching system
    matcher = SignatureMatchingSystem(
        real_data_dir="real_data",
        temp_dir="temp_matching", 
        results_dir="matching_results"
    )
    
    # Run comprehensive matching
    matcher.run_comprehensive_matching(batch_size=20)

Starting comprehensive signature matching...
Total pairs to process: 378
Same person pairs: 74
Different person pairs: 304

Processing batch 1 (20 pairs)...
Setup completed for 20 pairs in temp_matching/batch_0
Running command from temp_matching/batch_0: /home/kshitiz/Documents/Deep_Image_Matching/deep-image-matching/venv/bin/python /home/kshitiz/Documents/Deep_Image_Matching/deep-image-matching/main.py -d . -p superpoint+lightglue -s custom_pairs --pair_file pairs.txt -f
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)

  0%|          | 0/21 [00:00<?, ?it/s]
  5%|▍         | 1/21 [00:00<00:11,  1.80it/s]
 19%|█▉        | 4/21 [00:00<00:02,  7.57it/s]
 43%|████▎     | 9/21 [00:00<00:00, 15.98it/s]
 67%|██████▋   | 14/21 [00:00<00:00, 23.13it/s]
100%|██████████| 21/21 [00:01<00:00, 33.59it/s]
100%|██████████| 21/21 [00:01<00:00, 20.74it/s]

  0%|          | 0/20 [00:00<?, ?it/s]
  5%|▌         | 1/20 [00:00<00:08,  2.37it/s]
 15%|█▌        | 3/20 [00:00<00:03,  5.57it/s]
 25%|██▌

Starting comprehensive signature matching...
Total pairs to process: 378
Same person pairs: 74
Different person pairs: 304

Processing batch 1 (20 pairs)...
Setup completed for 20 pairs in temp_matching/batch_0
Running command from temp_matching/batch_0: /home/kshitiz/Documents/Deep_Image_Matching/deep-image-matching/venv/bin/python /home/kshitiz/Documents/Deep_Image_Matching/deep-image-matching/main.py -d . -p superpoint+lightglue -s custom_pairs --pair_file pairs.txt -f
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)

  0%|          | 0/21 [00:00<?, ?it/s]
  5%|▍         | 1/21 [00:00<00:05,  3.59it/s]
 24%|██▍       | 5/21 [00:00<00:01, 15.48it/s]
 48%|████▊     | 10/21 [00:00<00:00, 25.80it/s]
 81%|████████  | 17/21 [00:00<00:00, 38.47it/s]
100%|██████████| 21/21 [00:00<00:00, 32.37it/s]

  0%|          | 0/20 [00:00<?, ?it/s]
  5%|▌         | 1/20 [00:00<00:05,  3.60it/s]
 15%|█▌        | 3/20 [00:00<00:02,  7.21it/s]
 25%|██▌       | 5/20 [00:00<00:01,  8.93it/s]
 35%|███

In [12]:
# comprehensive_signature_matching_fixed.py
import os
import itertools
import pandas as pd
from pathlib import Path
import shutil
import subprocess
import sys
import h5py
import numpy as np
import cv2
import matplotlib.pyplot as plt
from datetime import datetime

class SignatureMatchingSystem:
    def __init__(self, real_data_dir="real_data", temp_dir="temp_matching", results_dir="matching_results"):
        self.real_data_dir = Path(real_data_dir)
        self.temp_dir = Path(temp_dir)
        self.results_dir = Path(results_dir)
        
        # Create necessary directories
        self.temp_dir.mkdir(exist_ok=True)
        self.results_dir.mkdir(exist_ok=True)
        
        # Results tracking
        self.results_log = []
        
    def get_all_image_pairs(self):
        """Generate all possible pairs of images across all subfolders"""
        all_images = []
        
        # Collect all images with their person ID
        for person_folder in self.real_data_dir.iterdir():
            if person_folder.is_dir():
                person_id = person_folder.name
                for img_file in person_folder.iterdir():
                    if img_file.suffix.lower() in ['.png', '.jpg', '.jpeg']:
                        all_images.append({
                            'person_id': person_id,
                            'image_path': img_file,
                            'image_name': img_file.name,
                            'full_name': f"{person_id}_{img_file.name}"
                        })
        
        # Generate all pairs
        pairs = []
        for img1, img2 in itertools.combinations(all_images, 2):
            pairs.append({
                'img1': img1,
                'img2': img2,
                'same_person': img1['person_id'] == img2['person_id'],
                'pair_name': f"{img1['full_name']}_vs_{img2['full_name']}"
            })
            
        return pairs
    
    def create_pair_batch(self, pairs, batch_size=50):
        """Create batches of pairs for processing"""
        for i in range(0, len(pairs), batch_size):
            yield pairs[i:i + batch_size]
    
    def setup_batch_directory(self, batch_pairs, batch_id):
        """Setup batch directory with proper structure for main.py"""
        batch_temp_dir = self.temp_dir / f"batch_{batch_id}"
        
        # Clean up existing directory
        if batch_temp_dir.exists():
            shutil.rmtree(batch_temp_dir)
        batch_temp_dir.mkdir(exist_ok=True)
        
        # Create 'images' subdirectory as expected by main.py
        images_dir = batch_temp_dir / "images"
        images_dir.mkdir(exist_ok=True)
        
        # Copy images to images subdirectory
        copied_images = set()
        pair_lines = []
        
        for pair in batch_pairs:
            # Copy images if not already copied
            for img_key in ['img1', 'img2']:
                img_info = pair[img_key]
                new_name = img_info['full_name']
                
                if new_name not in copied_images:
                    src_path = img_info['image_path']
                    dst_path = images_dir / new_name
                    shutil.copy2(src_path, dst_path)
                    copied_images.add(new_name)
            
            # Add pair to pairs.txt
            pair_lines.append(f"{pair['img1']['full_name']} {pair['img2']['full_name']}")
        
        # Write pairs.txt in batch directory
        pairs_file = batch_temp_dir / "pairs.txt"
        with open(pairs_file, 'w') as f:
            f.write('\n'.join(pair_lines))
        
        return batch_temp_dir, len(pair_lines)
    
    def run_matching_for_batch(self, batch_temp_dir):
        """Run the main.py script for a batch with corrected command"""
        original_cwd = os.getcwd()
        
        try:
            # Change to batch directory and run main.py from original location
            cmd = [
                sys.executable, 
                str(Path(original_cwd) / "main.py"),  # Full path to main.py
                "-d", ".",  # Current directory (batch directory)
                "-p", "superpoint+lightglue",
                "-s", "custom_pairs",
                "--pair_file", "pairs.txt",
                "-f"
            ]
            
            print(f"Running command from {batch_temp_dir}: {' '.join(cmd)}")
            
            # Run from batch directory
            result = subprocess.run(cmd, capture_output=True, text=True, cwd=str(batch_temp_dir))
            
            if result.returncode != 0:
                print(f"Error running matching: {result.stderr}")
                print(f"stdout: {result.stdout}")
                return False
            else:
                print(f"Matching completed successfully")
                print(f"stdout: {result.stdout}")
            
            return True
            
        except Exception as e:
            print(f"Exception during matching: {e}")
            return False
    
    def load_keypoints_and_scores(self, h5file, image_name):
        """Load keypoints and their confidence scores from features file"""
        try:
            with h5py.File(h5file, "r") as f:
                if image_name in f:
                    keypoints = f[image_name]["keypoints"][()]
                    scores = f[image_name]["scores"][()]
                    return keypoints, scores
                else:
                    return None, None
        except Exception as e:
            print(f"Error loading keypoints for {image_name}: {e}")
            return None, None
    
    def load_matches(self, h5file, image_name1, image_name2):
        """Load matches between two images with proper error handling"""
        try:
            with h5py.File(h5file, "r") as f:
                # Try both orders
                if image_name1 in f and image_name2 in f[image_name1]:
                    matches = f[image_name1][image_name2][()]
                    return matches if len(matches) > 0 else None
                elif image_name2 in f and image_name1 in f[image_name2]:
                    matches = f[image_name2][image_name1][()]
                    return matches if len(matches) > 0 else None
                else:
                    return None
        except Exception as e:
            print(f"Error loading matches between {image_name1} and {image_name2}: {e}")
            return None
    
    def create_signature_mask(self, image_path, blur_kernel=3, threshold_value=240):
        """Create a binary mask to identify signature regions"""
        try:
            img = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
            if img is None:
                return None
            
            blurred = cv2.GaussianBlur(img, (blur_kernel, blur_kernel), 0)
            _, mask = cv2.threshold(blurred, threshold_value, 255, cv2.THRESH_BINARY_INV)
            
            kernel = np.ones((3, 3), np.uint8)
            mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
            mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
            
            return mask
        except Exception as e:
            print(f"Error creating mask for {image_path}: {e}")
            return None
    
    def filter_keypoints_by_signature(self, keypoints, scores, mask, dilation_radius=10):
        """Filter keypoints to keep only those that lie on or near signature strokes"""
        if mask is None or len(keypoints) == 0:
            return np.array([]), np.array([]), np.array([])
        
        try:
            kernel = cv2.getStructuringElement(
                cv2.MORPH_ELLIPSE, (dilation_radius * 2 + 1, dilation_radius * 2 + 1)
            )
            dilated_mask = cv2.dilate(mask, kernel, iterations=1)
            
            valid_indices = []
            for i, (x, y) in enumerate(keypoints):
                x_int, y_int = int(round(x)), int(round(y))
                if 0 <= x_int < dilated_mask.shape[1] and 0 <= y_int < dilated_mask.shape[0]:
                    if dilated_mask[y_int, x_int] > 0:
                        valid_indices.append(i)
            
            valid_indices = np.array(valid_indices)
            
            if len(valid_indices) > 0:
                filtered_keypoints = keypoints[valid_indices]
                filtered_scores = scores[valid_indices]
                return filtered_keypoints, filtered_scores, valid_indices
            else:
                return np.array([]), np.array([]), np.array([])
        except Exception as e:
            print(f"Error filtering keypoints: {e}")
            return np.array([]), np.array([]), np.array([])
    
    def filter_matches_by_signature_features(self, matches, valid_indices1, valid_indices2):
        """Filter matches to keep only those between signature features"""
        if matches is None or len(matches) == 0 or len(valid_indices1) == 0 or len(valid_indices2) == 0:
            return np.array([])
        
        try:
            idx1_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(valid_indices1)}
            idx2_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(valid_indices2)}
            
            filtered_matches = []
            for match in matches:
                idx1, idx2 = match[0], match[1]
                if idx1 in idx1_map and idx2 in idx2_map:
                    filtered_matches.append([idx1_map[idx1], idx2_map[idx2]])
            
            return np.array(filtered_matches)
        except Exception as e:
            print(f"Error filtering matches: {e}")
            return np.array([])
    
    def visualize_matches(self, img1_path, img2_path, keypoints1, keypoints2, 
                         original_matches, filtered_keypoints1, filtered_keypoints2, 
                         filtered_matches, pair_info, save_path):
        """Create and save visualization of matches"""
        try:
            print(f"    Loading images: {img1_path}, {img2_path}")
            img1 = cv2.imread(str(img1_path))
            img2 = cv2.imread(str(img2_path))
            
            if img1 is None or img2 is None:
                print(f"    ERROR: Could not load images: {img1_path}, {img2_path}")
                return False
            
            print(f"    Images loaded successfully. Shape1: {img1.shape}, Shape2: {img2.shape}")
            
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 16))
            
            # Plot 1: Original matches
            if original_matches is not None and len(original_matches) > 0:
                print(f"    Creating original matches visualization ({len(original_matches)} matches)")
                kpts1_orig = [cv2.KeyPoint(float(x), float(y), 1) for x, y in keypoints1]
                kpts2_orig = [cv2.KeyPoint(float(x), float(y), 1) for x, y in keypoints2]
                
                cv_matches_orig = [
                    cv2.DMatch(_queryIdx=int(i), _trainIdx=int(j), _distance=0) 
                    for i, j in original_matches
                ]
                
                img_matches_orig = cv2.drawMatches(
                    img1, kpts1_orig, img2, kpts2_orig, cv_matches_orig, None,
                    flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
                )
                
                ax1.imshow(cv2.cvtColor(img_matches_orig, cv2.COLOR_BGR2RGB))
                ax1.set_title(f"Original Feature Matches ({len(original_matches)} matches)", fontsize=16)
            else:
                print(f"    No original matches to display")
                ax1.text(0.5, 0.5, "No original matches found", ha='center', va='center', 
                        transform=ax1.transAxes, fontsize=16)
            ax1.axis("off")
            
            # Plot 2: Signature matches
            if (filtered_matches is not None and len(filtered_matches) > 0 and 
                len(filtered_keypoints1) > 0 and len(filtered_keypoints2) > 0):
                
                print(f"    Creating signature matches visualization ({len(filtered_matches)} matches)")
                kpts1_filt = [cv2.KeyPoint(float(x), float(y), 1) for x, y in filtered_keypoints1]
                kpts2_filt = [cv2.KeyPoint(float(x), float(y), 1) for x, y in filtered_keypoints2]
                
                cv_matches_filt = [
                    cv2.DMatch(_queryIdx=int(i), _trainIdx=int(j), _distance=0) 
                    for i, j in filtered_matches
                ]
                
                img_matches_filt = cv2.drawMatches(
                    img1, kpts1_filt, img2, kpts2_filt, cv_matches_filt, None,
                    flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
                )
                
                ax2.imshow(cv2.cvtColor(img_matches_filt, cv2.COLOR_BGR2RGB))
                ax2.set_title(f"Signature Feature Matches ({len(filtered_matches)} matches)", fontsize=16)
            else:
                print(f"    No signature matches to display")
                ax2.text(0.5, 0.5, "No signature matches found", ha='center', va='center', 
                        transform=ax2.transAxes, fontsize=16)
            ax2.axis("off")
            
            # Add pair information
            same_person_text = "SAME PERSON" if pair_info['same_person'] else "DIFFERENT PERSON"
            fig.suptitle(f"{pair_info['pair_name']} - {same_person_text}", fontsize=18, y=0.98)
            
            plt.tight_layout()
            
            # Ensure directory exists
            save_path.parent.mkdir(parents=True, exist_ok=True)
            
            print(f"    Saving visualization to: {save_path}")
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()
            
            print(f"    Visualization saved successfully!")
            return True
            
        except Exception as e:
            print(f"    ERROR creating visualization: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def find_results_files(self, batch_temp_dir):
        """Find the results files (features.h5 and matches.h5) in the batch directory"""
        # Look for results directory created by main.py
        results_dirs = list(batch_temp_dir.glob("results_*"))
        
        if results_dirs:
            results_dir = results_dirs[0]  # Take the first (should be only one)
            features_path = results_dir / "features.h5"
            matches_path = results_dir / "matches.h5"
        else:
            # Sometimes files might be created directly in the batch directory
            features_path = batch_temp_dir / "features.h5"
            matches_path = batch_temp_dir / "matches.h5"
        
        return features_path, matches_path
    
    def process_batch_results(self, batch_pairs, batch_temp_dir, batch_id):
        """Process results from a batch"""
        features_path, matches_path = self.find_results_files(batch_temp_dir)
        
        if not features_path.exists() or not matches_path.exists():
            print(f"Missing features ({features_path.exists()}) or matches ({matches_path.exists()}) file for batch {batch_id}")
            print(f"Looking in: {batch_temp_dir}")
            # List all files in the directory for debugging
            print("Files found:")
            for file in batch_temp_dir.rglob("*"):
                print(f"  {file}")
            return
        
        batch_results_dir = self.results_dir / f"batch_{batch_id}"
        batch_results_dir.mkdir(exist_ok=True)
        
        batch_log = []
        
        for pair in batch_pairs:
            try:
                img1_name = pair['img1']['full_name']
                img2_name = pair['img2']['full_name']
                
                print(f"Processing pair: {img1_name} vs {img2_name}")
                
                # Load keypoints
                keypoints1, scores1 = self.load_keypoints_and_scores(features_path, img1_name)
                keypoints2, scores2 = self.load_keypoints_and_scores(features_path, img2_name)
                
                if keypoints1 is None or keypoints2 is None:
                    print(f"Failed to load keypoints for {img1_name} or {img2_name}")
                    continue
                
                # Load matches
                original_matches = self.load_matches(matches_path, img1_name, img2_name)
                
                # Create signature masks
                img1_path = batch_temp_dir / "images" / img1_name
                img2_path = batch_temp_dir / "images" / img2_name
                
                mask1 = self.create_signature_mask(img1_path)
                mask2 = self.create_signature_mask(img2_path)
                
                # Filter keypoints by signature regions
                filt_kpts1, filt_scores1, valid_idx1 = self.filter_keypoints_by_signature(
                    keypoints1, scores1, mask1
                )
                filt_kpts2, filt_scores2, valid_idx2 = self.filter_keypoints_by_signature(
                    keypoints2, scores2, mask2
                )
                
                # Filter matches
                filtered_matches = self.filter_matches_by_signature_features(
                    original_matches, valid_idx1, valid_idx2
                )
                
                # Create visualization
                viz_path = batch_results_dir / f"{pair['pair_name']}.png"
                print(f"  Creating visualization: {viz_path}")
                viz_success = self.visualize_matches(
                    img1_path, img2_path, keypoints1, keypoints2,
                    original_matches, filt_kpts1, filt_kpts2, filtered_matches,
                    pair, viz_path
                )
                print(f"  Visualization success: {viz_success}")
                
                # Log results
                result_entry = {
                    'pair_name': pair['pair_name'],
                    'img1': img1_name,
                    'img2': img2_name,
                    'same_person': pair['same_person'],
                    'person1_id': pair['img1']['person_id'],
                    'person2_id': pair['img2']['person_id'],
                    'original_features_1': len(keypoints1),
                    'original_features_2': len(keypoints2),
                    'signature_features_1': len(filt_kpts1),
                    'signature_features_2': len(filt_kpts2),
                    'original_matches': len(original_matches) if original_matches is not None else 0,
                    'signature_matches': len(filtered_matches) if filtered_matches is not None else 0,
                    'visualization_created': viz_success,
                    'visualization_path': str(viz_path) if viz_success else None,
                    'batch_id': batch_id,
                    'timestamp': datetime.now().isoformat()
                }
                
                batch_log.append(result_entry)
                self.results_log.append(result_entry)
                
                print(f"  Original matches: {result_entry['original_matches']}")
                print(f"  Signature matches: {result_entry['signature_matches']}")
                print(f"  Visualization saved: {viz_path if viz_success else 'FAILED'}")
                
            except Exception as e:
                print(f"Error processing pair {pair['pair_name']}: {e}")
                continue
        
        # Save batch results
        if batch_log:
            batch_df = pd.DataFrame(batch_log)
            batch_df.to_csv(batch_results_dir / "batch_results.csv", index=False)
    
    def run_comprehensive_matching(self, batch_size=20):
        """Run comprehensive signature matching across all pairs"""
        print("Starting comprehensive signature matching...")
        
        # Get all possible pairs
        all_pairs = self.get_all_image_pairs()
        print(f"Total pairs to process: {len(all_pairs)}")
        
        same_person_pairs = sum(1 for p in all_pairs if p['same_person'])
        different_person_pairs = len(all_pairs) - same_person_pairs
        print(f"Same person pairs: {same_person_pairs}")
        print(f"Different person pairs: {different_person_pairs}")
        
        # Process in batches
        for batch_id, batch_pairs in enumerate(self.create_pair_batch(all_pairs, batch_size)):
            print(f"\nProcessing batch {batch_id + 1} ({len(batch_pairs)} pairs)...")
            
            try:
                # Setup batch directory
                batch_temp_dir, num_pairs = self.setup_batch_directory(batch_pairs, batch_id)
                print(f"Setup completed for {num_pairs} pairs in {batch_temp_dir}")
                
                # Run matching
                success = self.run_matching_for_batch(batch_temp_dir)
                if not success:
                    print(f"Matching failed for batch {batch_id}")
                    continue
                
                # Process results
                self.process_batch_results(batch_pairs, batch_temp_dir, batch_id)
                
                # Clean up temporary files (optional - comment out for debugging)
                # shutil.rmtree(batch_temp_dir)
                
            except Exception as e:
                print(f"Error processing batch {batch_id}: {e}")
                continue
        
        # Save comprehensive results
        if self.results_log:
            comprehensive_df = pd.DataFrame(self.results_log)
            comprehensive_df.to_csv(self.results_dir / "comprehensive_results.csv", index=False)
            
            # Generate summary statistics
            self.generate_summary_report(comprehensive_df)
        else:
            print("No results to save!")
    
    def generate_summary_report(self, df):
        """Generate a summary report of matching results"""
        summary = {
            'total_pairs': len(df),
            'same_person_pairs': len(df[df['same_person'] == True]),
            'different_person_pairs': len(df[df['same_person'] == False]),
            'pairs_with_original_matches': len(df[df['original_matches'] > 0]),
            'pairs_with_signature_matches': len(df[df['signature_matches'] > 0]),
            'avg_original_matches_same_person': df[df['same_person'] == True]['original_matches'].mean(),
            'avg_original_matches_different_person': df[df['same_person'] == False]['original_matches'].mean(),
            'avg_signature_matches_same_person': df[df['same_person'] == True]['signature_matches'].mean(),
            'avg_signature_matches_different_person': df[df['same_person'] == False]['signature_matches'].mean(),
        }
        
        # Save summary
        summary_df = pd.DataFrame([summary])
        summary_df.to_csv(self.results_dir / "summary_report.csv", index=False)
        
        # Print summary
        print("\n" + "="*60)
        print("COMPREHENSIVE MATCHING SUMMARY")
        print("="*60)
        for key, value in summary.items():
            print(f"{key.replace('_', ' ').title()}: {value:.2f}" if isinstance(value, float) else f"{key.replace('_', ' ').title()}: {value}")
        print("="*60)


# Usage
if __name__ == "__main__":
    # Initialize the matching system
    matcher = SignatureMatchingSystem(
        real_data_dir="real_data",
        temp_dir="temp_matching", 
        results_dir="matching_results"
    )
    
    # Run comprehensive matching
    matcher.run_comprehensive_matching(batch_size=20)

Starting comprehensive signature matching...
Total pairs to process: 378
Same person pairs: 74
Different person pairs: 304

Processing batch 1 (20 pairs)...
Setup completed for 20 pairs in temp_matching/batch_0
Running command from temp_matching/batch_0: /home/kshitiz/Documents/Deep_Image_Matching/deep-image-matching/venv/bin/python /home/kshitiz/Documents/Deep_Image_Matching/deep-image-matching/main.py -d . -p superpoint+lightglue -s custom_pairs --pair_file pairs.txt -f
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)

  0%|          | 0/21 [00:00<?, ?it/s]
  5%|▍         | 1/21 [00:00<00:06,  3.16it/s]
 24%|██▍       | 5/21 [00:00<00:01, 14.34it/s]
 48%|████▊     | 10/21 [00:00<00:00, 25.03it/s]
 71%|███████▏  | 15/21 [00:00<00:00, 32.01it/s]
100%|██████████| 21/21 [00:00<00:00, 39.74it/s]
100%|██████████| 21/21 [00:00<00:00, 28.55it/s]

  0%|          | 0/20 [00:00<?, ?it/s]
  5%|▌         | 1/20 [00:00<00:05,  3.43it/s]
 15%|█▌        | 3/20 [00:00<00:02,  6.87it/s]
 25%|██