# Complete Pet Re-Identification Pipeline
## Ready for API Integration

This notebook demonstrates the complete pipeline:
1. **Visual Processing**: YOLOv8 detection + DINOv2 embeddings
2. **Metadata Matching**: Location, time, physical attributes
3. **Final Scoring**: Weighted combination for match confidence

### Pipeline Flow
```
Input (Lost Post) ──► Image Processing ──► Embedding Extraction ──┐
                                                                   │
Input (Found Post) ─► Image Processing ──► Embedding Extraction ──┤
                                                                   │
                                                                   ▼
                                                          Matching Engine
                                                                   │
                                                                   ├─► Visual Similarity (50%)
                                                                   ├─► Time Score (20%)
                                                                   ├─► Location Score (20%)
                                                                   ├─► Gender Score (10%)
                                                                   │
                                                                   ▼
                                                          Match Result + Confidence
```

### Current Weighting System
- **Visual Match**: 50% - Image similarity using DINOv2
- **Time**: 20% - Recency with exponential decay
- **Location**: 20% - Distance with exponential decay  
- **Gender**: 10% - Gender matching component
- **Pet Type**: Must match (cat/dog/bird/rabbit) - multiplier of 1.0 or 0.0

In [None]:
# Install dependencies (run once)
!pip install torch torchvision transformers ultralytics opencv-python pillow matplotlib numpy -q

In [None]:
import sys
import torch
import cv2
import numpy as np
from PIL import Image
from pathlib import Path
from datetime import datetime, timedelta
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Import vision processing
from ultralytics import YOLO
from transformers import AutoImageProcessor, AutoModel

# Import our matching system
from pet_matching_engine import (
    PetPost, MatchResult, PetMatcher, PetMatchingConfig,
    format_match_result
)

