In [None]:
!pip install open_clip_torch faiss-cpu pyngrok torchscale

In [None]:
import os, io, gc, time, asyncio
import torch, numpy as np
from PIL import Image
from queue import Queue
from threading import Thread
from typing import List
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
import uvicorn
from pyngrok import ngrok

try:
    import nest_asyncio
    nest_asyncio.apply()
except:
    pass


# =========================
# CONFIG
# =========================
CLIP_MODEL_NAME = "ViT-H-14-378-quickgelu"
CLIP_PRETRAINED = "dfn5b"

BIGG_MODEL_NAME = "ViT-bigG-14"
BIGG_PRETRAINED = "laion2b_s39b_b160k"

BEIT3_REPO = "Quintu/beit3"

CLIP_DEVICE = "cuda:0"
BEIT3_DEVICE = "cuda:0"
BIGG_DEVICE = "cuda:1"

PORT = 7000
MAX_BATCH = 24
MAX_WAIT = 0.01

NGROK_AUTH_TOKEN = "36TVMtp50wfQnR4xVZ5allqmD76_oCvuKuUAZfPy5dB69VmJ"


# =========================
# LOAD CLIP H/14
# =========================
import open_clip

clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
    CLIP_MODEL_NAME,
    pretrained=CLIP_PRETRAINED,
    device=CLIP_DEVICE,
)
clip_model.eval()

print(f"‚úÖ CLIP {CLIP_MODEL_NAME} loaded on {CLIP_DEVICE}")

# =========================
# LOAD CLIP bigG
# =========================
print(f"üöÄ Loading CLIP {BIGG_MODEL_NAME} ({BIGG_PRETRAINED}) on {BIGG_DEVICE}...")
bigg_model, _, _ = open_clip.create_model_and_transforms(
    model_name=BIGG_MODEL_NAME,
    pretrained=BIGG_PRETRAINED,
    device=BIGG_DEVICE,
)
bigg_model = bigg_model.to(device=BIGG_DEVICE, dtype=torch.float16).eval()
bigg_tokenizer = open_clip.get_tokenizer(BIGG_MODEL_NAME)
print("‚úÖ CLIP bigG ready.\n")


# =========================
# LOAD BEiT3 (TEXT ONLY)
# =========================
from huggingface_hub import hf_hub_download
from transformers import XLMRobertaTokenizer
from timm.models.layers import trunc_normal_


class BEiT3QdrantSearcher:
    class _BEiT3Wrapper(nn.Module):
        def __init__(self, args):
            super().__init__()
            from torchscale.model.BEiT3 import BEiT3
            self.beit3 = BEiT3(args)
            self.apply(self._init_weights)

        def _init_weights(self, m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    class _BEiT3ForRetrieval(_BEiT3Wrapper):
        def __init__(self, args):
            super().__init__(args)
            d = args.encoder_embed_dim
            self.language_head = nn.Linear(d, d, bias=False)
            self._init_weights(self.language_head)

        @torch.no_grad()
        def encode_text(self, ids, mask):
            out = self.beit3(
                textual_tokens=ids,
                visual_tokens=None,
                text_padding_position=mask,
            )
            x = self.language_head(out["encoder_out"][:, 0])
            return F.normalize(x, dim=-1)

    @staticmethod
    def _cfg():
        from torchscale.architecture.config import EncoderConfig

        return EncoderConfig(
            img_size=384,
            patch_size=16,
            vocab_size=64010,
            multiway=True,
            normalize_output=True,
            no_output_layer=True,
            encoder_embed_dim=1024,
            encoder_layers=24,
            encoder_attention_heads=16,
            encoder_ffn_embed_dim=4096,
        )

    def __init__(self, repo=BEIT3_REPO):
        self.repo = repo
        self.tokenizer = None
        self.model = None

    def load(self, device):
        ckpt_path = hf_hub_download(self.repo, "beit3_large_patch16_384_coco_retrieval.pth")
        spm_path = hf_hub_download(self.repo, "beit3.spm")

        self.tokenizer = XLMRobertaTokenizer(spm_path, use_fast=False)
        args = self._cfg()
        self.model = self._BEiT3ForRetrieval(args).to(device)

        state = torch.load(ckpt_path, map_location="cpu")
        self.model.load_state_dict(state.get("model", state), strict=False)
        self.model.eval()

        print(f"‚úÖ BEiT3 loaded on {device}")

    @torch.no_grad()
    def encode_text(self, texts: List[str]) -> np.ndarray:
        enc = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt",
        )
        ids = enc["input_ids"].to(self.model.language_head.weight.device)
        mask = (enc["attention_mask"] == 0).bool().to(ids.device)
        out = self.model.encode_text(ids, mask)
        return out.cpu().numpy()


