In [72]:
from io import BytesIO
from PIL import Image
from tqdm import tqdm
from datetime import datetime
import openslide
import os, sys, json, asyncio, aiohttp, requests, logging
from openslide.deepzoom import DeepZoomGenerator
import hashlib
import numpy as np
import timm, torch

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import multiprocessing as mp

In [None]:
import os

origin_list = []
wsi_list = os.listdir('/hpc2hdd/home/ysi538/retrieval/caches/wsi_image')
for wsi in wsi_list:
    #  _ origin_list
    origin_list.append('_'.join(wsi.split('_')[1:]))



In [2]:
len(wsi_list), len(origin_list)

(8341, 8341)

In [None]:

with open('wsi_list.txt', 'w') as f:
    f.write('\n'.join(wsi_list))
with open('origin_list.txt', 'w') as f:
    f.write('\n'.join(origin_list))

In [10]:
fold_list = os.listdir('/hpc2hdd/home/ysi538/retrieval/MDI_RAG_Image2Image_Research/data/embedding_cache')

In [None]:

for fold in fold_list:
    idx = origin_list.index(fold)
    os.rename(f'/hpc2hdd/home/ysi538/retrieval/MDI_RAG_Image2Image_Research/data/embedding_cache/{fold}', f'/hpc2hdd/home/ysi538/retrieval/MDI_RAG_Image2Image_Research/data/embedding_cache/{wsi_list[idx]}')

In [73]:
class CustomWSIDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        return image
    

class WSIUNIEncoder():
    def __init__(self, **kwargs):
        self.embed_model =  timm.create_model(
            "vit_large_patch16_224", img_size=224, patch_size=16, init_values=1e-5, num_classes=0, dynamic_img_size=True
        )
        self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ])

        local_dir = "/hpc2hdd/home/ysi538/retrieval/checkpoints/vit_large_patch16_224.dinov2.uni_mass100k/"
        self._device = self.infer_torch_device()
        print(self._device)
        self.embed_model.load_state_dict(torch.load(os.path.join(local_dir, "pytorch_model.bin"), map_location="cpu"), strict=True)
        self.embed_model = self.embed_model.to(self._device)
        self.embed_model.eval()

    def infer_torch_device(self):
        """Infer the input to torch.device."""
        try:
            has_cuda = torch.cuda.is_available()
        except NameError:
            import torch  # pants: no-infer-dep
            has_cuda = torch.cuda.is_available()
        if has_cuda:
            return "cuda"
        if torch.backends.mps.is_available():
            return "mps"
        return "cpu"

    def encode_wsi_patch(self, wsi_name, dataloader):
        embeddings = []
        with torch.no_grad():
            for images in tqdm(dataloader, desc=f"WSI name: {wsi_name}", ascii=True):
                images = images.to(self._device)
                embedding = self.embed_model(images)
                embeddings.append(embedding.cpu())

        if embeddings == []:
            return []
        else:
            patch_embeddings = torch.cat(embeddings, dim=0).cpu().tolist()
            return patch_embeddings

class Embedding_loader():
    def __init__(self):
        self.wsi_patch_encoder = WSIUNIEncoder()
        self.cache_path = "/hpc2hdd/home/ysi538/retrieval/MDI_RAG_Image2Image_Research/data/embedding_cache"
        self.loaded_embeddings = os.listdir(self.cache_path)

    async def load_images(self, images):
        """ """
        return images

    async def loading_wsi_image(self, wsi_name, images):
        """  CPU   WSI patch   Dataloader """
        patch_infos = [f"patch_{i}" for i in range(len(images))]
        images = await self.load_images(images)

        wsi_dataset = CustomWSIDataset(images, self.wsi_patch_encoder.transform)
        dataloader = DataLoader(wsi_dataset, batch_size=16, shuffle=False, num_workers=16, pin_memory=True)

        return patch_infos, dataloader

    def loading_worker(self, input_queue, output_queue):
        while True:
            item = input_queue.get()
            if item is None:
                break

            wsi_name, images = item
            if wsi_name in self.loaded_embeddings:
                print(f"WSI {wsi_name} cached.")
                output_queue.put((wsi_name, [], []))
            else:
                patch_infos, dataloader = asyncio.run(self.loading_wsi_image(wsi_name, images))
                output_queue.put((wsi_name, patch_infos, dataloader))
    
    def encoding_worker(self, input_queue):
        while True:
            item = input_queue.get()
            if item is None:
                break

            wsi_name, patch_infos, dataloader = item
            patch_embeddings = self.wsi_patch_encoder.encode_wsi_patch(wsi_name, dataloader)

            dir_path = os.path.join(self.cache_path, wsi_name)
            if not os.path.exists(dir_path):
                os.makedirs(dir_path)

            info_path = os.path.join(self.cache_path, wsi_name, "patch_info.json")
            with open(info_path, 'w') as file:
                json.dump(patch_infos, file)

            embedding_path = os.path.join(self.cache_path, wsi_name, "embeddings.json")
            with open(embedding_path, 'w') as file:
                json.dump(patch_embeddings, file)
    
    def main(self, wsi_data_list):
        load_workers = 2
        load_queue = mp.Queue(maxsize=8)
        encode_queue = mp.Queue(maxsize=8)

        loading_processes = [mp.Process(target=self.loading_worker, args=(load_queue, encode_queue)) for _ in range(load_workers)]
        encoding_process = mp.Process(target=self.encoding_worker, args=(encode_queue,))

        for p in loading_processes:
            p.start()
        encoding_process.start()

        for wsi_name, images in wsi_data_list:
            load_queue.put((wsi_name, images))

        for _ in range(load_workers):
            load_queue.put(None)
        for p in loading_processes:
            p.join()

        encode_queue.put(None)
        encoding_process.join()

