In [1]:
import streamlit as st
import spacy
import pickle
import os
from PIL import Image
from collections import Counter

# Load mô hình NLP
nlp = spacy.load("en_core_web_sm")

# Load encoder và ánh xạ từ model R-GCN
with open("saved_model_R_GCN/entity_encoder.pkl", "rb") as f:
    entity_encoder = pickle.load(f)
with open("saved_model_R_GCN/entity_idx_to_images.pkl", "rb") as f:
    entity_idx_to_images = pickle.load(f)
with open("saved_model_R_GCN/synonym_map.pkl", "rb") as f:
    synonym_map = pickle.load(f)

# Load thư mục ảnh
image_folder = "E:/Download/val2017"

# Hàm lấy cụm danh từ
def get_full_noun_phrase(token):
    mods = [child.text for child in token.children if child.dep_ in ["amod", "compound", "det", "nummod"]]
    return " ".join(mods + [token.text])

# Hàm trích triplet
@st.cache_data
def extract_multiple_triplets(caption):
    doc = nlp(caption)
    triplets = []
    subjects = set()
    objects = set()
    verb_subjects = {}
    spatial_adverbs = {
        "outside", "inside", "nearby", "abroad", "indoors", "outdoors", "underground",
        "overhead", "upstairs", "downstairs", "somewhere", "anywhere", "nowhere",
        "back", "ahead", "overseas", "home", "away"
    }
    for token in doc:
        if token.dep_ in ["nsubj", "nsubjpass"]:
            subj = token.text
            verb_token = token.head
            verb = verb_token.lemma_
            verb_subjects[verb_token] = subj
            subjects.add(subj)
            for child in verb_token.children:
                if child.dep_ in ["dobj", "attr"] and child.pos_ in ["NOUN", "PROPN"]:
                    obj = get_full_noun_phrase(child)
                    triplets.append((subj, verb, obj))
                    objects.add(obj)
                elif child.dep_ == "prep":
                    for pobj in child.children:
                        if pobj.dep_ == "pobj":
                            obj = get_full_noun_phrase(pobj)
                            triplets.append((subj, verb, obj))
                            objects.add(obj)
    return list(set(triplets))

# Hàm tìm ảnh
@st.cache_data
def find_images_by_entities_prioritize_intersection(caption):
    triplets = extract_multiple_triplets(caption)

    def normalize(word):
        if not word:
            return None
        word_lower = word.lower()
        word_norm = synonym_map.get(word_lower, word_lower)
        if word_norm in entity_encoder.classes_:
            return word_norm
        if word_norm.endswith("ing"):
            root = word_norm[:-3]
            if root in entity_encoder.classes_:
                return root
        lemma = nlp(word_norm)[0].lemma_
        if lemma in entity_encoder.classes_:
            return lemma
        return word_norm

    def get_id(word):
        try:
            return entity_encoder.transform([word])[0]
        except:
            return None

    image_counter = Counter()

    for subj_raw, pred_raw, obj_raw in triplets:
        subj = normalize(subj_raw)
        pred = normalize(pred_raw)
        obj = normalize(obj_raw)
        subj_id = get_id(subj)
        pred_id = get_id(pred)
        obj_id = get_id(obj)
        imgs = set()
        if subj_id is not None and obj_id is not None:
            subj_imgs = set(entity_idx_to_images.get(subj_id, []))
            obj_imgs = set(entity_idx_to_images.get(obj_id, []))
            core_imgs = subj_imgs & obj_imgs
            if pred_id is not None:
                pred_imgs = set(entity_idx_to_images.get(pred_id, []))
                imgs = core_imgs & pred_imgs
                if not imgs:
                    imgs = subj_imgs | obj_imgs | pred_imgs
            else:
                imgs = core_imgs
        elif subj_id is not None and pred_id is not None:
            imgs = set(entity_idx_to_images.get(subj_id, [])) & set(entity_idx_to_images.get(pred_id, []))
        elif obj_id is not None and pred_id is not None:
            imgs = set(entity_idx_to_images.get(obj_id, [])) & set(entity_idx_to_images.get(pred_id, []))
        elif subj_id is not None:
            imgs = set(entity_idx_to_images.get(subj_id, []))
        elif obj_id is not None:
            imgs = set(entity_idx_to_images.get(obj_id, []))
        elif pred_id is not None:
            imgs = set(entity_idx_to_images.get(pred_id, []))
        image_counter.update(imgs)

    sorted_image_ids = [img_id for img_id, _ in image_counter.most_common()]
    filenames = [f"{int(img_id):012}.jpg" for img_id in sorted_image_ids]
    return filenames

# === Giao diện Streamlit ===
st.set_page_config(
    page_title="Find images",
    layout="wide",
    initial_sidebar_state="auto"
)

st.title("Find related images from the caption")
caption = st.text_input("Enter image description (caption):")

if caption:
    filenames = find_images_by_entities_prioritize_intersection(caption)
    if not filenames:
        st.warning("No matching image found.")
    else:
        st.success(f"Found {len(filenames)} related images.")
        cols = st.columns(3)
        for i, filename in enumerate(filenames[:9]):
            col = cols[i % 3]
            image_path = os.path.join(image_folder, filename)
            if os.path.exists(image_path):
                col.image(Image.open(image_path), caption=filename, use_column_width=True)
            else:
                col.write(f"[Image file not found: {filename}]")


2025-06-08 11:22:48.521 
  command:

    streamlit run C:\Users\admin\anaconda3\envs\coco_kg\lib\site-packages\ipykernel_launcher.py [ARGUMENTS]
2025-06-08 11:22:48.535 Session state does not function when running a script without `streamlit run`
