In [1]:
import os
import shutil
import random
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from glob import glob
import cv2
from tqdm.auto import tqdm
from collections import deque, defaultdict
import joblib


def get_nodes(im: np.array):
    H, W = im.shape
    im_nodes = np.zeros((H+2, W+2), dtype=int)
    for dh in [0, 1, 2]:
        for dw in [0, 1, 2]:
            if dh == 1 and dw == 1:
                im_nodes[dh:H+dh, dw:W+dw] += im * 100
            else:
                im_nodes[dh:H+dh, dw:W+dw] += im
    im_nodes = (im_nodes[1:H+1, 1:W+1]>=103)*255 + (im_nodes[1:H+1, 1:W+1]==101)*255
    
    return im_nodes


def get_nodepoints(im_nodes: np.array):
    H, W = im_nodes.shape
    nodes = defaultdict(list)
    _count = 0
    check = np.zeros_like(im_nodes, dtype=int)
    for h in range(H):
        for w in range(W):
            if check[h, w] == 1: 
                continue
            check[h, w] = 1
            if im_nodes[h, w] == 255:
                nodes[_count] = [[h, w]]
                que = deque([[h, w]])
                while que:
                    _h, _w = que.popleft()
                    for dh in [-1, 0, 1]:
                        for dw in [-1, 0, 1]:
                            if dh == 0 and dw == 0:
                                continue
                            nh = _h + dh
                            nw = _w + dw
                            if 0<=nh<H and 0<=nw<W and check[nh, nw]==0 and im_nodes[nh, nw]==255:
                                nodes[_count].append([nh, nw])
                                check[nh, nw] = 1
                                que.append([nh, nw])
                _count += 1
    return nodes


def get_connectnodes(im: np.array, nodes: list):
    H, W = im.shape
    im_nodes_id = np.zeros_like(im, dtype=int) - 1
    for _key in nodes.keys():
        for h, w in nodes[_key]:
            im_nodes_id[h, w] = _key

    nodes_output = [[] for _ in range(len(nodes))]
    for _key in nodes.keys():
        que = deque(nodes[_key].copy())
        check = nodes[_key].copy()
        while que:
            _h, _w = que.popleft()
            for dh in [-1, 0, 1]:
                for dw in [-1, 0, 1]:
                    if dh == 0 and dw == 0: continue
                    nh = _h + dh
                    nw = _w + dw
                    if 0<nh<H and 0<=nw<W and im[nh, nw]==1 and (not [nh, nw] in check):
                        check.append([nh, nw])
                        if im_nodes_id[nh, nw] != -1:
                            nodes_output[_key].append(im_nodes_id[nh, nw])
                        else:
                            que.append([nh, nw])
    return nodes_output


def get_linkingnodes(connect_edges: defaultdict):
    nodes_check = [0] * len(connect_edges)
    nodes_size = [0] * len(connect_edges)

    for i in range(len(connect_edges)):
        if nodes_check[i] == 1: continue
        nodes_check[i] = 1

        que = deque([i])
        path_list = [i]
        while que:
            x = que.popleft()
            for y in connect_edges[x]:
                if nodes_check[y] == 1: continue
                nodes_check[y] = 1

                que.append(y)
                path_list.append(y)

        for _path in path_list:
            nodes_size[_path] = len(path_list)

    _max = max(nodes_size)
    _set = set([i for i in range(len(nodes_size)) if nodes_size[i] == _max])

    nodes_output = [[] for _ in range(len(connect_edges))]
    for x in range(len(connect_edges)):
        if not x in _set: continue
        for y in connect_edges[x]:
            if not y in _set: continue
            nodes_output[x].append(y)
    return nodes_output, _set


def get_network(path):
    new_path = path.split('/')[-1][9:][:-4]
    
    try:
        im = cv2.imread(path, 0)
        im = (im > 100).astype(int)
        im_nodes = get_nodes(im)
        nodes = get_nodepoints(im_nodes)
        connect_edges = get_connectnodes(im, nodes)
        connect_edges_linking, nodes_linking = get_linkingnodes(connect_edges)

        out_path = new_path + '.txt'
        with open(os.path.join(LINKOUTPUT_DIR, out_path), 'wb') as f:
            pickle.dump(connect_edges_linking, f)

        out_path = new_path + '.json'
        with open(os.path.join(POSOUTPUT_DIR, out_path), 'wb') as f:
            pickle.dump(nodes, f)
        return None
    
    except:
        out_path

In [2]:
BASE_DIR = '../'
IMAGE_DIR = os.path.join(BASE_DIR, 'data/interm/skeleton-non-treated-dataset')
OUTPUT_DIR = os.path.join(BASE_DIR, 'data/processed/network-non-treated-dataset')
LINKOUTPUT_DIR = os.path.join(OUTPUT_DIR, 'node-link')
POSOUTPUT_DIR = os.path.join(OUTPUT_DIR, 'node-position')

image_paths = glob(os.path.join(IMAGE_DIR, '*'))
image_results = joblib.Parallel(n_jobs=-1)(
    joblib.delayed(get_network)(path) for path in tqdm(image_paths)
)

  0%|          | 0/1 [00:00<?, ?it/s]