In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.applications import resnet
import matplotlib.pyplot as plt
from pathlib import Path
import os

print("="*60)
print("OUTFIT MATCHER - Re-PolyVore Dataset")
print("="*60)

# Try loading the model with error handling
def load_embedding_model():
    """Load embedding model with multiple fallback options"""

    # Option 1: Try loading siameese_embedding.h5
    try:
        print("\nAttempt 1: Loading 'siameese_embedding.h5'...")
        model = load_model('siameese_embedding.h5', compile=False)
        print(" Successfully loaded siameese_embedding.h5")
        return model
    except Exception as e:
        print(f" Failed: {e}")

    # Option 2: Try loading siamese_model.h5 and extract embedding
    try:
        print("\nAttempt 2: Loading 'siamese_model.h5' and extracting embedding...")
        full_model = load_model('siamese_model.h5', compile=False,
                               custom_objects={'DistanceLayer': DistanceLayer})

        # Extract the embedding model from the siamese network
        # The embedding is the first layer of the siamese network
        embedding_model = full_model.layers[3]  # Adjust index if needed
        print(" Successfully extracted embedding from siamese_model.h5")
        return embedding_model
    except Exception as e:
        print(" Failed: {e}")

    # Option 3: Try loading with keras
    try:
        print("\nAttempt 3: Loading with tf.keras...")
        model = tf.keras.models.load_model('siameese_embedding.h5', compile=False)
        print(" Successfully loaded with tf.keras")
        return model
    except Exception as e:
        print(f" Failed: {e}")

    # Option 4: Rebuild the model architecture
    try:
        print("\nAttempt 4: Rebuilding model architecture and loading weights...")
        from tensorflow.keras import layers, Model

        # Rebuild embedding model architecture (same as training)
        base_cnn = resnet.ResNet50(
            weights="imagenet",
            input_shape=(200, 200, 3),
            include_top=False
        )

        flatten = layers.Flatten()(base_cnn.output)
        dense1 = layers.Dense(512, activation="relu")(flatten)
        dense1 = layers.BatchNormalization()(dense1)
        dense2 = layers.Dense(256, activation="relu")(dense1)
        dense2 = layers.BatchNormalization()(dense2)
        output = layers.Dense(256)(dense2)

        embedding = Model(base_cnn.input, output, name="Embedding")

        # Try to load weights
        embedding.load_weights('siameese_embedding.h5')
        print("Successfully rebuilt model and loaded weights")
        return embedding
    except Exception as e:
        print(f" Failed: {e}")

    raise Exception(" All loading attempts failed. Please check your model files.")

# Custom layer needed for siamese_model.h5
class DistanceLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, anchor, positive, negative):
        ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
        an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
        return (ap_distance, an_distance)

# Load the model
embedding_model = load_embedding_model()

target_shape = (200, 200)

def preprocess_image(filename):
    """Load and preprocess image"""
    image_string = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, target_shape)
    return image

def get_embedding(image_path):
    """Get embedding vector for an image"""
    image = preprocess_image(image_path)
    image_batch = np.expand_dims(image, axis=0)
    image_batch = np.array(image_batch, dtype=np.float32, copy=True)

    # Preprocess for ResNet
    preprocessed = resnet.preprocess_input(image_batch)

    # Get embedding
    embedding = embedding_model.predict(preprocessed, verbose=0)
    return embedding[0]

def compute_cosine_similarity(emb1, emb2):
    """Compute cosine similarity between two embeddings"""
    dot_product = np.dot(emb1, emb2)
    norm1 = np.linalg.norm(emb1)
    norm2 = np.linalg.norm(emb2)
    return dot_product / (norm1 * norm2)

def load_wardrobe(base_path, categories, limit_per_category=None):
    """Load clothing items from specified categories"""
    wardrobe = {}

    print(f"\nLoading wardrobe from '{base_path}'...")

    for category in categories:
        category_path = os.path.join(base_path, category)

        if not os.path.exists(category_path):
            print(f"⚠ Category '{category}' not found at {category_path}")
            continue

        # Get all image files
        image_files = []
        for ext in ['*.jpg', '*.jpeg', '*.png']:
            image_files.extend(Path(category_path).glob(ext))

        # Limit if specified
        if limit_per_category and len(image_files) > limit_per_category:
            image_files = image_files[:limit_per_category]

        wardrobe[category] = [str(f) for f in image_files]
        print(f" Loaded {len(wardrobe[category])} items from '{category}'")

    return wardrobe

def find_best_match(query_path, category_items, top_k=5):
    """Find top-k best matching items from a category"""
    query_emb = get_embedding(query_path)

    results = []
    total = len(category_items)

    print(f"Processing {total} items...")
    for idx, item_path in enumerate(category_items):
        if (idx + 1) % 500 == 0:
            print(f"  Processed {idx + 1}/{total} items...")

        try:
            item_emb = get_embedding(item_path)
            similarity = compute_cosine_similarity(query_emb, item_emb)
            results.append({'path': item_path, 'similarity': similarity})
        except Exception as e:
            continue

    # Sort by similarity
    results.sort(key=lambda x: x['similarity'], reverse=True)
    return results[:top_k]

def visualize_matches(query_path, matches, title="Outfit Matches"):
    """Visualize query image with its matches"""
    n_matches = len(matches)
    fig, axes = plt.subplots(1, n_matches + 1, figsize=(3 * (n_matches + 1), 3))

    # Query image
    query_img = preprocess_image(query_path)
    if n_matches == 0:
        axes.imshow(query_img)
        axes.set_title("Your Item", fontweight='bold')
        axes.axis('off')
    else:
        axes[0].imshow(query_img)
        axes[0].set_title("Your Item", fontweight='bold')
        axes[0].axis('off')

        # Matches
        for i, match in enumerate(matches):
            match_img = preprocess_image(match['path'])
            axes[i + 1].imshow(match_img)
            axes[i + 1].set_title(f"Match {i+1}\n{match['similarity']:.3f}")
            axes[i + 1].axis('off')

    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

