In [None]:
!pip install fastapi uvicorn pillow torch tensorflow

# Optionally, ins tall localtunnel if you want to expose your local server to the internet
!npm install -g localtunnel

In [None]:
%%writefile app.py
import os
import sys
import time
import json
import torch
import numpy as np
from PIL import Image as PILImage
import requests
from io import BytesIO
from pathlib import Path
from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from pinecone import Pinecone, ServerlessSpec
import pandas as pd
from pinecone_text.sparse import BM25Encoder
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

# Setup project root and import paths
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
sys.path.insert(0, str(PROJECT_ROOT))

# Initialize FastAPI
app = FastAPI()

# Allow CORS (Optional, adjust origins accordingly)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allow all methods
    allow_headers=["*"],  # Allow all headers
)

# Setup Pinecone for search
pc = Pinecone(api_key='2f1689ff-2b91-4d12-8657-8ccd3c15abb6')

# Config for hybrid search
index_name_hybrid = "text-search"
index_name_image = "image-retrieval"
cloud = os.environ.get('PINECONE_CLOUD') or 'aws'
region = os.environ.get('PINECONE_REGION') or 'us-east-1'
spec = ServerlessSpec(cloud=cloud, region=region)

# Create and connect to hybrid search index
if index_name_hybrid not in pc.list_indexes().names():
    pc.create_index(
        index_name_hybrid,
        dimension=384,
        metric='dotproduct',
        spec=spec
    )
    while not pc.describe_index(index_name_hybrid).status['ready']:
        time.sleep(1)
index_hybrid = pc.Index(index_name_hybrid)

# Create and connect to image retrieval index
if index_name_image not in pc.list_indexes().names():
    pc.create_index(
        index_name_image,
        dimension=512,
        metric='dotproduct',
        spec=spec
    )
    while not pc.describe_index(index_name_image).status['ready']:
        time.sleep(1)
index_image = pc.Index(index_name_image)

# Setup Sentence Transformer for hybrid search
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_hybrid = SentenceTransformer(
    'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
    device=device
)

# Setup BM25
data = pd.read_csv('/content/rtr_combined.csv')
data.drop_duplicates()
data['title'] = data['title'] + ' ' + data['designer']
bm25 = BM25Encoder()
bm25.fit(data['title'])

# Setup Fashion Clip for image search
processor_image = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
model_image = AutoModelForZeroShotImageClassification.from_pretrained("patrickjohncyh/fashion-clip")
model_image = model_image.to(device)

# Load TensorFlow classification model
classification_model_path = '/content/model_frozen.h5'
classification_model = load_model(classification_model_path)

# Load class indices
class_indices_path = '/content/class_indices.json'
with open(class_indices_path, 'r') as f:
    class_indices = json.load(f)
    class_labels = {v: k for k, v in class_indices.items()}

# Functions for hybrid search
def hybrid_score_norm(dense, sparse, alpha: float):
    if alpha < 0 or alpha > 1:
        raise ValueError("Alpha must be between 0 and 1")
    hs = {
        'indices': sparse['indices'],
        'values':  [v * (1 - alpha) for v in sparse['values']]
    }
    return [v * alpha for v in dense], hs

def retrieve_by_filter(query, category=None, top_k=20, alpha=0.5):
    sparse = bm25.encode_queries(query)
    dense = model_hybrid.encode(query).tolist()
    hdense, hsparse = hybrid_score_norm(dense, sparse, alpha=alpha)
    if category is None:
        result = index_hybrid.query(
            top_k=top_k,
            vector=hdense,
            sparse_vector=hsparse,
            include_metadata=True
        )
    else:
        result = index_hybrid.query(
            top_k=top_k,
            vector=hdense,
            sparse_vector=hsparse,
            filter={
                "category": {"$eq": category}
            },
            include_metadata=True
        )
    return [d['id'] for d in result['matches']]

