In [3]:

import os
import json

import numpy as np
import faiss
import pickle

import torch
from tqdm import tqdm
import pickle
import json
import numpy as np
import os
from cuvs.neighbors import hnsw,cagra
import cupy as cp


def get_all_folders(path):
    all_folders = []
    for root, dirs, files in os.walk(path):
        if len(dirs) == 0:
            all_folders.append(root)
    return all_folders

root_path = '/hpc2hdd/home/ysi538/my_cuda_code/TCGA_slide_retrieval/'

In [4]:

all_sites = os.listdir(os.path.join(root_path, 'embed_cache'))
site_folders = []
for site in all_sites[:]:
    total_embeddings = []
    total_patch_info = []
    # print(site)
    site_path = os.path.join(root_path, 'embed_cache', site)
    site_folders = os.listdir(site_path)
    site_folders = [os.path.join(site_path, folder) for folder in site_folders]

    for folder in tqdm(site_folders):
        if os.path.isdir(folder):
            embeddings = torch.load(os.path.join(folder, 'embeddings')).to('cpu').numpy()
            patch_info = json.load(open(os.path.join(folder, 'patch_info.json')))

            total_embeddings.append(embeddings)
            total_patch_info.append(patch_info)
            


    total_embeddings = np.concatenate(total_embeddings, axis=0)
    total_patch_info = [item for sublist in total_patch_info for item in sublist]

    save_path = os.path.join(root_path, "total_for_site", site)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    np.save(os.path.join(save_path, 'total_embeddings.npy'), total_embeddings)
    with open(os.path.join(save_path, 'total_patch_info.json'), 'w') as f:
        json.dump(total_patch_info, f)
    print(f"Saved {site} total embeddings and patch info to {save_path}")
        
            



In [5]:

total_site = os.listdir(os.path.join(root_path, "total_for_site"))
total_site = [os.path.join(root_path, "total_for_site", site) for site in total_site]
for site in tqdm(total_site):

    total_embeddings = np.load(os.path.join(site, 'total_embeddings.npy'))

    with open(os.path.join(site, 'total_patch_info.json'), 'r') as f:
        total_patch_info = json.load(f)
    
    print(f"before total_embeddings shape: {total_embeddings.shape}")
    print(f"before total_patch_info length: {len(total_patch_info)}")

    foreground_mask = np.array([patch['is_foreground'] for patch in total_patch_info], dtype=bool)
    print(f"foreground_mask shape: {foreground_mask.shape}")
    
    total_embeddings = total_embeddings[foreground_mask]
    total_patch_info = [patch for patch in total_patch_info if patch['is_foreground'] == 1]

    print(f"after total_embeddings shape: {total_embeddings.shape}")
    print(f"after total_patch_info length: {len(total_patch_info)}")
    #  
    save_path = os.path.join(root_path, "total_for_site_foreground", site.split('/')[-1])
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if os.path.exists(os.path.join(save_path, 'total_embeddings.npy')):
        os.remove(os.path.join(save_path, 'total_embeddings.npy'))
    np.save(os.path.join(save_path, 'total_embeddings.npy'), total_embeddings)

    if os.path.exists(os.path.join(save_path, 'total_patch_info.json')):
        os.remove(os.path.join(save_path, 'total_patch_info.json'))
    with open(os.path.join(save_path, 'total_patch_info.json'), 'w') as f:
        json.dump(total_patch_info, f)
    

In [6]:
type(total_patch_info[0]['is_foreground'])

