# Dynamic Textile Matching with TTA

This notebook implements a production-ready textile matching system with:
- Automatic ID assignment for any images dropped in source folder
- Incremental embedding generation (only processes new images)
- TTA (Test-Time Augmentation) for robust matching
- Smart caching and mapping system

## 1. Setup and Configuration

In [None]:
import os
# Set HuggingFace token from environment variable or prompt user
if 'HF_TOKEN' not in os.environ:
    print("⚠️  HF_TOKEN not found in environment variables")
    print("Please set it using one of these methods:")
    print("  1. Export before running notebook: export HF_TOKEN='your_token'")
    print("  2. Create .env file with: HF_TOKEN=your_token")
    print("  3. Run: huggingface-cli login")
    # Optionally prompt for token (will not be saved in notebook)
    # os.environ['HF_TOKEN'] = input("Enter HF token: ")
else:
    print("✅ HF_TOKEN found in environment")

import torch
import numpy as np
from pathlib import Path
import pickle
import json
import hashlib
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm import tqdm
import torch.nn.functional as F
import time
from datetime import datetime
import shutil

# DINOv3 imports
from transformers import AutoImageProcessor, AutoModel

# Configuration
BASE_DIR = Path("/workspace/textile_matching")
SOURCE_IMAGES_DIR = BASE_DIR / "source_images"  # Drop images here
EMBEDDINGS_DIR = BASE_DIR / "embeddings_tta"    # TTA embeddings stored here
MAPPINGS_DIR = BASE_DIR / "mappings"            # ID mappings stored here
QUERIES_DIR = BASE_DIR / "queries"              # Query images

# Create directories if they don't exist
SOURCE_IMAGES_DIR.mkdir(exist_ok=True)
EMBEDDINGS_DIR.mkdir(exist_ok=True)
MAPPINGS_DIR.mkdir(exist_ok=True)
QUERIES_DIR.mkdir(exist_ok=True)

print(f"📁 Source images directory: {SOURCE_IMAGES_DIR}")
print(f"📁 Embeddings directory: {EMBEDDINGS_DIR}")
print(f"📁 Mappings directory: {MAPPINGS_DIR}")
print(f"📁 Queries directory: {QUERIES_DIR}")

## 2. Load DINOv3 Model

In [3]:
# Initialize DINOv3 model
MODEL_NAME = "facebook/dinov3-vitl16-pretrain-lvd1689m"

print("Loading DINOv3 model...")
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(
    MODEL_NAME,
    device_map="auto"
)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Model loaded: {MODEL_NAME}")
print(f"🖥️ Using device: {device}")

Loading DINOv3 model...


Fetching 1 files: 100%|██████████| 1/1 [00:00<00:00, 567.10it/s]


✅ Model loaded: facebook/dinov3-vitl16-pretrain-lvd1689m
🖥️ Using device: cuda


## 3. TTA Feature Extraction Functions