def main():
    """Main interactive function"""

    # Get available categories
    base_path = 'Re-PolyVore'

    if not os.path.exists(base_path):
        print(f"\n Error: '{base_path}' folder not found!")
        print("Please make sure Re-PolyVore dataset is in the current directory.")
        return

    available_categories = sorted([d for d in os.listdir(base_path)
                                  if os.path.isdir(os.path.join(base_path, d))])

    print(f"\nAvailable categories:")
    for i, cat in enumerate(available_categories, 1):
        print(f"{i:2d}. {cat}")

    # Category selection
    print("\nChoose wardrobe loading option:")
    print("1. Load specific categories")
    print("2. Load all categories")
    print("3. Load common categories (top, pants, shoes, bag, dress)")

    choice = input("\nEnter your choice (1/2/3): ").strip()

    if choice == '1':
        cat_input = input("Enter category names (comma-separated, e.g., top,pants,shoes): ")
        selected_categories = [c.strip() for c in cat_input.split(',')]
    elif choice == '2':
        selected_categories = available_categories
    else:
        selected_categories = ['top', 'pants', 'shoes', 'bag', 'dress']

    # Filter valid categories
    selected_categories = [c for c in selected_categories if c in available_categories]

    # Check dataset size
    total_items = 0
    for cat in selected_categories:
        cat_path = os.path.join(base_path, cat)
        num_items = len(list(Path(cat_path).glob('*.jpg')))
        total_items += num_items

    limit_per_category = None
    if total_items > 5000:
        print(f"\n Large dataset detected ({total_items} total images)")
        print("Options:")
        print("1. Load ALL images (thorough but slower)")
        print("2. Load limited images (faster, e.g., 100, 500, 1000)")

        limit_choice = input("\nLoad all images? (1=all, 2=limited): ").strip()
        if limit_choice == '2':
            limit_per_category = int(input("Enter max items per category (e.g., 500): ").strip())
        else:
            print("Loading ALL images from each category...")

    # Load wardrobe
    wardrobe = load_wardrobe(base_path, selected_categories, limit_per_category)
    print(f"\n✓ Wardrobe loaded with {len(wardrobe)} categories")

    # Get query image
    print("\n" + "="*60)
    query_path = input("Enter path to your clothing item (or press Enter for sample): ").strip()

    if not query_path or not os.path.exists(query_path):
        # Use a sample from the first category
        first_cat = list(wardrobe.keys())[0]
        query_path = wardrobe[first_cat][0]
        print(f"Using sample: {query_path}")

    # Matching options
    print(f"\nCategories in your wardrobe:")
    for i, (cat, items) in enumerate(wardrobe.items(), 1):
        print(f"{i}. {cat} ({len(items)} items)")

    print("\nMatching Options:")
    print("1. Match with specific category (see top 5 matches)")
    print("2. Get complete outfit (best match from each category)")

    match_choice = input("\nEnter your choice (1 or 2): ").strip()

    if match_choice == '1':
        # Match with specific category
        target_cat = input("Enter category to match with: ").strip()

        if target_cat not in wardrobe:
            print(f" Category '{target_cat}' not loaded in wardrobe")
            return

        print(f"\nSearching through {len(wardrobe[target_cat])} {target_cat} items...")
        matches = find_best_match(query_path, wardrobe[target_cat], top_k=5)

        print(f"\n Found {len(matches)} matches")
        for i, match in enumerate(matches, 1):
            print(f"{i}. {os.path.basename(match['path'])} - Similarity: {match['similarity']:.4f}")

        visualize_matches(query_path, matches, title=f"Best {target_cat} matches")

    else:
        # Complete outfit
        print(f"\nFinding best matches from: {', '.join(wardrobe.keys())}")

        all_matches = {}
        for category, items in wardrobe.items():
            print(f"\nSearching {category}...")
            matches = find_best_match(query_path, items, top_k=1)
            if matches:
                all_matches[category] = matches[0]

        # Display results
        print(f"\n{'='*60}")
        print("COMPLETE OUTFIT")
        print(f"{'='*60}\n")

        for category, match in all_matches.items():
            print(f"{category.upper()}: {os.path.basename(match['path'])}")
            print(f"  Similarity: {match['similarity']:.4f}\n")

        # Visualize each category match
        for category, match in all_matches.items():
            visualize_matches(query_path, [match], title=f"Best {category} match")

if __name__ == "__main__":
    main()

OUTFIT MATCHER - Re-PolyVore Dataset

Attempt 1: Loading 'siameese_embedding.h5'...
✓ Successfully loaded siameese_embedding.h5

Available categories:
 1. bag
 2. bracelet
 3. brooch
 4. dress
 5. earrings
 6. eyewear
 7. gloves
 8. hairwear
 9. hats
10. jumpsuit
11. legwear
12. necklace
13. neckwear
14. outwear
15. pants
16. rings
17. shoes
18. skirt
19. top
20. watches

Choose wardrobe loading option:
1. Load specific categories
2. Load all categories
3. Load common categories (top, pants, shoes, bag, dress)


In [None]:
import zipfile
import os

zip_path = '/content/drive/MyDrive/Colab Notebooks/Re-PolyVore.zip'
extract_path = '/content'

# Create the folder if it doesn't exist
os.makedirs(extract_path, exist_ok=True)

# Extract
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Extraction completed!")


Extraction completed!