beit3 = BEiT3QdrantSearcher()
beit3.load(BEIT3_DEVICE)


# =========================
# QUEUES + WORKERS (AUTO-BATCH FOR SINGLE REQUESTS)
# =========================
clip_img_q = Queue()
clip_txt_q = Queue()
beit3_txt_q = Queue()
bigg_txt_q = Queue()


def run_worker(q: Queue, fn):
    """
    Generic worker:
      - collects requests into `batch`
      - runs `fn(batch)` once
      - splits output row-wise back to callbacks
    """
    def w():
        while True:
            batch, cbs = [], []

            x, cb = q.get()
            batch.append(x)
            cbs.append(cb)

            t0 = time.time()
            while len(batch) < MAX_BATCH and (time.time() - t0) < MAX_WAIT:
                try:
                    x, cb = q.get_nowait()
                    batch.append(x)
                    cbs.append(cb)
                except Exception:
                    break

            out = fn(batch)  # np.array (B, dim)
            # split row-wise
            for o, cb in zip(out, cbs):
                cb(o)

    Thread(target=w, daemon=True).start()


# Image: B is list[tensor (1,C,H,W)]
run_worker(
    clip_img_q,
    lambda B: F.normalize(
        clip_model.encode_image(torch.cat(B).to(CLIP_DEVICE)),
        dim=-1,
    ).detach().cpu().numpy(),
)

# CLIP text (SINGLE): B is list[token_tensor (1,seq)]
run_worker(
    clip_txt_q,
    lambda B: F.normalize(
        clip_model.encode_text(torch.cat(B).to(CLIP_DEVICE)),
        dim=-1,
    ).detach().cpu().numpy(),
)

# BEiT3 text (SINGLE): B is list[str]
run_worker(
    beit3_txt_q,
    lambda B: beit3.encode_text(B),
)

# bigG text (SINGLE): B is list[str]
run_worker(
    bigg_txt_q,
    lambda B: F.normalize(
        bigg_model.encode_text(bigg_tokenizer(B).to(BIGG_DEVICE)),
        dim=-1,
    ).detach().cpu().numpy(),
)


# =========================
# FASTAPI APP
# =========================
app = FastAPI()


# =========================
# REQUEST/RESPONSE MODELS
# =========================
class TextReq(BaseModel):
    text: str


class BatchTextReq(BaseModel):
    texts: List[str]


class EmbeddingResponse(BaseModel):
    model: str
    embedding: List[float]
    dimension: int


class BatchEmbeddingResponse(BaseModel):
    model: str
    embeddings: List[List[float]]
    dimension: int
    count: int


# =========================
# SINGLE REQUEST ENDPOINTS
# =========================
@app.post("/embedding/clip/image", response_model=EmbeddingResponse)
async def clip_img(file: UploadFile = File(...)):
    pil = Image.open(io.BytesIO(await file.read())).convert("RGB")
    x = clip_preprocess(pil).unsqueeze(0)  # (1,C,H,W)

    loop = asyncio.get_running_loop()
    fut = loop.create_future()

    clip_img_q.put((x, lambda r: loop.call_soon_threadsafe(fut.set_result, r)))
    v = await fut  # 1D array (dim,)

    v = v.astype("float32")
    return EmbeddingResponse(
        model="clip-image",
        embedding=v.tolist(),
        dimension=len(v),
    )


@app.post("/embedding/clip/text", response_model=EmbeddingResponse)
async def clip_text(req: TextReq):
    # tokenize single text ‚Üí (1,seq)
    tok = open_clip.tokenize([req.text]).to(CLIP_DEVICE)

    loop = asyncio.get_running_loop()
    fut = loop.create_future()

    clip_txt_q.put((tok, lambda r: loop.call_soon_threadsafe(fut.set_result, r)))
    v = await fut  # 1D array

    v = v.astype("float32")
    return EmbeddingResponse(
        model="clip-text",
        embedding=v.tolist(),
        dimension=len(v),
    )


@app.post("/embedding/beit3/text", response_model=EmbeddingResponse)
async def beit3_text(req: TextReq):
    loop = asyncio.get_running_loop()
    fut = loop.create_future()

    beit3_txt_q.put((req.text, lambda r: loop.call_soon_threadsafe(fut.set_result, r)))
    v = await fut  # 1D array

    v = v.astype("float32")
    return EmbeddingResponse(
        model="beit3-text",
        embedding=v.tolist(),
        dimension=len(v),
    )


@app.post("/embedding/bigg/text", response_model=EmbeddingResponse)
async def bigg_text(req: TextReq):
    loop = asyncio.get_running_loop()
    fut = loop.create_future()

    bigg_txt_q.put((req.text, lambda r: loop.call_soon_threadsafe(fut.set_result, r)))
    v = await fut  # 1D array

    v = v.astype("float32")
    return EmbeddingResponse(
        model="bigg-text",
        embedding=v.tolist(),
        dimension=len(v),
    )