In [None]:
def extract_dinov3_embedding_single(image, model, processor):
    """
    Extract DINOv3 CLS token embedding from a PIL Image.
    """
    inputs = processor(images=image, return_tensors="pt", do_resize=False, do_center_crop=False)
    
    if torch.cuda.is_available():
        inputs = {k: v.cuda() for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    cls_token = outputs.last_hidden_state[:, 0, :].squeeze()
    return cls_token.cpu()


def extract_dinov3_embedding_tta(image_path, model, processor, 
                                 rotations=[0, 90, 180, 270]):
    """
    Extract DINOv3 embedding with Test-Time Augmentation (rotation).
    """
    image = Image.open(image_path).convert('RGB')
    embeddings = []
    
    for angle in rotations:
        if angle != 0:
            rotated_img = image.rotate(-angle, expand=True)
        else:
            rotated_img = image
        
        embedding = extract_dinov3_embedding_single(rotated_img, model, processor)
        embeddings.append(embedding)
    
    # Average embeddings across all rotations
    embeddings_tensor = torch.stack(embeddings)
    averaged_embedding = embeddings_tensor.mean(dim=0)
    
    return averaged_embedding

## 4. Database Management System

In [None]:
class TextileDatabase:
    def __init__(self, source_dir, embeddings_dir, mappings_dir, model, processor):
        self.source_dir = Path(source_dir)
        self.embeddings_dir = Path(embeddings_dir)
        self.mappings_dir = Path(mappings_dir)
        self.mapping_file = self.mappings_dir / "image_mapping.json"
        self.model = model
        self.processor = processor
        
        # Load or initialize mapping
        self.mapping = self.load_mapping()
        
    def load_mapping(self):
        """Load existing mapping or create new one."""
        if self.mapping_file.exists():
            with open(self.mapping_file, 'r') as f:
                return json.load(f)
        else:
            return {"images": {}, "next_id": 1}
    
    def save_mapping(self):
        """Save mapping to JSON file."""
        with open(self.mapping_file, 'w') as f:
            json.dump(self.mapping, f, indent=2)
    
    def get_textile_id(self, image_path):
        """Get or assign textile ID for an image."""
        # Use relative path as key
        rel_path = str(Path(image_path).relative_to(self.source_dir))
        
        if rel_path in self.mapping["images"]:
            return self.mapping["images"][rel_path]["id"]
        else:
            # Assign new ID
            textile_id = f"textile_{self.mapping['next_id']:03d}"
            self.mapping["images"][rel_path] = {
                "id": textile_id,
                "original_name": Path(image_path).name,
                "added_date": datetime.now().isoformat(),
                "file_size": Path(image_path).stat().st_size
            }
            self.mapping["next_id"] += 1
            self.save_mapping()
            return textile_id
    
    def scan_and_update(self):
        """Scan source directory and update database."""
        # Get all images in source directory
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
        all_images = []
        
        for ext in image_extensions:
            all_images.extend(self.source_dir.glob(f"*{ext}"))
            all_images.extend(self.source_dir.glob(f"*{ext.upper()}"))
        
        all_images = sorted(set(all_images))  # Remove duplicates and sort
        
        print(f"\n📷 Found {len(all_images)} images in source directory")
        
        # Track new and existing
        new_images = []
        existing_images = []
        
        for img_path in all_images:
            textile_id = self.get_textile_id(img_path)
            embedding_path = self.embeddings_dir / f"{textile_id}.pkl"
            
            if embedding_path.exists():
                existing_images.append((img_path, textile_id))
            else:
                new_images.append((img_path, textile_id))
        
        print(f"✅ {len(existing_images)} images already have embeddings")
        print(f"🆕 {len(new_images)} new images need embeddings")
        
        return new_images, existing_images
    
    def generate_embeddings_for_new_images(self, new_images):
        """Generate TTA embeddings only for new images."""
        if not new_images:
            print("No new images to process.")
            return
        
        print(f"\n🔄 Generating TTA embeddings for {len(new_images)} new images...")
        start_time = time.time()
        
        for img_path, textile_id in tqdm(new_images, desc="Processing"):
            # Generate TTA embedding
            embedding = extract_dinov3_embedding_tta(img_path, self.model, self.processor)
            
            # Save embedding
            embedding_path = self.embeddings_dir / f"{textile_id}.pkl"
            with open(embedding_path, 'wb') as f:
                pickle.dump(embedding, f)
            
            print(f"  ✓ {textile_id} <- {img_path.name}")
        
        elapsed = time.time() - start_time
        print(f"\n✅ Completed in {elapsed:.1f}s ({elapsed/len(new_images):.2f}s per image)")
    
    def load_all_embeddings(self):
        """Load all embeddings from disk."""
        embeddings = {}
        
        # Create reverse mapping from ID to original name
        id_to_info = {}
        for rel_path, info in self.mapping["images"].items():
            id_to_info[info["id"]] = {
                "original_name": info["original_name"],
                "rel_path": rel_path
            }
        
        # Load embeddings
        embedding_files = sorted(self.embeddings_dir.glob("*.pkl"))
        print(f"\n📦 Loading {len(embedding_files)} embeddings...")
        
        for emb_path in tqdm(embedding_files):
            textile_id = emb_path.stem
            with open(emb_path, 'rb') as f:
                embedding = pickle.load(f)
            
            # Store with both ID and original name for reference
            if textile_id in id_to_info:
                embeddings[textile_id] = {
                    'embedding': embedding,
                    'original_name': id_to_info[textile_id]['original_name'],
                    'rel_path': id_to_info[textile_id]['rel_path']
                }
        
        print(f"✅ Loaded {len(embeddings)} embeddings")
        return embeddings
    
    def update_database(self):
        """Main function to update the database."""
        print("="*60)
        print("DATABASE UPDATE")
        print("="*60)
        
        # Scan for new images
        new_images, existing_images = self.scan_and_update()
        
        # Generate embeddings for new images
        if new_images:
            self.generate_embeddings_for_new_images(new_images)
        
        # Load all embeddings
        embeddings = self.load_all_embeddings()
        
        return embeddings

## 5. Initialize Database and Process Images

In [None]:
# Initialize database
db = TextileDatabase(
    source_dir=SOURCE_IMAGES_DIR,
    embeddings_dir=EMBEDDINGS_DIR,
    mappings_dir=MAPPINGS_DIR,
    model=model,
    processor=processor
)

# Update database (scan for new images and generate embeddings)
database_embeddings = db.update_database()

## 6. Search Functions

In [None]:
def compute_similarities(query_embedding, database_embeddings):
    """
    Compute cosine similarities between query and database embeddings.
    """
    query_norm = F.normalize(query_embedding.unsqueeze(0), p=2, dim=-1)
    
    similarities = []
    for textile_id, data in database_embeddings.items():
        db_embedding = data['embedding']
        db_norm = F.normalize(db_embedding.unsqueeze(0), p=2, dim=-1)
        
        similarity = torch.matmul(query_norm, db_norm.T).item()
        similarities.append({
            'textile_id': textile_id,
            'original_name': data['original_name'],
            'similarity': similarity,
            'rel_path': data['rel_path']
        })
    
    # Sort by similarity (highest first)
    similarities.sort(key=lambda x: x['similarity'], reverse=True)
    
    return similarities


def search_textile(query_path, database_embeddings, model, processor, top_k=5):
    """
    Search for similar textiles using TTA query embedding.
    """
    print(f"\n🔍 Searching for: {query_path}")
    
    # Extract TTA embedding for query
    print("Extracting query embedding with TTA...")
    query_embedding = extract_dinov3_embedding_tta(query_path, model, processor)
    
    # Compute similarities
    similarities = compute_similarities(query_embedding, database_embeddings)
    
    # Return top-k matches
    return similarities[:top_k]

## 7. Run Query Search

In [None]:
# Specify query image path
# You can change this to any query image path
QUERY_IMAGE = QUERIES_DIR / "query_image.jpeg"  # Change to "query_2.png" for second query

# Alternative: directly specify path
# QUERY_IMAGE = Path("/workspace/query_2.png")

if not QUERY_IMAGE.exists():
    print(f"❌ Query image not found: {QUERY_IMAGE}")
    print(f"Available queries: {list(QUERIES_DIR.glob('*'))}")
else:
    # Search for similar textiles
    print("="*60)
    print("TEXTILE SEARCH")
    print("="*60)
    
    matches = search_textile(
        QUERY_IMAGE, 
        database_embeddings, 
        model, 
        processor, 
        top_k=10
    )
    
    # Display results
    print("\n📊 Top 10 Matches:")
    print("-"*50)
    for i, match in enumerate(matches, 1):
        print(f"{i:2d}. {match['textile_id']:12s} | Sim: {match['similarity']:.4f} | {match['original_name']}")

## 8. Visualize Results

In [None]:
def visualize_search_results(query_path, matches, source_dir, top_k=5):
    """
    Visualize query image and top matches.
    """
    fig = plt.figure(figsize=(20, 8))
    gs = gridspec.GridSpec(2, top_k+1, figure=fig, hspace=0.3, wspace=0.2)
    
    # Load and display query image
    query_img = Image.open(query_path)
    ax_query = fig.add_subplot(gs[:, 0])
    ax_query.imshow(query_img)
    ax_query.set_title(f"Query Image\n{Path(query_path).name}", 
                       fontsize=12, fontweight='bold')
    ax_query.axis('off')
    
    # Display top matches
    for idx, match in enumerate(matches[:top_k]):
        # Load image
        img_path = source_dir / match['rel_path']
        if img_path.exists():
            img = Image.open(img_path)
            
            # Top row - image
            ax_img = fig.add_subplot(gs[0, idx+1])
            ax_img.imshow(img)
            ax_img.set_title(f"#{idx+1} {match['textile_id']}\nSim: {match['similarity']:.3f}", 
                            fontsize=10)
            ax_img.axis('off')
            
            # Bottom row - filename
            ax_text = fig.add_subplot(gs[1, idx+1])
            ax_text.axis('off')
            ax_text.text(0.5, 0.5, match['original_name'][:20], 
                        ha='center', va='center', fontsize=9, 
                        wrap=True, rotation=0)
    
    plt.suptitle(f"Textile Search Results (TTA Embeddings)", 
                fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return fig

# Visualize if we have matches
if 'matches' in locals() and matches:
    fig = visualize_search_results(
        QUERY_IMAGE, 
        matches, 
        SOURCE_IMAGES_DIR, 
        top_k=5
    )

## 9. Database Statistics

In [None]:
# Show database statistics
print("="*60)
print("DATABASE STATISTICS")
print("="*60)

print(f"\n📊 Summary:")
print(f"  • Total images in database: {len(database_embeddings)}")
print(f"  • Embeddings directory: {EMBEDDINGS_DIR}")
print(f"  • Source images directory: {SOURCE_IMAGES_DIR}")
print(f"  • TTA rotations used: [0°, 90°, 180°, 270°]")
print(f"  • Embedding dimension: {next(iter(database_embeddings.values()))['embedding'].shape[0] if database_embeddings else 'N/A'}")

# Show mapping info
print(f"\n📋 Mapping Statistics:")
print(f"  • Next available ID: textile_{db.mapping['next_id']:03d}")
print(f"  • Mapping file: {db.mapping_file}")

# Show recent additions
if db.mapping["images"]:
    print(f"\n🆕 Recently Added Images:")
    # Sort by date and show last 5
    recent = sorted(
        [(info["added_date"], rel_path, info["id"]) 
         for rel_path, info in db.mapping["images"].items()],
        reverse=True
    )[:5]
    
    for date, rel_path, textile_id in recent:
        print(f"  • {textile_id}: {Path(rel_path).name} (added {date[:10]})")

print("\n" + "="*60)

## 10. Quick Add New Image Function

In [None]:
def add_new_image_to_database(image_path, db):
    """
    Convenience function to add a single new image to the database.
    """
    image_path = Path(image_path)
    
    if not image_path.exists():
        print(f"❌ Image not found: {image_path}")
        return None
    
    # Copy to source directory if not already there
    dest_path = db.source_dir / image_path.name
    if not dest_path.exists():
        shutil.copy2(image_path, dest_path)
        print(f"✅ Copied {image_path.name} to source directory")
    
    # Update database
    print("Updating database...")
    updated_embeddings = db.update_database()
    
    return updated_embeddings

# Example: Add new_dataset_case.jpeg if it exists
new_image_path = Path("/workspace/new_dataset_case.jpeg")
if new_image_path.exists():
    print(f"\n🆕 Adding new image: {new_image_path}")
    database_embeddings = add_new_image_to_database(new_image_path, db)
else:
    print(f"\n💡 To add a new image, place it in: {SOURCE_IMAGES_DIR}")
    print(f"   Then run: database_embeddings = db.update_database()")

## 11. Test with Multiple Queries

In [None]:
# Function to test multiple queries
def test_multiple_queries(query_paths, database_embeddings, model, processor):
    """
    Test multiple query images and show results.
    """
    results = {}
    
    for query_path in query_paths:
        query_path = Path(query_path)
        if query_path.exists():
            print(f"\n🔍 Testing query: {query_path.name}")
            matches = search_textile(query_path, database_embeddings, model, processor, top_k=3)
            
            print(f"Top 3 matches:")
            for i, match in enumerate(matches, 1):
                print(f"  {i}. {match['textile_id']}: {match['similarity']:.4f} ({match['original_name']})")
            
            results[query_path.name] = matches
        else:
            print(f"❌ Query not found: {query_path}")
    
    return results

# List of queries to test
test_queries = [
    QUERIES_DIR / "query_image.jpeg",
    QUERIES_DIR / "query_2.png",
    # Add more query paths as needed
]

# Filter to only existing files
existing_queries = [q for q in test_queries if q.exists()]

if existing_queries:
    print("="*60)
    print("MULTIPLE QUERY TEST")
    print("="*60)
    test_results = test_multiple_queries(existing_queries, database_embeddings, model, processor)
else:
    print("No test queries found. Add images to:", QUERIES_DIR)

## 12. Instructions for Use

### How to use this notebook:

1. **Add reference images**: Drop any textile images into `/workspace/textile_matching/source_images/`
   - Any filename is fine - the system assigns IDs automatically
   - Images are never renamed, just tracked

2. **Run the notebook**: It will:
   - Detect new images automatically
   - Generate TTA embeddings only for new images
   - Load existing embeddings from cache

3. **Search with a query**: 
   - Change `QUERY_IMAGE` in section 7 to your query path
   - Or use the `search_textile()` function directly

4. **Add new images anytime**:
   - Drop new images in source folder
   - Run `database_embeddings = db.update_database()`
   - Only new images are processed

### Features:
- ✅ Incremental updates (only process new images)
- ✅ TTA for robust matching
- ✅ Persistent ID mapping
- ✅ Original filenames preserved
- ✅ Fast loading from cache