In [7]:
total_site = os.listdir(os.path.join(root_path, "total_for_site_foreground"))
total_site = [os.path.join(root_path, "total_for_site_foreground", site) for site in total_site]
for site in tqdm(total_site):

    total_embeddings = np.load(os.path.join(site, 'total_embeddings.npy'))

    with open(os.path.join(site, 'total_patch_info.json'), 'r') as f:
        total_patch_info = json.load(f)

    total_embeddings = cp.array(total_embeddings, dtype=cp.float32)


    print(f"site: {site}")
    print(f"total embeddings shape: {total_embeddings.shape}")
    print(f"total patch info length: {len(total_patch_info)}")
    
    build_params = cagra.IndexParams(metric="sqeuclidean",build_algo = 'nn_descent')
    cuda_index = cagra.build(build_params, total_embeddings)

    query_embeddings = cp.array(total_embeddings, dtype=cp.float32)
    k = 512
    distances, neighbors = cagra.search(cagra.SearchParams(itopk_size = k),cuda_index, query_embeddings, k)

    result_path = os.path.join(site, 'total_query_results.pkl')
    if os.path.exists(result_path):
          os.remove(result_path)


    with open(os.path.join(site, 'total_query_results.pkl'), 'wb') as f:
        pickle.dump((distances, neighbors), f)

    del cuda_index
    del total_embeddings
    del query_embeddings
    del distances
    del neighbors


In [None]:

import os
import pickle
import json
import cupy as cp
from tqdm import tqdm

def calculate_map5(query_results, ground_truth, neighbors, total_patch_info):
    ap_scores = []
    for i in range(len(query_results)):

        true_label = ground_truth[i]

        neighbor_indices = neighbors[i][:5]

        unique_neighbors = []
        seen_sub_ids = {total_patch_info[i]['sub_id']} 
        
        for idx in neighbor_indices:
            neighbor_idx = int(idx)
            patch_info = total_patch_info[neighbor_idx]
            if patch_info['sub_id'] not in seen_sub_ids:
                seen_sub_ids.add(patch_info['sub_id'])
                unique_neighbors.append(patch_info['subtype'])
                if len(unique_neighbors) == 5:
                    break

        if len(unique_neighbors) < 5:
            ap_scores.append(0)
            continue

        ap = 0.0
        correct_count = 0
        for k in range(5):
            if unique_neighbors[k] == true_label:
                correct_count += 1
                precision_at_k = correct_count / (k + 1)
                ap += precision_at_k
        
        if correct_count > 0:
            ap /= min(5, correct_count)
        else:
            ap = 0
        
        ap_scores.append(ap)

    map5 = sum(ap_scores) / len(ap_scores) if ap_scores else 0
    return map5

total_site = os.listdir(os.path.join(root_path, "total_for_site_foreground"))
total_site = [os.path.join(root_path, "total_for_site_foreground", site) for site in total_site]
dismiss = 0
for site in tqdm(total_site[:]):

    with open(os.path.join(site, 'total_query_results.pkl'), 'rb') as f:
        distances, neighbors = pickle.load(f)

    with open(os.path.join(site, 'total_patch_info.json'), 'r') as f:
        total_patch_info = json.load(f)

    query_results = []
    distances = cp.asarray(distances)
    neighbors = cp.asarray(neighbors)
    
    print(f"neighbors shape: {neighbors.shape}")
    print(f"distances shape: {distances.shape}")
    print(f"total_patch_info length: {len(total_patch_info)}")
    

    ground_truth = [patch['subtype'] for patch in total_patch_info]
    map5 = calculate_map5(query_results, ground_truth, neighbors, total_patch_info)

    for i in range(neighbors.shape[0]):
        neighbors_patch_info = []
        for j in range(neighbors.shape[1]):
            neighbor_index = int(neighbors[i][j])
            patch_info = total_patch_info[neighbor_index]
            if patch_info['sub_id'] != total_patch_info[i]['sub_id']:
                neighbors_patch_info.append(patch_info)
                if len(neighbors_patch_info) == 5:
                    break
        
        if len(neighbors_patch_info) < 5:
            query_results.append(total_patch_info[i]['subtype'])
            dismiss += 1
        else:
            neighbors_labels = [patch['subtype'] for patch in neighbors_patch_info]
            label = max(set(neighbors_labels), key=neighbors_labels.count)
            query_results.append(label)

    correct_count = {subtype: 0 for subtype in set(ground_truth)}
    incorrect_count = {subtype: 0 for subtype in set(ground_truth)}
    for i in range(len(query_results)):
        if query_results[i] == ground_truth[i]:
            correct_count[ground_truth[i]] += 1
        else:
            incorrect_count[ground_truth[i]] += 1
    
    total_count = {subtype: correct_count[subtype] + incorrect_count[subtype] for subtype in set(ground_truth)}

    with open(os.path.join(site, 'total_query_results.json'), 'w') as f:
        json.dump({
            'correct_count': correct_count, 
            'incorrect_count': incorrect_count, 
            'total_count': total_count,
            'map5': map5
        }, f)

    site_name = site.split('/')[-1]
    print(f"\nSite: {site_name}")
    print(f"MAP@5: {map5:.4f}")
    
    for subtype in set(ground_truth):
        if total_count[subtype] != 0:
            accuracy = correct_count[subtype] / total_count[subtype]
        else:
            accuracy = 0
        print(f"Subtype: {subtype}, Correct: {correct_count[subtype]}, Incorrect: {incorrect_count[subtype]}, Total: {total_count[subtype]}, Accuracy: {accuracy:.4f}")

  0%|          | 0/10 [00:00<?, ?it/s]

