# Multimodal Inference with Custom CNN + Text Encoder (Fashion MNIST)

This notebook performs **image-based** and **text-based** search using your custom trained dual-encoder model on Fashion MNIST dataset.

## Features:
- Query by **image** → Get similar fashion items
- Query by **text** → Get matching fashion items
- Dataset: Fashion MNIST (10 fashion categories)
- Intel CPU optimized with OpenVINO
- FAISS-based fast similarity search
- Grayscale image support

## 1. Import Libraries

In [None]:
import sys
import numpy as np
import pickle
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torchvision.transforms as transforms
from openvino.runtime import Core
import faiss

# Add datasets path
sys.path.append(r'e:\Projects\AI Based\RecTrio\datasets\fashion_mnist')
from V1.training.custom_cnn.text_descriptions import FASHION_DESCRIPTIONS, CLASSES

print("✓ Libraries imported")
print(f"  Fashion categories: {len(CLASSES)}")

## 2. Configuration

In [None]:
# Paths
MODEL_DIR = Path(r'e:\Projects\AI Based\RecTrio\V1\models\fashion_cnn')
DATASET_PATH = Path(r'e:\Projects\AI Based\RecTrio\datasets\fashion_mnist\processed\train')
VECTOR_DB_DIR = Path(r'e:\Projects\AI Based\RecTrio\V1\models\fashion_cnn\vector_db')

# Create directories if they don't exist
MODEL_DIR.mkdir(parents=True, exist_ok=True)
VECTOR_DB_DIR.mkdir(parents=True, exist_ok=True)

# Model files
IMAGE_ENCODER_PATH = MODEL_DIR / 'image_encoder.xml'
TEXT_ENCODER_PATH = MODEL_DIR / 'text_encoder.xml'
VOCAB_PATH = MODEL_DIR / 'vocabulary.pkl'

# Vector database files
EMBEDDINGS_FILE = VECTOR_DB_DIR / 'embeddings.npy'
METADATA_FILE = VECTOR_DB_DIR / 'metadata.pkl'
FAISS_INDEX_FILE = VECTOR_DB_DIR / 'faiss_index.bin'

# Configuration from training
IMAGE_SIZE = 224
print(f"✓ Model directory: {MODEL_DIR}")
print(f"✓ Vector DB directory: {VECTOR_DB_DIR}")

print(f"✓ All directories verified/created")
print(f"Model directory: {MODEL_DIR}")
print(f"Vector DB directory: {VECTOR_DB_DIR}")

## 3. Load Vocabulary

In [None]:
print("Loading vocabulary...")
with open(VOCAB_PATH, 'rb') as f:
    vocab = pickle.load(f)

print(f"✓ Vocabulary loaded: {vocab.n_words} words")

## 4. Load OpenVINO Models (Intel CPU Optimized)

In [None]:
print("Initializing OpenVINO Runtime...")
core = Core()

# Load image encoder
print("Loading image encoder...")
image_model = core.compile_model(str(IMAGE_ENCODER_PATH), "CPU")
image_input = image_model.input(0)
image_output = image_model.output(0)
print(f"  Input shape: {image_input.partial_shape}")
print(f"  Output shape: {image_output.partial_shape}")

# Load text encoder
print("Loading text encoder...")
text_model = core.compile_model(str(TEXT_ENCODER_PATH), "CPU")
text_input = text_model.input(0)
text_output = text_model.output(0)
print(f"  Input shape: {text_input.partial_shape}")
print(f"  Output shape: {text_output.partial_shape}")

print("\n✓ OpenVINO models loaded on Intel CPU")

## 5. Define Image Preprocessing

In [None]:
preprocess = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("✓ Image preprocessing defined")

## 6. Embedding Functions

In [None]:
def get_image_embedding(image_path):
    """Generate embedding for a Fashion MNIST image using OpenVINO"""
    # Load and preprocess image (Fashion MNIST is grayscale)
    image = Image.open(image_path).convert('L')  # Convert to grayscale
    image = image.convert('RGB')  # Convert to 3-channel for model compatibility
    image_tensor = preprocess(image).unsqueeze(0)
    
    # Run inference
    result = image_model([image_tensor.numpy()])[image_output]
    embedding = result[0]
    
    # Normalize
    embedding = embedding / np.linalg.norm(embedding)
    return embedding.astype('float32')


