In [None]:
import os
import torch
from torchvision import transforms
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.preprocessing.text import Tokenizer

In [None]:
caption_model_path = '/content/drive/MyDrive/Colab Notebooks/Caption Generation/model.keras'
ocr_model_path = '/content/drive/MyDrive/Colab Notebooks/OCR/model.keras'

In [None]:
caption_model = load_model(caption_model_path)
ocr_model=load_model(ocr_model_path)

In [None]:
caption_model.summary()

In [None]:
ocr_model.summary()

In [None]:
features_directory='/content/drive/MyDrive/Colab Notebooks/Caption Generation'

# Load features from pickle file
pickle_file_path = os.path.join(features_directory, 'img_features.pkl')
with open(pickle_file_path, 'rb') as file:
    loaded_features = pickle.load(file)

In [None]:
tokenizer=Tokenizer()
max_length=35

In [None]:
vgg_model = VGG16()
# restructure the model
vgg_model = Model(inputs=vgg_model.inputs,
                  outputs=vgg_model.layers[-2].output)

In [None]:
image_path = '/content/pngtree-real-life-hot-air-balloon-flying-in-the-sky-image_2622657.jpg'
# load image
image = load_img(image_path, target_size=(224, 224))
# convert image pixels to numpy array
image = img_to_array(image)
# reshape data for model
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
# preprocess image from vgg
image = preprocess_input(image)
# extract features
feature = vgg_model.predict(image, verbose=0)

In [None]:
# Dataset Loading
def load_images_from_folder(folder_path):
    image_paths = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith(('.jpg', '.png'))]
    return image_paths

def idx_to_word(integer, tokenizer):
    for word, index in tokenizer.word_index.items():
        if index == integer:
         return word
    return None

# generate caption for an image
def caption_prediction(model, image, tokenizer, max_length):
    # load image
    image = load_img(image_path, target_size=(224, 224))
    # convert image pixels to numpy array
    image = img_to_array(image)
    # reshape data for model
    image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
    # preprocess image from vgg
    image = preprocess_input(image)
    # extract features
    image = vgg_model.predict(image, verbose=0)

    # add start tag for generation process
    in_text = 'start'
    # iterate over the max length of sequence
    for i in range(max_length):
        # encode input sequence
        sequence = tokenizer.texts_to_sequences([in_text])[0]
        # pad the sequence
        sequence = pad_sequences([sequence], max_length, padding='post')
        # predict next word
        yhat = model.predict([image, sequence], verbose=0)
        # get index with high probability
        yhat = np.argmax(yhat)
        # convert index to word
        word = idx_to_word(yhat, tokenizer)
        # stop if word not found
        if word is None:
            break
        # append word as input for generating next word
        in_text += " " + word
        # stop if we reach end tag
        if word == 'end':
            break
    return in_text

# Text Recognition
def recognize_text(model, image):

    with torch.no_grad():
        text = model(image)
    return text

# Text Prompt Matching
def find_most_relevant_image(prompt, image_paths, caption_model, ocr_model):

    sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')  # Sentence embeddings for similarity
    prompt_embedding = sentence_transformer.encode([prompt])

    relevance_scores = []
    for image_path in image_paths:
        # Load and preprocess image
        image = Image.open(image_path).convert('RGB')
        input_tensor = transform(image).unsqueeze(0).to(device)


        # Generate caption
        caption = caption_prediction(caption_model, image, tokenizer, max_length)
        caption_embedding = sentence_transformer.encode([caption])

        # Recognize text
        recognized_text = recognize_text(text_recognition_model, input_tensor)
        recognized_text_embedding = sentence_transformer.encode([recognized_text])

        # Compute similarity scores
        caption_similarity = cosine_similarity(prompt_embedding, caption_embedding)[0][0]
        text_similarity = cosine_similarity(prompt_embedding, recognized_text_embedding)[0][0]

        # Store max similarity score for this image
        max_similarity = max(caption_similarity, text_similarity)
        relevance_scores.append((max_similarity, image_path))

    # Sort by relevance score
    relevance_scores.sort(key=lambda x: x[0], reverse=True)
    return relevance_scores[0]  # Most relevant image

# Run Retrieval
def main():
    # Dataset folder path
    folder_path = "/path/to/your/dataset"
    image_paths = load_images_from_folder(folder_path)

    text_prompt=input("Enter text prompt: ")

    # Find the most relevant image
    most_relevant_score, most_relevant_image_path = find_most_relevant_image(
        text_prompt, image_paths, caption_model, text_recognition_model
    )

    print(f"Most Relevant Image: {most_relevant_image_path} (Score: {most_relevant_score:.4f})")
    # Display the most relevant image
    image = Image.open(most_relevant_image_path)
    image.show()

# Run the main function
if __name__ == "__main__":
    main()
