In [1]:
### System
import os
import h5py
from tqdm import tqdm

### Numerical Packages
import numpy as np

### Graph Network Packages
import nmslib
import networkx as nx

### PyTorch / PyG
import torch
import torch_geometric
from torch_geometric import utils
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Hnsw:
    def __init__(self, space='cosinesimil', index_params=None,
                 query_params=None, print_progress=True):
        self.space = space
        self.index_params = index_params
        self.query_params = query_params
        self.print_progress = print_progress

    def fit(self, X):
        index_params = self.index_params
        if index_params is None:
            index_params = {'M': 16, 'post': 0, 'efConstruction': 400}

        query_params = self.query_params
        if query_params is None:
            query_params = {'ef': 90}

        # this is the actual nmslib part, hopefully the syntax should
        # be pretty readable, the documentation also has a more verbiage
        # introduction: https://nmslib.github.io/nmslib/quickstart.html
        index = nmslib.init(space=self.space, method='hnsw')
        index.addDataPointBatch(X)
        index.createIndex(index_params, print_progress=self.print_progress)
        index.setQueryTimeParams(query_params)

        self.index_ = index
        self.index_params_ = index_params
        self.query_params_ = query_params
        return self

    def query(self, vector, topn):
        # the knnQuery returns indices and corresponding distance
        # we will throw the distance away for now
        indices, dist = self.index_.knnQuery(vector, k=topn)
        return indices


In [3]:
def cosine_similarity(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    if norm_vec1 == 0 or norm_vec2 == 0:
        return 0
    else:
        return dot_product / (norm_vec1 * norm_vec2)

In [4]:
def pt2graph(wsi_h5, radius=9):
    from torch_geometric.data import Data as geomData
    from itertools import chain
    coords, features = np.array(wsi_h5['coords']), np.array(wsi_h5['features'])
    assert coords.shape[0] == features.shape[0]
    num_patches = coords.shape[0]
    
    model = Hnsw(space='l2')
    model.fit(coords)
    a = np.repeat(range(num_patches), radius-1) # [0, 0, 1, 1, 2, 2, ...]
    b = np.fromiter(chain(*[model.query(coords[v_idx], topn=radius)[1:] for v_idx in range(num_patches)]),dtype=int)
    edge_spatial = torch.Tensor(np.stack([a,b])).type(torch.LongTensor) # np.array([[0, 0, 1, 1, 2, 2, ...], [1, 2, 0, 2, 0, 1, ...]]), shape=(2, num_patches*(radius-1)), edge_spaital[:, i] = [a, b] for edge i, a->b

    edge_weight = []
    for edge in edge_spatial.t().tolist():
        from_idx, to_idx = edge
        from_feat, to_feat = features[from_idx], features[to_idx]
        w = 1 - cosine_similarity(from_feat, to_feat) # [-1, 1] -> [0, 2]
        edge_weight.append(w)
    edge_weight = torch.Tensor(edge_weight).type(torch.FloatTensor)
    
    model = Hnsw(space='l2')
    model.fit(features)
    a = np.repeat(range(num_patches), radius-1)
    b = np.fromiter(chain(*[model.query(features[v_idx], topn=radius)[1:] for v_idx in range(num_patches)]),dtype=int)
    edge_latent = torch.Tensor(np.stack([a,b])).type(torch.LongTensor)

    G = geomData(x = torch.Tensor(features),
                 edge_index = edge_spatial,
                 edge_latent = edge_latent,
                 edge_weight = edge_weight,
                 centroid = torch.Tensor(coords))
    
    g = utils.to_networkx(G, to_undirected=True, edge_attrs=['edge_weight'])
    T = nx.minimum_spanning_tree(g, weight='edge_weight')

    dfs_postorder_index = list(nx.dfs_postorder_nodes(T)) 
    dfs_preorder_index = list(nx.dfs_preorder_nodes(T))

    bfs_levelorder_index = []
    for connected_component in nx.connected_components(T):
        T_sub = T.subgraph(connected_component)
        source_node = random.choice(list(T_sub.nodes()))
        levelorder_index = [source_node] + [t for (s,t) in nx.bfs_edges(T_sub, source_node)]
        bfs_levelorder_index.extend(levelorder_index)
        
    assert len(dfs_postorder_index) == len(dfs_preorder_index) == len(bfs_levelorder_index) == G.num_nodes, f'post_order: {len(dfs_postorder_index)}, pre_order: {len(dfs_preorder_index)}, level_order: {len(bfs_levelorder_index)}, num_nodes: {G.num_nodes}, num_patches: {num_patches}'

    G.dfs_postorder_index = torch.LongTensor(dfs_postorder_index)
    G.dfs_preorder_index = torch.LongTensor(dfs_preorder_index)
    G.bfs_levelorder_index = torch.LongTensor(bfs_levelorder_index)
    
    return G
    


In [5]:
def createDir_h5toPyG(h5_path, save_path):
    pbar = tqdm(os.listdir(h5_path))
    for h5_fname in pbar:
        pbar.set_description('%s - Creating Graph' % (h5_fname[:-3]))
        if os.path.exists( os.path.join(save_path, h5_fname[:-3]+'.pt') ):
            print(str(h5_fname[:-3]) + 'existed')
            continue

        try:
            wsi_h5 = h5py.File(os.path.join(h5_path, h5_fname), "r")
            G = pt2graph(wsi_h5)
            torch.save(G, os.path.join(save_path, h5_fname[:-3]+'.pt'))
            wsi_h5.close()
        except OSError:
            pbar.set_description('%s - Broken H5' % (h5_fname[:12]))
            print(h5_fname, 'Broken')


In [6]:
h5_path = './data/BRACS/BRACS_512_at_level0/h5_files'
save_path = './data/BRACS/BRACS_512_at_level0/PyG_files'
if not os.path.exists(save_path):
    os.makedirs(save_path)
createDir_h5toPyG(h5_path, save_path)

case_radboud_0036 - Creating Graph:   0%|          | 0/508 [00:00<?, ?it/s]
0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************
case_radboud_0651 - Creating Graph:   0%|          | 1/508 [00:08<1:13:49,  8.74s/it]
0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************

0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
***************************************************
case_radboud_0226 - Creating Graph:   0%|          | 2/508 [01:17<6:10:17, 43.91s/it]
0%   10   20   30   40   50   60   70   80   90   100%
|----|----|----|----|----|----|----|----|----|----|
*******