def get_text_embedding(text):
    """Generate embedding for text using OpenVINO"""
    # Encode text
    text_indices = vocab.encode(text, MAX_TEXT_LENGTH)
    text_tensor = np.array([text_indices], dtype=np.int64)
    
    # Run inference
    result = text_model([text_tensor])[text_output]
    embedding = result[0]
    
    # Normalize
    embedding = embedding / np.linalg.norm(embedding)
    return embedding.astype('float32')

print("✓ Embedding functions defined")

## 7. Build or Load Embeddings Database

In [None]:
if EMBEDDINGS_FILE.exists() and METADATA_FILE.exists():
    print("✓ Loading existing embeddings database...")
    
    embeddings = np.load(EMBEDDINGS_FILE)
    with open(METADATA_FILE, 'rb') as f:
        metadata = pickle.load(f)
    
    image_paths = metadata['image_paths']
    
    print(f"✓ Loaded {len(embeddings)} embeddings")
    print(f"  Embedding shape: {embeddings.shape}")
    
else:
    print("Building embeddings database from scratch...")
    
    # Collect all images
    image_paths = []
    valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp'}
    
    for class_name in CLASSES:
        class_dir = DATASET_PATH / class_name
        if class_dir.exists():
            for img_path in class_dir.iterdir():
                if img_path.suffix.lower() in valid_extensions:
                    image_paths.append(str(img_path))
    
    print(f"Found {len(image_paths)} images")
    
    # Generate embeddings
    embeddings = []
    valid_paths = []
    
    print("Generating embeddings...")
    for img_path in tqdm(image_paths):
        try:
            embedding = get_image_embedding(img_path)
            embeddings.append(embedding)
            valid_paths.append(img_path)
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
    
    embeddings = np.array(embeddings).astype('float32')
    image_paths = valid_paths
    
    print(f"Generated {len(embeddings)} embeddings")
    
    # Save embeddings
    np.save(EMBEDDINGS_FILE, embeddings)
    print(f"✓ Embeddings saved to {EMBEDDINGS_FILE}")
    
    # Save metadata
    metadata = {
        'image_paths': image_paths,
        'total_images': len(image_paths),
        'embedding_dim': embeddings.shape[1]
    }
    with open(METADATA_FILE, 'wb') as f:
        pickle.dump(metadata, f)
    print(f"✓ Metadata saved to {METADATA_FILE}")

## 8. Build or Load FAISS Index

In [None]:
if FAISS_INDEX_FILE.exists():
    print("✓ Loading existing FAISS index...")
    index = faiss.read_index(str(FAISS_INDEX_FILE))
    print(f"✓ FAISS index loaded with {index.ntotal} vectors")
else:
    print("Building FAISS index...")
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)  # Inner product (cosine similarity)
    index.add(embeddings)
    
    # Save index
    faiss.write_index(index, str(FAISS_INDEX_FILE))
    print(f"✓ FAISS index built with {index.ntotal} vectors")
    print(f"✓ Index saved to {FAISS_INDEX_FILE}")

## 9. Search Function

In [None]:
def search_similar_images(query_embedding, top_k=TOP_K):
    """Search for similar images using FAISS"""
    query_embedding = query_embedding.reshape(1, -1)
    distances, indices = index.search(query_embedding, top_k)
    
    results = []
    for idx, dist in zip(indices[0], distances[0]):
        results.append({
            'path': image_paths[idx],
            'similarity': float(dist),
            'class': Path(image_paths[idx]).parent.name
        })
    
    return results

print("✓ Search function defined")

## 10. Visualization Function

In [None]:
def display_results(results, query_info=None):
    """Display search results in a grid"""
    n_results = len(results)
    cols = 5
    rows = (n_results + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows))
    if n_results == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for idx, result in enumerate(results):
        img = Image.open(result['path'])
        axes[idx].imshow(img)
        axes[idx].axis('off')
        title = f"{result['class']}\nSim: {result['similarity']:.3f}"
        axes[idx].set_title(title, fontsize=10)
    
    # Hide extra subplots
    for idx in range(n_results, len(axes)):
        axes[idx].axis('off')
    
    if query_info:
        fig.suptitle(f"Query: {query_info}", fontsize=14, fontweight='bold', y=1.00)
    
    plt.tight_layout()
    plt.show()

print("✓ Visualization function defined")

## 11. Image-Based Search Example

In [None]:
# Example: Search by image from Fashion MNIST
query_image_path = r"e:\Projects\AI Based\RecTrio\datasets\fashion_mnist\processed\train\tshirt\00001.png"