# =========================
# BATCH ENDPOINTS
#  (KH√îNG D√ôNG QUEUE, G·ªåI MODEL TR·ª∞C TI·∫æP)
# =========================
@app.post("/embedding/clip/text/batch", response_model=BatchEmbeddingResponse)
async def clip_text_batch(req: BatchTextReq):
    """
    Batch text embedding cho CLIP.
    G·ªçi model tr·ª±c ti·∫øp (kh√¥ng d√πng queue) ƒë·ªÉ tr√°nh conflict v·ªõi auto-batch.
    """
    texts = req.texts
    if not texts:
        return BatchEmbeddingResponse(
            model="clip-text",
            embeddings=[],
            dimension=0,
            count=0,
        )

    tok = open_clip.tokenize(texts).to(CLIP_DEVICE)
    with torch.no_grad():
        feats = clip_model.encode_text(tok)
        feats = F.normalize(feats, dim=-1)

    v = feats.detach().cpu().numpy().astype("float32")  # (B,dim)
    embeddings = v.tolist()

    dim = v.shape[1] if v.size > 0 else 0
    return BatchEmbeddingResponse(
        model="clip-text",
        embeddings=embeddings,
        dimension=dim,
        count=len(embeddings),
    )


@app.post("/embedding/beit3/text/batch", response_model=BatchEmbeddingResponse)
async def beit3_text_batch(req: BatchTextReq):
    """
    Batch text embedding cho BEiT3.
    D√πng encode_text native (ƒë√£ batch internally).
    """
    texts = req.texts
    if not texts:
        return BatchEmbeddingResponse(
            model="beit3-text",
            embeddings=[],
            dimension=0,
            count=0,
        )

    embeddings_array = beit3.encode_text(texts)  # (B,dim)
    embeddings_array = embeddings_array.astype("float32")
    embeddings = embeddings_array.tolist()

    dim = embeddings_array.shape[1] if embeddings_array.size > 0 else 0
    return BatchEmbeddingResponse(
        model="beit3-text",
        embeddings=embeddings,
        dimension=dim,
        count=len(embeddings),
    )


@app.post("/embedding/bigg/text/batch", response_model=BatchEmbeddingResponse)
async def bigg_text_batch(req: BatchTextReq):
    """
    Batch text embedding cho CLIP bigG.
    G·ªçi model tr·ª±c ti·∫øp (kh√¥ng d√πng queue).
    """
    texts = req.texts
    if not texts:
        return BatchEmbeddingResponse(
            model="bigg-text",
            embeddings=[],
            dimension=0,
            count=0,
        )

    tok = bigg_tokenizer(texts).to(BIGG_DEVICE)
    with torch.no_grad():
        feats = bigg_model.encode_text(tok)
        feats = F.normalize(feats, dim=-1)

    v = feats.detach().cpu().numpy().astype("float32")  # (B,dim)
    embeddings = v.tolist()
    dim = v.shape[1] if v.size > 0 else 0

    return BatchEmbeddingResponse(
        model="bigg-text",
        embeddings=embeddings,
        dimension=dim,
        count=len(embeddings),
    )


# =========================
# HEALTH CHECK
# =========================
@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "models": {
            "clip": CLIP_MODEL_NAME,
            "beit3": "BEiT3 Large",
            "bigg": BIGG_MODEL_NAME,
        },
        "devices": {
            "clip": CLIP_DEVICE,
            "beit3": BEIT3_DEVICE,
            "bigg": BIGG_DEVICE,
        },
        "batch_support": True,
        "max_batch_size": MAX_BATCH,
    }


# =========================
# MAIN
# =========================
if __name__ == "__main__":
    print("=" * 80)
    print("üöÄ Video Retrieval Model Server")
    print("=" * 80)
    print("üì¶ Models loaded:")
    print(f"   - CLIP: {CLIP_MODEL_NAME} on {CLIP_DEVICE}")
    print(f"   - BEiT3: Large on {BEIT3_DEVICE}")
    print(f"   - CLIP bigG: {BIGG_MODEL_NAME} on {BIGG_DEVICE}")
    print(f"‚ö° Auto-batch (queue) for SINGLE endpoints, MAX_BATCH = {MAX_BATCH}")
    print(f"üåê Port: {PORT}")
    print("=" * 80)

    if NGROK_AUTH_TOKEN:
        ngrok.set_auth_token(NGROK_AUTH_TOKEN)
        public_url = ngrok.connect(PORT).public_url
        print(f"üîó Public URL: {public_url}")
        print("=" * 80)

    uvicorn.run(app, host="0.0.0.0", port=PORT)


In [None]:
# !ps aux | grep python

In [None]:
# !kill -9 48