In [1]:
import openslide
import numpy as np
import matplotlib.pyplot as plt
import glob
import cv2
import tqdm

In [2]:
slides = glob.glob('/mnt/s3/lhm/HCC/*/*.mrxs')
len(slides)

128

In [3]:
import concurrent.futures
from tqdm import tqdm
from multiprocessing import Pool, Pipe, freeze_support

#=============================================================#
# 接口                                                        #
#-------------------------------------------------------------#
#   multi_process_exec 多进程执行                             #
#   multi_thread_exec  多线程执行                             #
#-------------------------------------------------------------#
# 参数：                                                      #
#   f         (function): 批量执行的函数                      #
#   args_mat  (list)    : 批量执行的参数                      #
#   pool_size (int)     : 进程/线程池的大小                   #
#   desc      (str)     : 进度条的描述文字                    #
#-------------------------------------------------------------#
# 例子：                                                      #
# >>> def Pow(a,n):        ← 定义一个函数（可以有多个参数）   #
# ...     return a**n                                         #
# >>>                                                         #
# >>> args_mat=[[2,1],     ← 批量计算 Pow(2,1)                #
# ...           [2,2],                Pow(2,2)                #
# ...           [2,3],                Pow(2,3)                #
# ...           [2,4],                Pow(2,4)                #
# ...           [2,5],                Pow(2,5)                #
# ...           [2,6]]                Pow(2,6)                #
# >>>                                                         #
# >>> results=multi_thread_exec(Pow,args_mat,desc='计算中')   #
# 计算中: 100%|█████████████| 6/6 [00:00<00:00, 20610.83it/s] #
# >>>                                                         #
# >>> print(results)                                          #
# [2, 4, 8, 16, 32, 64]                                       #
#-------------------------------------------------------------#