print(f"Searching for fashion items similar to: {Path(query_image_path).name}")

# Get embedding
query_embedding = get_image_embedding(query_image_path)

# Search
results = search_similar_images(query_embedding, top_k=10)

# Display results
print(f"\nTop {len(results)} similar fashion items:")
for i, result in enumerate(results, 1):
    print(f"{i}. [{result['class']}] {Path(result['path']).name} - Similarity: {result['similarity']:.4f}")

display_results(results, query_info=f"Image: {Path(query_image_path).name}")

## 12. Text-Based Search Example

In [None]:
# Example: Search by text
query_text = "a casual t-shirt with short sleeves"

print(f"Searching for: '{query_text}'")

# Get embedding
query_embedding = get_text_embedding(query_text)

# Search
results = search_similar_images(query_embedding, top_k=10)

# Display results
print(f"\nTop {len(results)} matching fashion items:")
for i, result in enumerate(results, 1):
    print(f"{i}. [{result['class']}] {Path(result['path']).name} - Similarity: {result['similarity']:.4f}")

display_results(results, query_info=f"Text: '{query_text}'")

## 13. Try Different Text Queries

In [None]:
# Test with different fashion descriptions
test_queries = [
    "a warm winter coat with long sleeves",
    "comfortable running sneakers",
    "an elegant dress for women",
    "casual trousers for everyday wear",
    "open-toed summer sandals"
]

for query_text in test_queries:
    print(f"\n{'='*60}")
    print(f"Query: '{query_text}'")
    print('='*60)
    
    query_embedding = get_text_embedding(query_text)
    results = search_similar_images(query_embedding, top_k=5)
    
    for i, result in enumerate(results, 1):
        print(f"{i}. [{result['class']}] {Path(result['path']).name} - Sim: {result['similarity']:.4f}")
    
    display_results(results[:5], query_info=f"Text: '{query_text}'")

## 14. Interactive Search

In [None]:
def interactive_search():
    """Interactive search interface"""
    print("\n" + "="*60)
    print("MULTIMODAL IMAGE SEARCH")
    print("="*60)
    print("Choose search type:")
    print("1. Search by image")
    print("2. Search by text description")
    
    choice = input("\nEnter choice (1 or 2): ").strip()
    
    if choice == "1":
        img_path = input("Enter image path: ").strip()
        if not Path(img_path).exists():
            print("❌ Image not found!")
            return
        
        print(f"\nSearching for images similar to: {Path(img_path).name}")
        query_embedding = get_image_embedding(img_path)
        results = search_similar_images(query_embedding, top_k=10)
        
        print(f"\nTop {len(results)} results:")
        for i, result in enumerate(results, 1):
            print(f"{i}. [{result['class']}] {Path(result['path']).name} - Sim: {result['similarity']:.4f}")
        
        display_results(results, query_info=f"Image: {Path(img_path).name}")
        
    elif choice == "2":
        text_query = input("Enter text description: ").strip()
        if not text_query:
            print("❌ Empty query!")
            return
        
        print(f"\nSearching for: '{text_query}'")
        query_embedding = get_text_embedding(text_query)
        results = search_similar_images(query_embedding, top_k=10)
        
        print(f"\nTop {len(results)} results:")
        for i, result in enumerate(results, 1):
            print(f"{i}. [{result['class']}] {Path(result['path']).name} - Sim: {result['similarity']:.4f}")
        
        display_results(results, query_info=f"Text: '{text_query}'")
    else:
        print("❌ Invalid choice!")

# Run interactive search
interactive_search()

## Summary

### What This Notebook Does:
1. ✅ Loads Intel CPU-optimized OpenVINO models
2. ✅ Generates/loads embeddings for all Fashion MNIST images
3. ✅ Builds FAISS index for fast similarity search
4. ✅ Supports **image query** → get similar fashion items
5. ✅ Supports **text query** → get matching fashion items

### Dataset:
- **Fashion MNIST**: 60,000 training images across 10 fashion categories
- **Categories**: T-shirt, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot
- **Image Format**: Grayscale (converted to 3-channel for model)

### Model Architecture:
- **Image Encoder**: Custom CNN (Intel CPU optimized, grayscale support)
- **Text Encoder**: LSTM-based encoder (Intel CPU optimized)
- **Shared Embedding Space**: 256-dimensional

### Performance:
- Fast inference on Intel CPUs using OpenVINO
- Efficient similarity search with FAISS
- Supports both modalities seamlessly