In [None]:
def read_region_with_rotate(slide, x, y, level, w, h, angle = 0):

    center_x, center_y = x + w / 2, y + h / 2

    radians = np.deg2rad(angle)
    cos_r = abs(np.cos(radians))
    sin_r = abs(np.sin(radians))
    new_w = w * cos_r + h * sin_r
    new_h = w * sin_r + h * cos_r

    new_w = int(np.ceil(new_w))
    new_h = int(np.ceil(new_h))

    new_x = int(center_x - new_w / 2)
    new_y = int(center_y - new_h / 2)


    region = slide.read_region((new_x, new_y), level, (new_w, new_h))
    region = region.rotate(angle, expand=True)

    region_center_x, region_center_y = region.size[0] / 2, region.size[1] / 2
    region_x = int(region_center_x - w / 2)
    region_y = int(region_center_y - h / 2)
    region_w = w
    region_h = h
    region = region.crop((region_x, region_y, region_x + region_w, region_y + region_h))

    return region

In [75]:

def file_md5(fileName):
    m = hashlib.md5()
    blocksize = 2**20
    with open(fileName, "rb") as f:
        while True:
            buf = f.read(blocksize)
            if not buf:
                break
            m.update(buf)
    return m.hexdigest()

def get_slide_info(wsi_name):
    metadata = {}
    filepath = os.path.join('caches/wsi_image/', wsi_name)
    filepath = os.path.join('/hpc2hdd/home/ysi538/retrieval/caches/wsi_image/', wsi_name)
    if not os.path.isfile(filepath):
        msg = {"error": "No such file"}
        print(msg)
        return msg
    # metadata['location'] = filepath
    print(f"Loading {filepath}")
    try:
        slide = openslide.OpenSlide(filepath)
    except BaseException as error:
        msg = {"type": "Openslide", "error": str(error)}
        print(msg)
        return msg
    slide_properties = slide.properties
    return slide, slide_properties


def loading_wsi(wsi_name):
    slice_size = (256, 256)
    slide, slide_info = get_slide_info(wsi_name)
    print(slide_info)
    num_level = int(slide_info.get('openslide.level-count', 1))
    patch_info_list = []
    for level in range(1, num_level):       # start from level 1

        width = int(slide_info.get(f"openslide.level[{level}].width"))
        height = int(slide_info.get(f"openslide.level[{level}].height"))

        for y in range(0, height, slice_size[1]):
            for x in range(0, width, slice_size[0]):
                patch_infos = {
                    "x": str(x),
                    "y": str(y),
                    "width": str(slice_size[1]),
                    "height": str(slice_size[0]),
                    "level": str(level)
                }
                patch_info_list.append(patch_infos)

    captions = []
    regions = []
    for patch_info in tqdm(patch_info_list):
        x, y, width, height, level = int(patch_info["x"]), int(patch_info["y"]), int(patch_info["width"]), int(patch_info["height"]), int(patch_info["level"])
        region = slide.read_region((x, y), level, (width, height))
        regions.append(region)
        captions.append(wsi_name + f"_{x}_{y}_{width}_{height}_{level}.png")
        
    return regions, captions


In [None]:


file_path = "/hpc2hdd/home/ysi538/retrieval/MDI_RAG_Image2Image_Research/data/wsi_names.json"
with open(file_path, 'r', encoding='utf-8') as f:
        wsi_name_list = json.load(f)
        for wsi_name in wsi_name_list:
            regions, captions  = loading_wsi(wsi_name)
            #  region encode
            wsi_data_list = [(wsi_name, regions)]
            emb_loader = Embedding_loader()
            emb_loader.main(wsi_data_list)
            
                

In [None]:
# query_img_path =  
query_img_path =  
response = requests.get(query_img_path)
img = response.content
    