ToBatch = lambda arr, size: [arr[i * size:(i + 1) * size] for i in range((size - 1 + len(arr)) // size)]


def batch_exec(f, args_batch, w):
    results = []
    for i, args in enumerate(args_batch):
        try:
            if isinstance(args, (list, tuple, dict)):
                ans = f(*args)
            else:
                ans = f(args)
            results.append(ans)
        except Exception as e:
            print(e)
            results.append(None)
        w.send(1)
    return results


def multi_process_exec(f, args_mat, pool_size=5, desc=None):
    if len(args_mat) == 0: return []
    batch_size = max(1, int(len(args_mat) / 4 / pool_size))
    results = []
    args_batches = ToBatch(args_mat, batch_size)
    with tqdm(total=len(args_mat), desc=desc) as pbar:
        with Pool(processes=pool_size) as pool:
            r, w = Pipe(duplex=False)
            pool_rets = []
            for i, args_batch in enumerate(args_batches):
                pool_rets.append(pool.apply_async(batch_exec, (f, args_batch, w)))
            cnt = 0
            while cnt < len(args_mat):
                try:
                    msg = r.recv()
                    pbar.update(1)
                    cnt += 1
                except EOFError:
                    print('EOFError')
                    break
            for ret in pool_rets:
                for r in ret.get():
                    results.append(r)
    return results


def multi_thread_exec(f, args_mat, pool_size=5, desc=None):
    if len(args_mat) == 0: return []
    results = [None for _ in range(len(args_mat))]
    with tqdm(total=len(args_mat), desc=desc) as pbar:
        with concurrent.futures.ThreadPoolExecutor(max_workers=pool_size) as executor:
            futures = {executor.submit(f, *args): i for i, args in enumerate(args_mat)}
            for future in concurrent.futures.as_completed(futures):
                i = futures[future]
                ret = future.result()
                results[i] = ret
                pbar.update(1)
    return results

In [4]:
import multiprocessing
import os
import time
from itertools import repeat

import networkx as nx
import numpy as np
import torch
from skimage.segmentation import slic, mark_boundaries
from torch_geometric.data import Data

# tqdm for progress bar
from tqdm.auto import tqdm
import time


def convert_numpy_img_to_superpixel_graph(img, code, polygons, slic_kwargs={}):
    # img  = cv2.resize(img,(img.shape[1]//3,img.shape[0]//3))
    height = img.shape[0]
    width = img.shape[1]
    n = 1024
    hsv_image = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    lower_white = np.array([0, 0, 220])  # 假设threshold是一个你选择的值
    upper_white = np.array([180, 255, 255])
    lower_black = np.array([0, 0, 0])
    upper_black = np.array([180, 255, 50])
    # 创建掩模来分离白色和黑色背景
    mask1 = cv2.inRange(hsv_image, lower_white, upper_white)
    mask2 = cv2.inRange(hsv_image, lower_black, upper_black)
    # 使用 cv2.bitwise_or 来合并掩模
    mask = 255 - cv2.bitwise_or(mask1, mask2)
    kernel = np.ones((10, 10), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    mask = cv2.erode(mask, kernel, iterations=2)
    #缩小图像
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    max_area = 0
    max_contour = None
    xm = 0
    ym = 0
    height, width = img.shape[:2]
    # 遍历轮廓，找到面积最大的外接矩形
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        area = w * h
        if area > max_area:
            max_area = area
            max_contour = contour

    if max_contour is not None:
        xm, ym, w, h = cv2.boundingRect(max_contour)
        mask = mask[ym:ym + h, xm:xm + w]
        hsv_image = hsv_image[ym:ym + h, xm:xm + w, :]
        img_org = img.copy()
        img = img[ym:ym + h, xm:xm + w, :]
        height, width = h, w
    mask = mask / 255
    segments = slic(img, n_segments=n, slic_zero=True, compactness=10, start_label=0,
                    enforce_connectivity=True, convert2lab=True, sigma=0.7,
                    mask=mask, **slic_kwargs)
    np.save(f"/mnt/s3/lhm/HCC_seg_1/{code}.npy", segments)
    num_of_nodes = np.max(segments) + 1
    nodes = {node: {"rgb_list": [], "r": [], "g": [], "b": [], } for node in range(num_of_nodes)}
    # get rgb values and positions
    for y in range(height):
        for x in range(width):
            node = segments[y, x]
            if node < 0:
                continue
            rgb = img[y, x, :]
            nodes[node]["r"].append(rgb[2])
            nodes[node]["g"].append(rgb[1])
            nodes[node]["b"].append(rgb[0])
    for node in nodes:
        r_bin = np.bincount(nodes[node]["r"])
        r_bin = np.pad(r_bin, (0, 256 - len(r_bin)), 'constant', constant_values=(0, 0))
        g_bin = np.bincount(nodes[node]["g"])
        g_bin = np.pad(g_bin, (0, 256 - len(g_bin)), 'constant', constant_values=(0, 0))
        b_bin = np.bincount(nodes[node]["b"])
        b_bin = np.pad(b_bin, (0, 256 - len(b_bin)), 'constant', constant_values=(0, 0))
        nodes[node]["rgb_list"] = np.stack([r_bin, g_bin, b_bin]).ravel()
    G = nx.Graph()
    # compute node positions
    segments_ids = np.unique(segments)
    segments_ids = np.delete(segments_ids, np.where(segments_ids == -1))
    pos = np.array([np.mean(np.nonzero(segments == i), axis=1) for i in segments_ids])
    pos[:, 0] += ym
    pos[:, 1] += xm
    pos = pos * 64
    pos = pos.astype(int)
    #pos[0]为height_y pos[1]为width_x
    for node in nodes:
        feature = nodes[node]['rgb_list']
        label = False
        for p in polygons:
            p = np.array(p, dtype=np.int32)
            label = cv2.pointPolygonTest(p, (int(pos[node][1]), int(pos[node][0])), True) > 0
        G.add_node(node, features=feature, label=label)
    # add edges
    vs_right = np.vstack([segments[:, :-1].ravel(), segments[:, 1:].ravel()])
    vs_below = np.vstack([segments[:-1, :].ravel(), segments[1:, :].ravel()])
    bneighbors = np.unique(np.hstack([vs_right, vs_below]), axis=1)
    for i in range(bneighbors.shape[1]):
        if bneighbors[0, i] == -1 or bneighbors[1, i] == -1:
            continue
        if bneighbors[0, i] != bneighbors[1, i]:
            G.add_edge(bneighbors[0, i], bneighbors[1, i])
    # add self loops
    for node in nodes:
        G.add_edge(node, node)

    # get edge_index
    m = len(G.edges)
    edge_index = np.zeros([2 * m, 2]).astype(np.int64)
    for e, (s, t) in enumerate(G.edges):
        edge_index[e, 0] = s
        edge_index[e, 1] = t
        edge_index[m + e, 0] = t
        edge_index[m + e, 1] = s
    # get features
    num_of_features = 768
    x = np.zeros([1024, num_of_features]).astype(np.float32)
    y = np.zeros(1024).astype(np.float32)
    for node in G.nodes:
        if node >= 1024:
            continue
        x[node] = G.nodes[node]["features"]
        y[node] = G.nodes[node]["label"]
    return x, y, edge_index, pos, [ym, xm], [height, width]

In [6]:
from utils.utils_single import read_points_from_xml


def process_img(i):
    img = cv2.imread(i)
    name = i.split('/')[-1].split('.')[0]
    label = i.split('/')[-2]
    if 'M' in name and name[-1] == '-':
        name = name[:-1]
    polygons = []
    if os.path.exists('/mnt/s3/lhm/HCC' + '/xml/' + f'{name}_Annotations.xml'):
        polygons = read_points_from_xml(liver_name=f'{name}_Annotations.xml', scale=1,
                                        xml_path='/mnt/s3/lhm/HCC' + '/xml/',
                                        dataset='HCC_LOWER')
    x, y, edge_index, pos, offset, size = convert_numpy_img_to_superpixel_graph(img, name, polygons)
    res = dict({})
    res['x'] = x
    res['edge'] = edge_index
    res['pos'] = pos
    res['code'] = name
    res['y'] = label
    res['nodey'] = y
    res['offset'] = offset
    res['size'] = size
    np.save(f'/mnt/s3/lhm/HCC_level_1/{name}.npy', res, allow_pickle=True)

# process_img('/mnt/storage/lhm/HCC_thumb/0/6M01.png')

In [61]:
args_mat = glob.glob('/mnt/s3/lhm/HCC_thumb/*/*.png')
for i in args_mat:
    process_img(i)

(4267, 1830, 3)


KeyboardInterrupt: 

In [7]:
args_mat = glob.glob('/mnt/s3/lhm/HCC_thumb/*/*.png')
multi_process_exec(process_img, args_mat, 32, desc='processing')

processing:   0%|          | 0/128 [00:00<?, ?it/s]



[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [64]:
from collections import Counter

a = np.load('/mnt/s3/lhm/HCC_level_1/201251654.npy', allow_pickle=True).item()
print(a)
print(Counter(a['nodey']))

{'x': array([[ 0.,  0.,  0., ..., 16., 15., 26.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  1.,  0.,  0.],
       ...,
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32), 'edge': array([[   0,   11],
       [   0,   12],
       [   0,   20],
       ...,
       [1019, 1019],
       [1020, 1020],
       [1021, 1021]]), 'pos': array([[ 65287,  61301],
       [ 65847,  93817],
       [ 66340,  91014],
       ...,
       [176245,  90067],
       [177351,  89224],
       [177408,  80576]]), 'code': '201251654', 'y': '2', 'nodey': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'offset': [988, 570], 'size': (4267, 1830)}
Counter({1.0: 716, 0.0: 308})


In [46]:
process_img('/mnt/s3/lhm/HCC_thumb/0/6M01.png')

{'x': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), 'edge': array([[   0,    1],
       [   0,    3],
       [   0,    5],
       ...,
       [1023, 1022],
       [1022, 1022],
       [1023, 1023]]), 'pos': array([[ 59897,  77674],
       [ 60763,  74870],
       [ 61719,  71800],
       ...,
       [156520,  20019],
       [157256,  17705],
       [157927,  15318]]), 'code': '6M0', 'y': '0', 'nodey': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), 'offset': [908, 140]}
