In [2]:
# GTE encoding
import numpy as np
import pickle
import math
import os.path as osp
from multiprocessing import Pool    # 多进程
from concurrent.futures import ThreadPoolExecutor   # 多线程

# .p文件是rc坐标 

IMG_SIZE = 2048
OFFSET = 1
MAX_DEGREE = 6
VECTOR_NORM = 25.0

NUM_PROCESS = 10
NUM_THREAD = 10
 
# 将图编码为19维张量
graph_pattern = './cityscale/20cities/region_{}_refine_gt_graph.p'
rgb_pattern   = './cityscale/20cities/region_{}_sat.png'

save_dir = './cityscale/GTE'


def encode_GTE(tile_id):
    graph_path = graph_pattern.format(tile_id)
    graph_adj = pickle.load(open(graph_path, 'rb'))
    
    GTE = np.zeros((IMG_SIZE, IMG_SIZE, 3*MAX_DEGREE+1))
    for v, neis in graph_adj.items():
        r, c = v
        if (r < 16) or (r > IMG_SIZE- 16) or (c < 16) or (c > IMG_SIZE- 16):
            continue
        for row in range(r-OFFSET, r+OFFSET+1):
            for col in range(c-OFFSET, c+OFFSET+1):
                GTE[row, col, 0] = 1
        
        for nei in neis:
            nei_r, nei_c = nei
            if (nei_r < 16) or (nei_r > IMG_SIZE- 16) or (nei_c < 16) or (nei_c > IMG_SIZE- 16):
                continue
            dx, dy = nei_c-c, nei_r-r
            d = math.atan2(dy, dx) + math.pi
            j = int(d/(math.pi/3.0)) % MAX_DEGREE  # 扇区编号0-5
            
            for row in range(r-OFFSET, r+OFFSET+1):
                for col in range(c-OFFSET, c+OFFSET+1):
                    GTE[row, col, 1+3*j] = 1
                    GTE[row, col, 1+3*j+1] = dx / VECTOR_NORM
                    GTE[row, col, 1+3*j+2] = dy / VECTOR_NORM
                    
    np.savez(osp.join(save_dir, f'region_{tile_id}_GTE.npz'), GTE=GTE)
    
    return  GTE

In [18]:
# Decoding for verifying sake
import cv2 as cv
from skimage import measure
from scipy.ndimage import rotate
import scipy

IMG_SIZE = 2048
OFFSET = 1
MAX_DEGREE = 6
VECTOR_NORM = 25.0



def vis_GT_GTE(GTE, keypoint_thr=0.5, edge_thr=0.5, aug=False, rot_angle=90, rot_index=0):
    # vis_output = np.zeros((512, 512, 3))    # 不加底图
    vis_output = cv.imread(rgb_pattern.format(7))[:512, :512, :]
    sub_GTE = GTE[:512, :512, :]
    
    if aug:
        sub_GTE = GTE[:512, :512, :]

        rot_radians = math.radians(rot_angle)
        rot_mat = np.array([
                [math.cos(rot_radians), +math.sin(rot_radians)],
                [-math.sin(rot_radians), math.cos(rot_radians)]
        ], dtype=np.float32)        # 注意这里的旋转矩阵，由于编码的时候是用的图像坐标，y方向跟笛卡尔坐标是相反的，所以笛卡尔坐标旋转的逆时针对应图像坐标的顺时针旋转
        
        sub_GTE = rotate(sub_GTE, rot_angle*rot_index, axes=(0, 1), reshape=False)
        vis_output = rotate(vis_output, rot_angle*rot_index, axes=(0, 1), reshape=False)
        
        # 只对所有不为0的进行变换，其余的全是0，没必要再计算了
        for r, c in np.column_stack(np.where(sub_GTE[:,:, 0]>0.98)):
            delta_coords = []
            for j in range(MAX_DEGREE):
                delta_coords.append([sub_GTE[r, c, 1+3*j+1], sub_GTE[r, c, 1+3*j+2]])   # dx, dy
            delta_coords_to_rot = np.column_stack(delta_coords)
            roted_coords = np.linalg.matrix_power(rot_mat, rot_index)@delta_coords_to_rot
            
            for j in range(MAX_DEGREE):    # stick back
                sub_GTE[r, c, 1+3*j+1], sub_GTE[r, c, 1+3*j+2] = roted_coords[:, j]
                

        
    keypoint_map = sub_GTE[:,:, 0]
    
    # 从团中寻找中心点
    keypoint_map = (keypoint_map > keypoint_thr).astype(np.uint8)
    labels = measure.label(keypoint_map, connectivity=2)
    props = measure.regionprops(labels)
    min_area = 4
    for region in props:
        if region.area > min_area:
            center = region.centroid[::-1]   # rc-> xy
            center_x, center_y = int(center[0]), int(center[1])
            cv.circle(vis_output, (center_x, center_y), radius=2, color=(0, 0, 255), thickness=-1)
            
            r, c = center_y, center_x
            for j in range(MAX_DEGREE):
                edgeness = sub_GTE[r, c, 1+3*j]
                if edgeness > edge_thr:
                    dx, dy = sub_GTE[r, c, 1+3*j+1], sub_GTE[r, c, 1+3*j+2]
                    # dx, dy = rot_mat@np.array([dx, dy])
                    dst_x, dst_y = int(center_x + VECTOR_NORM*dx), int(center_y + VECTOR_NORM*dy)
                    cv.line(vis_output, (center_x, center_y), (dst_x, dst_y), color=(0, 255, 0), thickness=1)
    cv.imwrite(osp.join(verify_dir, f'region_7_vis_{rot_angle*rot_index}_{aug}aug_with_rgb.png'), vis_output)
    
    
    
GTE = encode_GTE(tile_id=7)
verify_dir = './cityscale/verify'
vis_GT_GTE(GTE, aug=True, rot_angle=90, rot_index=2)
# vis_GT_GTE(GTE, aug=False)
    
    