In [7]:
import pickle
import nest_asyncio
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import pickle

# Cho phép FastAPI chạy trong Jupyter
nest_asyncio.apply()

In [5]:
# Định nghĩa lại sau khi load dữ liệu
def find_related_images(query_image_id, top_k=5):
    related_entities = set(df[df['image_id'] == query_image_id]['subject_id']) | \
                       set(df[df['image_id'] == query_image_id]['object_id'])
    if not related_entities:
        return []
    query_vec = sum(torch.tensor(node_embeddings[i]) for i in related_entities) / len(related_entities)
    scores = cosine_similarity(query_vec.reshape(1, -1), node_embeddings)[0]
    top_entity_indices = scores.argsort()[-top_k:][::-1]
    related_images = set()
    for idx in top_entity_indices:
        related_images.update(entity_idx_to_images[idx])
    return list(related_images - {query_image_id})[:top_k]

In [6]:
import pickle
with open("GAT_api_data.pkl", "rb") as f:
    saved = pickle.load(f)

df = saved["df"]
node_embeddings = saved["node_embeddings"]
entity_idx_to_images = saved["entity_idx_to_images"]

In [8]:
app = FastAPI(title="GAT-based Image Search API")

class ImageQuery(BaseModel):
    image_id: str
    top_k: int = 5

@app.post("/search", response_model=List[str])
def search_similar_images(query: ImageQuery):
    return find_related_images(query.image_id, top_k=query.top_k)

@app.get("/")
def root():
    return {"message": "GAT Image Search API is running!"}

In [None]:
uvicorn.run(app, host="127.0.0.1", port=8000)

INFO:     Started server process [11128]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)


INFO:     127.0.0.1:52446 - "GET / HTTP/1.1" 200 OK
INFO:     127.0.0.1:52446 - "GET /favicon.ico HTTP/1.1" 404 Not Found
INFO:     127.0.0.1:52447 - "GET /docs HTTP/1.1" 200 OK
INFO:     127.0.0.1:52447 - "GET /openapi.json HTTP/1.1" 200 OK
INFO:     127.0.0.1:52450 - "POST /search HTTP/1.1" 200 OK
INFO:     127.0.0.1:52450 - "POST /search HTTP/1.1" 200 OK
INFO:     127.0.0.1:52471 - "GET / HTTP/1.1" 200 OK
