<a href="https://colab.research.google.com/github/SumaiyaZohaRODELA/LLava-Faiss/blob/main/llava%26Faiss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install necessary packages
!pip install transformers accelerate torch torchvision faiss-gpu llava

import os
from glob import glob
import numpy as np
import faiss
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import random
from transformers import AutoProcessor, AutoModel

# Load the LLaVA model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "liuhaotian/LLaVA-7b-delta-v0"  # Replace with desired LLaVA model checkpoint
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()

# Folder containing images
image_folder = "/content/image_dataset"

# Get all image paths
image_files = glob(os.path.join(image_folder, "*.jpg"))

# Randomly select 10 images for display
random.seed(42)
selected_images = random.sample(image_files, 10)

# Display the selected images
plt.figure(figsize=(20, 10))
for i, img_path in enumerate(selected_images):
    img = Image.open(img_path)
    plt.subplot(2, 5, i + 1)
    plt.imshow(img)
    plt.axis("off")
plt.show()

# Function to generate LLaVA embeddings
def generate_llava_embeddings(images_path, processor, model, device):
    image_paths = glob(os.path.join(images_path, "*.jpg"))
    embeddings = []

    with torch.no_grad():
        for img_path in image_paths:
            img = Image.open(img_path).convert("RGB")
            inputs = processor(images=[img], return_tensors="pt").to(device)
            outputs = model.get_image_features(**inputs)
            embeddings.append(outputs.cpu().numpy().flatten())

    return embeddings, image_paths

# Generate embeddings using LLaVA
embeddings, image_paths = generate_llava_embeddings(image_folder, processor, model, device)

# Create FAISS index
def create_faiss_index(embeddings, image_paths, output_path):
    dimension = len(embeddings[0])
    index = faiss.IndexFlatIP(dimension)  # Inner product for similarity
    index = faiss.IndexIDMap(index)

    vectors = np.array(embeddings).astype(np.float32)

    # Add vectors to the index
    index.add_with_ids(vectors, np.array(range(len(embeddings))))

    # Save the index
    faiss.write_index(index, output_path)
    print(f"Index created and saved to {output_path}")

    # Save image paths
    with open(output_path + ".paths", "w") as f:
        for img_path in image_paths:
            f.write(img_path + "\n")

    return index

# Save FAISS index
OUTPUT_INDEX_PATH = "/content/vector_llava.index"
index = create_faiss_index(embeddings, image_paths, OUTPUT_INDEX_PATH)

# Load FAISS index
def load_faiss_index(index_path):
    index = faiss.read_index(index_path)
    with open(index_path + ".paths", "r") as f:
        image_paths = [line.strip() for line in f]
    print(f"Index loaded from {index_path}")
    return index, image_paths

index, image_paths = load_faiss_index(OUTPUT_INDEX_PATH)

# Function to retrieve similar images
def retrieve_similar_images(query, processor, model, index, image_paths, top_k=3):
    if isinstance(query, str):  # If query is a path
        query = Image.open(query).convert("RGB")

    inputs = processor(images=[query], return_tensors="pt").to(device)
    with torch.no_grad():
        query_features = model.get_image_features(**inputs).cpu().numpy().astype(np.float32)

    distances, indices = index.search(query_features, top_k)
    retrieved_images = [image_paths[int(idx)] for idx in indices[0]]

    return query, retrieved_images

# Visualize results
def visualize_results(query, retrieved_images):
    plt.figure(figsize=(12, 5))

    # Display the query image
    plt.subplot(1, len(retrieved_images) + 1, 1)
    if isinstance(query, Image.Image):
        plt.imshow(query)
        plt.title("Query Image")
        plt.axis("off")
    else:
        plt.text(0.5, 0.5, f"Query:\n\n '{query}'", fontsize=16, ha="center", va="center")
        plt.axis("off")

    # Display retrieved images
    for i, img_path in enumerate(retrieved_images):
        plt.subplot(1, len(retrieved_images) + 1, i + 2)
        plt.imshow(Image.open(img_path))
        plt.title(f"Match {i + 1}")
        plt.axis("off")

    plt.show()

# Example query and retrieval
query_image_path = "/content/image_dataset/example.jpg"  # Replace with your query image path
query, retrieved_images = retrieve_similar_images(query_image_path, processor, model, index, image_paths, top_k=3)
visualize_results(query, retrieved_images)


[31mERROR: Could not find a version that satisfies the requirement faiss-gpu (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for faiss-gpu[0m[31m
[0m

ModuleNotFoundError: No module named 'faiss'

In [None]:
# Install necessary packages
!pip install transformers accelerate torch torchvision faiss-gpu llava

import os
from glob import glob
import numpy as np
import faiss
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import random
from transformers import AutoProcessor, AutoModel

In [None]:
# Load the LLaVA model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "liuhaotian/LLaVA-7b-delta-v0"  # Replace with desired LLaVA model checkpoint
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()

# Folder containing images
image_folder = "/content/image_dataset"

# Get all image paths
image_files = glob(os.path.join(image_folder, "*.jpg"))


In [None]:
# Randomly select 10 images for display
random.seed(42)
selected_images = random.sample(image_files, 10)

# Display the selected images
plt.figure(figsize=(20, 10))
for i, img_path in enumerate(selected_images):
    img = Image.open(img_path)
    plt.subplot(2, 5, i + 1)
    plt.imshow(img)
    plt.axis("off")
plt.show()

# Function to generate LLaVA embeddings
def generate_llava_embeddings(images_path, processor, model, device):
    image_paths = glob(os.path.join(images_path, "*.jpg"))
    embeddings = []

    with torch.no_grad():
        for img_path in image_paths:
            img = Image.open(img_path).convert("RGB")
            inputs = processor(images=[img], return_tensors="pt").to(device)
            outputs = model.get_image_features(**inputs)
            embeddings.append(outputs.cpu().numpy().flatten())

    return embeddings, image_paths

In [None]:


# Generate embeddings using LLaVA
embeddings, image_paths = generate_llava_embeddings(image_folder, processor, model, device)