neighbors shape: (1778525, 512)
distances shape: (1778525, 512)
total_patch_info length: 1778525


 10%|█         | 1/10 [30:15<4:32:23, 1816.00s/it]

site: brain, subtype: LGG, correct: 581514, incorrect: 163314, total: 744828, accuracy: 0.7807359551466916
site: brain, subtype: GBM, correct: 871046, incorrect: 162651, total: 1033697, accuracy: 0.8426511830836309
neighbors shape: (600261, 512)
distances shape: (600261, 512)
total_patch_info length: 600261


 20%|██        | 2/10 [41:12<2:31:11, 1133.93s/it]

site: liver, subtype: CHOL, correct: 12996, incorrect: 28531, total: 41527, accuracy: 0.31295301851807256
site: liver, subtype: PAAD, correct: 171193, incorrect: 33109, total: 204302, accuracy: 0.8379408914254388
site: liver, subtype: LIHC, correct: 321102, incorrect: 33330, total: 354432, accuracy: 0.9059622156013001
neighbors shape: (855157, 512)
distances shape: (855157, 512)
total_patch_info length: 855157


 30%|███       | 3/10 [57:28<2:03:53, 1061.88s/it]

site: endocrine, subtype: PCPG, correct: 129947, incorrect: 43721, total: 173668, accuracy: 0.7482495335928323
site: endocrine, subtype: THCA, correct: 435138, incorrect: 36015, total: 471153, accuracy: 0.9235598627197534
site: endocrine, subtype: ACC, correct: 162967, incorrect: 47369, total: 210336, accuracy: 0.7747936634717785
neighbors shape: (1616260, 512)
distances shape: (1616260, 512)
total_patch_info length: 1616260


 40%|████      | 4/10 [1:09:48<1:33:27, 934.64s/it]

site: gastrointestinal, subtype: COAD, correct: 392005, incorrect: 236919, total: 628924, accuracy: 0.6232947065146186
site: gastrointestinal, subtype: STAD, correct: 433029, incorrect: 196610, total: 629639, accuracy: 0.6877417059616701
site: gastrointestinal, subtype: ESCA, correct: 61334, incorrect: 79039, total: 140373, accuracy: 0.43693587798223305
site: gastrointestinal, subtype: READ, correct: 46251, incorrect: 171073, total: 217324, accuracy: 0.21282048922346358
neighbors shape: (2090337, 512)
distances shape: (2090337, 512)
total_patch_info length: 2090337


In [19]:
patch_info

{'position': [112, 592], 'subtype': 'LGG', 'level': 'top', 'is_foreground': 1}

In [8]:
i,neighbors[i],len(total_patch_info)

(0,
 array([  82023, 1061050, 3083301, 1801675, 1142086], dtype=uint32),
 1778525)

In [4]:
len(query_results)

0

In [5]:
len(ground_truth)

NameError: name 'ground_truth' is not defined