print("✓ All modules imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Step 1: Load Models

Load YOLOv8-Seg and DINOv2 models with caching

In [None]:
# Setup model cache directory
cache_dir = Path('models_cache')
cache_dir.mkdir(exist_ok=True)

print("Loading models...\n")

# 1. YOLOv8 Segmentation Model
print("[1/2] Loading YOLOv8-Seg...")
yolo_path = cache_dir / 'yolov8x-seg.pt'
if not yolo_path.exists():
    print("  Downloading YOLOv8-Seg (~131MB)...")
    detector = YOLO('yolov8x-seg.pt')
    detector.model.to(device)
else:
    print("  Loading from cache...")
    detector = YOLO(str(yolo_path))
    detector.model.to(device)
print("  ✓ YOLOv8-Seg loaded\n")

# 2. DINOv2 Model
print("[2/2] Loading DINOv2-Large...")
model_name = 'facebook/dinov2-large'
processor = AutoImageProcessor.from_pretrained(model_name, cache_dir=str(cache_dir), use_fast=True)
reid_model = AutoModel.from_pretrained(model_name, cache_dir=str(cache_dir))
reid_model = reid_model.to(device)
reid_model.eval()
print("  ✓ DINOv2-Large loaded\n")

print("="*80)
print("All models loaded successfully!")
print("="*80)

## Step 2: Define Processing Functions

Functions for detection, preprocessing, and embedding extraction

In [None]:
def detect_pets_with_segmentation(image_path: str, conf_threshold: float = 0.3) -> List[Dict]:
    """
    Detect pets using YOLOv8 segmentation
    Returns list of detections with bounding boxes and segmentation masks
    """
    results = detector(image_path, conf=conf_threshold, verbose=False)
    detections = []
    
    for result in results:
        if result.masks is None:
            continue
            
        boxes = result.boxes
        masks = result.masks.data.cpu().numpy()
        
        for idx, (box, mask) in enumerate(zip(boxes, masks)):
            class_id = int(box.cls[0])
            class_name = result.names[class_id]
            
            # Filter for cats (15) and dogs (16)
            if class_id not in [15, 16]:
                continue
            
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            confidence = float(box.conf[0])
            
            detections.append({
                'bbox': (x1, y1, x2, y2),
                'class': class_name,
                'confidence': confidence,
                'mask': mask
            })
    
    return detections


def preprocess_crop_advanced(image: np.ndarray, detection: Dict, 
                            padding_percent: float = 0.15) -> np.ndarray:
    """
    Advanced preprocessing with background removal and enhancement
    """
    x1, y1, x2, y2 = detection['bbox']
    mask = detection['mask']
    
    # Resize mask to image dimensions
    h, w = image.shape[:2]
    mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_LINEAR)
    mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
    
    # Create masked image (remove background)
    masked_image = cv2.bitwise_and(image, image, mask=mask_binary)
    
    # Add padding to bounding box
    pad_w = int((x2 - x1) * padding_percent)
    pad_h = int((y2 - y1) * padding_percent)
    
    x1_pad = max(0, x1 - pad_w)
    y1_pad = max(0, y1 - pad_h)
    x2_pad = min(w, x2 + pad_w)
    y2_pad = min(h, y2 + pad_h)
    
    # Crop with padding
    cropped = masked_image[y1_pad:y2_pad, x1_pad:x2_pad]
    
    # CLAHE enhancement on LAB color space
    lab = cv2.cvtColor(cropped, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    l_clahe = clahe.apply(l)
    enhanced = cv2.merge([l_clahe, a, b])
    enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
    
    return enhanced


def extract_embedding_with_tta(image_bgr: np.ndarray, 
                              processor, model, device) -> np.ndarray:
    """
    Extract DINOv2 embedding with test-time augmentation
    Averages embeddings from original and horizontally flipped image
    """
    model.eval()
    embeddings = []
    
    # Convert BGR to RGB
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(image_rgb)
    
    # Original image
    inputs = processor(images=pil_image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()[0]
        embeddings.append(embedding)
    
    # Horizontally flipped image
    pil_image_flip = pil_image.transpose(Image.FLIP_LEFT_RIGHT)
    inputs_flip = processor(images=pil_image_flip, return_tensors="pt")
    inputs_flip = {k: v.to(device) for k, v in inputs_flip.items()}
    
    with torch.no_grad():
        outputs_flip = model(**inputs_flip)
        embedding_flip = outputs_flip.last_hidden_state[:, 0, :].cpu().numpy()[0]
        embeddings.append(embedding_flip)
    
    # Average embeddings
    final_embedding = np.mean(embeddings, axis=0)
    
    # Normalize
    final_embedding = final_embedding / np.linalg.norm(final_embedding)
    
    return final_embedding


def process_pet_images(image_paths: List[str]) -> List[np.ndarray]:
    """
    Complete processing pipeline for multiple images
    Returns list of embeddings
    """
    all_embeddings = []
    
    for img_path in image_paths:
        print(f"  Processing: {Path(img_path).name}")
        
        # Load image
        image = cv2.imread(str(img_path))
        if image is None:
            print(f"    ⚠ Warning: Could not load image")
            continue
        
        # Detect pets
        detections = detect_pets_with_segmentation(str(img_path))
        
        if not detections:
            print(f"    ⚠ Warning: No pets detected")
            continue
        
        print(f"    Found {len(detections)} pet(s)")
        
        # Process each detection
        for idx, det in enumerate(detections):
            # Preprocess
            processed = preprocess_crop_advanced(image, det)
            
            # Extract embedding
            embedding = extract_embedding_with_tta(processed, processor, reid_model, device)
            all_embeddings.append(embedding)
            
            print(f"    ✓ Extracted embedding (detection {idx+1})")
    
    return all_embeddings


print("✓ Processing functions defined")

## Step 3: Complete Matching Function (API-Ready)

This is the main function that will be wrapped in your API endpoint

In [None]:
def match_pet_posts(lost_post_data: Dict, found_post_data: Dict, 
                   detailed: bool = True) -> Dict:
    """
    Complete matching pipeline - ready for API integration
    
    Args:
        lost_post_data: Dictionary with lost pet information
            {
                'id': str,
                'pet_type': str,  # 'cat' or 'dog'
                'description': str (optional),
                'latitude': float,
                'longitude': float,
                'timestamp': datetime,
                'image_paths': List[str],  # paths to images
                'contact_info': str (optional),
                'neutered': bool (optional),
                'gender': str (optional)  # 'male' or 'female'
            }
        found_post_data: Dictionary with found pet information (same structure)
        detailed: Include detailed breakdown in response
    
    Returns:
        Dictionary with match result
    """
    print("\n" + "="*80)
    print("STARTING MATCH PIPELINE")
    print("="*80)
    
    # Step 1: Process lost pet images
    print("\n[1/3] Processing LOST pet images...")
    lost_embeddings = process_pet_images(lost_post_data['image_paths'])
    print(f"  ✓ Extracted {len(lost_embeddings)} embeddings from lost pet\n")
    
    # Step 2: Process found pet images
    print("[2/3] Processing FOUND pet images...")
    found_embeddings = process_pet_images(found_post_data['image_paths'])
    print(f"  ✓ Extracted {len(found_embeddings)} embeddings from found pet\n")
    
    # Create PetPost objects
    lost_post = PetPost(
        id=lost_post_data['id'],
        pet_type=lost_post_data['pet_type'],
        description=lost_post_data.get('description'),
        latitude=lost_post_data['latitude'],
        longitude=lost_post_data['longitude'],
        timestamp=lost_post_data['timestamp'],
        # is_lost=True,  # Not part of PetPost class
        neutered=lost_post_data.get('neutered'),
        gender=lost_post_data.get('gender'),
        embeddings=lost_embeddings
    )
    
    found_post = PetPost(
        id=found_post_data['id'],
        pet_type=found_post_data['pet_type'],
        description=found_post_data.get('description'),
        latitude=found_post_data['latitude'],
        longitude=found_post_data['longitude'],
        timestamp=found_post_data['timestamp'],
        # is_lost=False,  # Not part of PetPost class
        neutered=found_post_data.get('neutered'),
        gender=found_post_data.get('gender'),
        embeddings=found_embeddings
    )
    
    # Step 3: Perform matching
    print("[3/3] Computing match score...")
    matcher = PetMatcher()
    result = matcher.match(lost_post, found_post)
    
    # Format result
    output = format_match_result(result, detailed=detailed)
    
    print("\n" + "="*80)
    print("MATCH RESULT")
    print("="*80)
    print(f"Confidence: {result.confidence:.2f}%")
    print(f"Category: {result.match_category.upper()}")
    print(f"Is Match: {result.is_match}")
    
    if detailed:
        print("\nBreakdown:")
        print(f"  Visual Similarity: {result.visual_similarity:.4f}")
        print(f"  Location Score: {result.location_score:.2f}/100 ({result.distance_km:.2f} km)")
        print(f"  Time Score: {result.time_score:.2f}/100 ({result.time_diff_hours:.1f} hrs)")
        print(f"  Metadata Score: {result.metadata_score:.2f}/100")
    
    print("="*80 + "\n")
    
    return output


print("✓ API-ready matching function defined")

## Step 4: Example Usage

Test the complete pipeline with sample data

In [None]:
# Example: Lost Pet Post
lost_pet = {
    'id': 'LOST_001',
    'pet_type': 'cat',
    'description': 'Orange tabby with white paws, small scar on left ear',
    'latitude': 23.8103,  # Dhaka coordinates
    'longitude': 90.4125,
    'timestamp': datetime(2025, 10, 10, 14, 30),
    'image_paths': [
        'images/a.jpg',
        'images/b.jpg',
        'images/c.jpg'
    ],
    'contact_info': '+880-123-456789',
    'neutered': True,
    'gender': 'male'
}

# Example: Found Pet Post
found_pet = {
    'id': 'FOUND_001',
    'pet_type': 'cat',
    'description': 'Orange cat with white markings, friendly',
    'latitude': 23.8150,  # ~520m away
    'longitude': 90.4180,
    'timestamp': datetime(2025, 10, 12, 9, 15),
    'image_paths': [
        'images/p.jpg',
        'images/q.jpg'
    ],
    'contact_info': '+880-987-654321',
    'neutered': True,
    'gender': 'male'
}

print("Example data structures created")
print("\nTo run matching, execute:")
print(">>> result = match_pet_posts(lost_pet, found_pet)")

In [None]:
# Run the matching pipeline
# Adjust image paths as needed

result = match_pet_posts(lost_pet, found_pet, detailed=True)

## Step 5: Batch Matching (Find Best Matches)

Match one lost pet against multiple found pets

In [None]:
def find_matches_for_lost_pet(lost_post_data: Dict, 
                             found_posts_data: List[Dict],
                             top_k: int = 5,
                             min_confidence: float = 45.0) -> List[Dict]:
    """
    Find best matches for a lost pet from multiple found pets
    
    Args:
        lost_post_data: Lost pet information
        found_posts_data: List of found pet information
        top_k: Return top K matches
        min_confidence: Minimum confidence threshold
    
    Returns:
        List of match results sorted by confidence
    """
    print("\n" + "="*80)
    print(f"BATCH MATCHING: 1 LOST PET vs {len(found_posts_data)} FOUND PETS")
    print("="*80)
    
    # Process lost pet once
    print("\nProcessing LOST pet images...")
    lost_embeddings = process_pet_images(lost_post_data['image_paths'])
    
    lost_post = PetPost(
        id=lost_post_data['id'],
        pet_type=lost_post_data['pet_type'],
        description=lost_post_data.get('description'),
        latitude=lost_post_data['latitude'],
        longitude=lost_post_data['longitude'],
        timestamp=lost_post_data['timestamp'],
        # is_lost=True,  # Not part of PetPost class
        neutered=lost_post_data.get('neutered'),
        gender=lost_post_data.get('gender'),
        embeddings=lost_embeddings
    )
    
    # Process each found pet and compute match
    found_posts = []
    
    for i, found_data in enumerate(found_posts_data, 1):
        print(f"\nProcessing FOUND pet {i}/{len(found_posts_data)}...")
        
        found_embeddings = process_pet_images(found_data['image_paths'])
        
        found_post = PetPost(
            id=found_data['id'],
            pet_type=found_data['pet_type'],
            description=found_data.get('description'),
            latitude=found_data['latitude'],
            longitude=found_data['longitude'],
            timestamp=found_data['timestamp'],
            # is_lost=False,  # Not part of PetPost class
            neutered=found_data.get('neutered'),
            gender=found_data.get('gender'),
            embeddings=found_embeddings
        )
        
        found_posts.append(found_post)
    
    # Find best matches
    print("\nComputing matches...")
    matcher = PetMatcher()
    matches = matcher.find_best_matches(
        lost_post, found_posts, 
        top_k=top_k, 
        min_confidence=min_confidence
    )
    
    # Format results
    results = [format_match_result(m, detailed=True) for m in matches]
    
    # Display summary
    print("\n" + "="*80)
    print("MATCH RESULTS SUMMARY")
    print("="*80)
    print(f"Total candidates: {len(found_posts_data)}")
    print(f"Matches found: {len(results)}\n")
    
    for i, r in enumerate(results, 1):
        print(f"{i}. {r['matched_id']} - Confidence: {r['confidence']}% ({r['match_category']})")
    
    print("="*80 + "\n")
    
    return results


print("✓ Batch matching function defined")

## Step 6: Visualization

Visualize match results

In [None]:
def visualize_match_result(result: Dict, 
                          lost_image_path: str, 
                          found_image_path: str):
    """
    Visualize match comparison between two images
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 7))
    
    # Lost pet image
    lost_img = cv2.imread(lost_image_path)
    lost_img = cv2.cvtColor(lost_img, cv2.COLOR_BGR2RGB)
    axes[0].imshow(lost_img)
    axes[0].set_title(f"LOST PET\nID: {result['query_id']}", fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    # Found pet image
    found_img = cv2.imread(found_image_path)
    found_img = cv2.cvtColor(found_img, cv2.COLOR_BGR2RGB)
    axes[1].imshow(found_img)
    axes[1].set_title(f"FOUND PET\nID: {result['matched_id']}", fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    # Match result
    confidence = result['confidence']
    category = result['match_category'].upper()
    is_match = result['is_match']
    
    if is_match and confidence >= 75:
        color = 'green'
        status = '✓ HIGH CONFIDENCE MATCH'
    elif is_match:
        color = 'orange'
        status = '⚠ MEDIUM CONFIDENCE MATCH'
    else:
        color = 'red'
        status = '✗ NO MATCH'
    
    fig.suptitle(
        f"{status}\nConfidence: {confidence:.1f}% | Category: {category}",
        fontsize=16, fontweight='bold', color=color, y=0.98
    )
    
    # Details
    if 'details' in result:
        details = result['details']
        detail_text = (
            f"Visual Similarity: {details['visual_similarity']:.3f}\n"
            f"Distance: {details['distance_km']:.2f} km\n"
            f"Time Difference: {details['time_diff_hours']:.1f} hours"
        )
        fig.text(0.5, 0.02, detail_text, ha='center', fontsize=11, 
                family='monospace', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.show()


print("✓ Visualization function defined")

## Configuration & Tuning Guide

### Current Weight Configuration

The matching system uses the following weights:
- **Visual (Image Match): 50%** - Most important factor
- **Time: 20%** - Temporal proximity
- **Location: 20%** - Spatial proximity
- **Gender: 10%** - Gender matching component
- **Pet Type Multiplier: × 1.0 or × 0.0** - Must match (cat/dog/bird/rabbit)

**Formula:**
```
weighted_score = (0.50 × visual_score) + (0.20 × time_score) + (0.20 × location_score) + (0.10 × gender_score)
final_score = weighted_score × pet_type_multiplier
```

### Custom Configuration (Advanced)

Modify `PetMatchingConfig` class for different scenarios:

```python
# Custom matcher with adjusted thresholds
config = PetMatchingConfig()

# Distance thresholds (exponential decay)
config.DISTANCE_PERFECT_MATCH_KM = 10.0  # Perfect match zone
config.DISTANCE_HALF_LIFE_KM = 15.0      # Half-life for decay

# Time thresholds (exponential decay)
config.TIME_PERFECT_MATCH_DAYS = 3.0     # Perfect match zone
config.TIME_HALF_LIFE_DAYS = 14.0        # Half-life for decay
config.TIME_MAX_DAYS = 90                # Hard cutoff

# Confidence thresholds
config.CONFIDENCE_HIGH = 75.0            # High confidence match
config.CONFIDENCE_MEDIUM = 60.0          # Medium confidence
config.CONFIDENCE_LOW = 45.0             # Low confidence

# Weight configuration
config.WEIGHT_VISUAL = 0.50              # Visual similarity weight (50%)
config.WEIGHT_TIME = 0.20                # Time proximity weight (20%)
config.WEIGHT_LOCATION = 0.20            # Location proximity weight (20%)
config.WEIGHT_GENDER = 0.10              # Gender match weight (10%)

matcher = PetMatcher(config)
```

### API Integration Notes

1. **Image Storage**: Save uploaded images with unique IDs
2. **Caching**: Cache embeddings in database to avoid reprocessing
3. **Async Processing**: Use background tasks for batch matching
4. **Rate Limiting**: Limit API calls to prevent abuse
5. **Error Handling**: Validate all inputs, handle missing images gracefully

### Expected Performance

- **Same Pet (Good Images)**: 75-90% confidence
- **Same Pet (Different Poses)**: 60-75% confidence
- **Similar Looking Pets**: 45-60% confidence
- **Different Pets**: <45% confidence

### Match Categories

- **High (≥75%)**: Very likely the same pet, notify owner immediately
- **Medium (60-74%)**: Probable match, show for manual verification
- **Low (45-59%)**: Possible match, include in extended results
- **No Match (<45%)**: Unlikely to be the same pet

### Next Steps for API

1. Wrap `match_pet_posts()` in FastAPI/Flask endpoint
2. Add database integration for caching embeddings
3. Implement batch processing queue
4. Add authentication & rate limiting
5. Deploy with GPU support for fast inference