In [None]:
import requests
import json
from PIL import Image
from io import BytesIO
import base64
import numpy as np 
import cupy as cp
import sys
from cuvs.neighbors import hnsw,cagra
import os
import time
sys.path.append('/hpc2hdd/home/ysi538/retrieval')
from MDI_RAG_Image2Image_Research.src.utils.encoder import WSI_Image_UNI_Encoder

# url = '  /hpc/retrieval/query/image2image_retrieval'
url = 'http://localhost:9876/test'
query_img_path = " metaservice/api/region/openslide/241183-21.tiff/6400/25344/256/256/1"
request_data = {
    'query_img_path': query_img_path,
    'top_k': 20
}

response = requests.post(url, json=request_data)

if response.status_code == 200:
    response_data = response.json()
    print("Response JSON:", response_data)
else:
    print("Request failed with status code:", response.status_code) """

In [None]:
def get_combined_regions(result_infos):

    search_info_list = []
    
    for info in result_infos:
        search_info = {}
        level = info.split("_")[-1].split(".")[0]
        w = info.split("_")[-3]
        h = info.split("_")[-2]
        x = info.split("_")[-5]
        y = info.split("_")[-4]
        id = info.split("_")[0]
        name = "_".join(info.split("_")[1:-5])
        search_info = {"id": id, "name": name, "x": x, "y": y, "w": w, "h": h, "level": level}
        search_info_list.append(search_info)

    def dfs(node, component, visited):
        
        visited[node] = True
        
        component.append(search_info_list[node])
        
        
        for neighbor in range(len(search_info_list)):
            if not visited[neighbor] and judge_if_connected(search_info_list[node], search_info_list[neighbor]):
                
                dfs(neighbor, component, visited)

    
    visited = [False] * len(search_info_list)
    components = []

    
    for i in range(len(search_info_list)):
        if not visited[i]:
            
            current_component = []
            
            dfs(i, current_component, visited)
            
            components.append(current_component)

    
    components = [component for component in components if len(component) > 1]

    return components

def judge_if_connected(info1, info2):
    
    if info1["name"] != info2["name"]:
        return False
    if info1["level"] != info2["level"]:
        return False
    if int(info1["x"]) + int(info1["w"]) == int(info2["x"]) and int(info1["y"]) == int(info2["y"]):
        return True
    if int(info1["x"]) - int(info1["w"]) == int(info2["x"]) and int(info1["y"]) == int(info2["y"]):
        return True
    if int(info1["y"]) + int(info1["h"]) == int(info2["y"]) and int(info1["x"]) == int(info2["x"]):
        return True
    if int(info1["y"]) - int(info1["h"]) == int(info2["y"]) and int(info1["x"]) == int(info2["x"]):
        return True
    if int(info1["x"]) + int(info1["w"]) == int(info2["x"]) and int(info1["y"]) + int(info1["h"]) == int(info2["y"]):
        return True
    if int(info1["x"]) - int(info1["w"]) == int(info2["x"]) and int(info1["y"]) - int(info1["h"]) == int(info2["y"]):
        return True
    if int(info1["x"]) + int(info1["w"]) == int(info2["x"]) and int(info1["y"]) - int(info1["h"]) == int(info2["y"]):
        return True
    if int(info1["x"]) - int(info1["w"]) == int(info2["x"]) and int(info1["y"]) + int(info1["h"]) == int(info2["y"]):
        return True
    return False

def request_image(query_img_path):
    if "http" in query_img_path:
        response = requests.get(query_img_path, verify=False)
        query_image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        query_image = Image.open(query_img_path).convert("RGB")
    return query_image

def split_img(query_image, other_info):
    imgs = []
    
    w, h = query_image.size
    if w == h:
        return [query_image]
    if w > h:
        new_w = h
        new_h = h
    else:
        new_w = w
        new_h = w
    for i in range(0, w, new_w):
        for j in range(0, h, new_h):
            box = (i, j, i + new_w, j + new_h)
            img = query_image.crop(box)
            imgs.append(img)
    return imgs

def get_img_other_info(query_path):
    pass
    return None




def load_index_file(index_file):
        index = hnsw.load(index_file, 1024, np.float32, "sqeuclidean")
        return index

def load_info_file(info_file):
    with open (info_file, 'r') as f:
        infos = json.load(f)
    return infos



def search(index, encoder, infos_list, query_path, top_k=20):
    total_neighbors = []
    total_distances = []
    query_image = request_image(query_path)
    
    img_other_info = get_img_other_info(query_path)
    split_imgs = split_img(query_image, img_other_info)
    
    split_embeddings = cp.array([cp.array(encoder.encode_image(after_split_img)).astype('float32') for after_split_img in split_imgs])
    
    
    if split_embeddings.ndim == 1:
        split_embeddings = split_embeddings.reshape(1, -1)
    
    
    
    result_infos = []
    
    
    time_cost = 0
    begin_time = time.time()
    
    distances, neighbors = cagra.search(cagra.SearchParams(),index, split_embeddings, top_k)
    end_time = time.time()
    time_cost += end_time - begin_time
    
    neighbors = cp.asarray(neighbors).flatten().tolist()
    distances = cp.asarray(distances).flatten().tolist()
    for neighbor in neighbors:
        result_infos.append(infos_list[neighbor])
    total_neighbors.extend(neighbors)
    total_distances.extend(distances)

    total_distances, total_neighbors, result_infos = zip(*sorted(zip(total_distances, total_neighbors, result_infos)))
    return total_distances[:], total_neighbors[:], result_infos[:]






In [3]:
query_img_path = " metaservice/api/region/openslide/241183-21.tiff/6400/25344/512/256/1"
image_encoder = WSI_Image_UNI_Encoder()

ssd_dir = "/hpc2ssd/JH_DATA/spooler/ysi538/"

hnsw_index = load_index_file(ssd_dir + "cupy_index_batch_0.bin")
infos_list = load_info_file(ssd_dir + "cupy_infos_batch_0.json")



cuda


In [None]:
embeddings = cp.load(ssd_dir + "cupy_embeddings_batch_0.npy")

embeddings = cp.asnumpy(embeddings).astype(np.float32)
embeddings = cp.array(embeddings, dtype=cp.float32)
build_params = cagra.IndexParams(metric="sqeuclidean",build_algo = 'nn_descent')
cuda_index = cagra.build(build_params, embeddings)

[I] [15:53:42.278367] optimizing graph
[I] [15:54:07.391248] Graph optimized, creating index


In [8]:


total_distances, total_neighbors, result_infos = search(cuda_index, image_encoder, infos_list, query_img_path, 20)



combined_regions = get_combined_regions(result_infos)



In [None]:
# total_distances, total_neighbors, result_infos 
combined_regions