In [None]:
import nmslib
import networkx as nx
import os
import torch
import numpy as np
from tqdm import tqdm
import h5py

In [None]:
class Hnsw:
    def __init__(self, space='cosinesimil', index_params=None,
                 query_params=None, print_progress=True):
        self.space = space
        self.index_params = index_params
        self.query_params = query_params
        self.print_progress = print_progress

    def fit(self, X):
        index_params = self.index_params
        if index_params is None:
            index_params = {'M': 16, 'post': 0, 'efConstruction': 400}

        query_params = self.query_params
        if query_params is None:
            query_params = {'ef': 90}

        # this is the actual nmslib part, hopefully the syntax should
        # be pretty readable, the documentation also has a more verbiage
        # introduction: https://nmslib.github.io/nmslib/quickstart.html
        index = nmslib.init(space=self.space, method='hnsw')
        index.addDataPointBatch(X)
        index.createIndex(index_params, print_progress=self.print_progress)
        index.setQueryTimeParams(query_params)

        self.index_ = index
        self.index_params_ = index_params
        self.query_params_ = query_params
        return self

    def query(self, vector, topn):
        # the knnQuery returns indices and corresponding distance
        # we will throw the distance away for now
        indices, dist = self.index_.knnQuery(vector, k=topn)
        return indices

In [None]:
from torch_geometric.data import Data as geomData
from itertools import chain

def pt2graph(wsi_h5, slidename,radius=9):
    from torch_geometric.data import Data as geomData
    from itertools import chain
    coords, features = np.array(wsi_h5['coords']), np.array(wsi_h5['features'])
    assert coords.shape[0] == features.shape[0]
    num_patches = coords.shape[0]
#     print(num_patches)
    superpixel_info_path = os.path.join('/data14/yanhe/miccai/super_pixel/slide_superpixel/argo_new/superpixel_num_500',slidename+'.pth')
    superpixel_info = torch.load(superpixel_info_path)
    superpixel_attri=[]
#     print(len(superpixel_info))
    for index in range(len(superpixel_info)):
        superpixel = superpixel_info[index]['superpixel']
        superpixel_attri.append(superpixel)
    superpixel_attri = torch.LongTensor(superpixel_attri)  
#     print(superpixel_attri)
    
    inter_graph_path = os.path.join('/data14/yanhe/miccai/super_pixel/graph_file/argo/superpixel_num_500',slidename+'.pt')
    g = torch.load(inter_graph_path)
    for index in range(g.ndata['centroid'].shape[0]):
        edge_index = torch.Tensor()
        edge_index = g.edges()[0].unsqueeze(0)
        edge_index = torch.cat((edge_index,g.edges()[1].unsqueeze(0)),dim=0)
        
    model = Hnsw(space='l2')
    model.fit(coords)
    a = np.repeat(range(num_patches), radius-1)
    b = np.fromiter(chain(*[model.query(coords[v_idx], topn=radius)[1:] for v_idx in range(num_patches)]),dtype=int)
    edge_spatial = torch.Tensor(np.stack([a,b])).type(torch.LongTensor)
    superpixel_edge = superpixel_attri[edge_spatial]

    edge_mask = (superpixel_edge[0,:] == superpixel_edge[1,:])

    remain_edge_index = edge_spatial[:,edge_mask]
    G = geomData(x = torch.Tensor(features),
                 edge_patch = remain_edge_index,
                 edge_superpixel = edge_index,
                 superpixel_attri = superpixel_attri,
                 centroid = torch.Tensor(coords))
    return G

In [None]:
def createDir_h5toPyG(h5_path, save_path):
    pbar = tqdm(os.listdir(h5_path))
    for h5_fname in pbar:
        pbar.set_description('%s - Creating Graph' % (h5_fname[:-3]))

        try:
            wsi_h5 = h5py.File(os.path.join(h5_path, h5_fname), "r")
            slidename = h5_fname[:-3]
            if slidename != 'ZS6Y1A01554_HE520' and slidename != 'ZS6Y1A03883_HE208' and slidename != 'ZS6Y1A07240_HE400' and slidename != 'ZS6Y1A08318_HE155':
                G = pt2graph(wsi_h5,slidename)
                torch.save(G, os.path.join(save_path, h5_fname[:-3]+'.pt'))
            wsi_h5.close()
        except OSError:
            pbar.set_description('%s - Broken H5' % (h5_fname[:12]))
            print(h5_fname, 'Broken')

In [None]:
h5_path = '/data12/ybj/survival/argo_selected/20x/slides_feat/h5_files'
save_path = '/data14/yanhe/miccai/graph_file/argo/superpixel_num_500/'
os.makedirs(save_path, exist_ok=True)
createDir_h5toPyG(h5_path, save_path)