# GET API for text-based search
@app.get("/search/")
async def search(
    query: str,
    category: Optional[str] = None,
    top_k: int = Query(20, ge=1),
    alpha: float = Query(0.5, ge=0, le=1)
):
    try:
        results = retrieve_by_filter(
            query=query,
            category=category,
            top_k=top_k,
            alpha=alpha
        )
        return JSONResponse(content={"result": results})
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# POST API for image-based search
@app.post("/image_search/")
async def image_search(
    image_url: str = Form(...)
):
    try:
        # Fetch and process the image from the URL
        response = requests.get(image_url)
        image = PILImage.open(BytesIO(response.content))
        inputs = processor_image(images=image, return_tensors="pt").to(device)
        image_features = model_image.get_image_features(**inputs)
        image_features = image_features.detach().cpu().numpy()[0].tolist()

        # Search
        result = index_image.query(
            top_k=20,
            vector=image_features,
            include_metadata=True,
        )

        # Extract only the part before the '_'
        ids = [d['id'].split('_')[0] for d in result['matches']]
        return JSONResponse(content={"result": ids})

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# POST API for image classification
@app.post("/predict_classification/")
async def predict_classification(
    image_url: str = Form(...)
):
    try:
        # Fetch and preprocess the image from the URL
        response = requests.get(image_url)
        img = PILImage.open(BytesIO(response.content)).resize((224, 224))  # Ensure target size matches model input size
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)
        img_array = preprocess_input(img_array)  # Ensure preprocessing matches training phase

        # Make a prediction
        predictions = classification_model.predict(img_array)
        predicted_class = np.argmax(predictions, axis=1)

        # Map the predicted class to its label
        predicted_label = class_labels[predicted_class[0]]

        return JSONResponse(content={"predicted_label": predicted_label})
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# # POST API for image-based search
# @app.post("/image_search/")
# async def image_search(
#     image_file: UploadFile = File(...)
# ):
#     try:
#         image = PILImage.open(image_file.file)
#         inputs = processor_image(images=image, return_tensors="pt").to(device)
#         image_features = model_image.get_image_features(**inputs)
#         image_features = image_features.detach().cpu().numpy()[0].tolist()

#         # search
#         result = index_image.query(
#             top_k=20,
#             vector=image_features,
#             include_metadata=True,
#         )

#         # Extract only the part before the '_'
#         ids = [d['id'].split('_')[0] for d in result['matches']]
#         return JSONResponse(content={"result": ids})

#     except Exception as e:
#         raise HTTPException(status_code=500, detail=str(e))

# # POST API for image classification
# @app.post("/predict_classification/")
# async def predict_classification(
#     image_file: UploadFile = File(...)
# ):
#     try:
#         # Read and preprocess the image
#         img = PILImage.open(image_file.file).resize((224, 224))  # Ensure target size matches model input size
#         img_array = image.img_to_array(img)
#         img_array = np.expand_dims(img_array, axis=0)
#         img_array = preprocess_input(img_array)  # Ensure preprocessing matches training phase

#         # Make a prediction
#         predictions = classification_model.predict(img_array)
#         predicted_class = np.argmax(predictions, axis=1)

#         # Map the predicted class to its label
#         predicted_label = class_labels[predicted_class[0]]

#         return JSONResponse(content={"predicted_label": predicted_label})
#     except Exception as e:
#         raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8080, reload=True)


In [None]:
# Change directory to the app's location
# %cd /content/OOTDiffusion/run

import subprocess
import signal
import nest_asyncio
import uvicorn
from pyngrok import ngrok
import requests
import os

# Apply nest_asyncio
nest_asyncio.apply()

# Kill any existing Uvicorn processes using the port
try:
    # Check for existing process using the port and kill it
    pid = int(subprocess.check_output(["lsof", "-t", "-i:8000"]).strip())
    os.kill(pid, signal.SIGKILL)
    print(f"Killed existing process on port 8000 with PID {pid}")
except subprocess.CalledProcessError:
    print("No existing process on port 8000")

# Create an ngrok tunnel
public_url = ngrok.connect(8000).public_url
print(f"ngrok tunnel available at: {public_url}")

# Send the public URL to the specified endpoint
response = requests.post(
    "https://amoure-backend-awu2hmc2hq-et.a.run.app/ml/ngrok",
    json={"url": public_url, "type":"DEFAULT"}
)
print(f"Response from server: {response.status_code} - {response.text}")

# Function to run the FastAPI server
def run_uvicorn():
    uvicorn.run("app:app", host="0.0.0.0", port=8000)

# Start the Uvicorn server directly
run_uvicorn()
