In [None]:
import os
import glob
import ogr
import gdal
import sys
from collections import defaultdict
import numpy as np
import cv2
import matplotlib.pyplot as plt
from math import sqrt
ogr.RegisterAll()
gdal.SetConfigOption('SHAPE_ENCODING', "UTF8")
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")

In [None]:
def geo2cr(geoTransform, px, py):
    dTemp = geoTransform[1] * geoTransform[5] - geoTransform[2] * geoTransform[4]
    col = (geoTransform[5] * (px - geoTransform[0]) - geoTransform[2] * (py - geoTransform[3])) / dTemp + 0.5
    row = (geoTransform[1] * (py - geoTransform[3]) - geoTransform[4] * (px - geoTransform[0])) / dTemp + 0.5
    return col, row

def walk(input_dir):
    tif_files = []
    shp_files = []
    for root, _, files in os.walk(input_dir):
        for file in files:
            if file.endswith('tif'):
                tif_files.append(os.path.join(root, file))
            elif file.endswith('.shp'):
                shp_files.append(os.path.join(root, file))
    assert len(tif_files) == 1 and len(shp_files) == 1, '{} should have 1 tif and 1 shp , but get {} tifs and {} shps'.format(input_dir, len(tif_files), len(shp_files))
    return tif_files[0], shp_files[0]

def check_tifs(files, items=['pan', 'ms', 'fusion']):
    # check file available
    files_new = {}
    for k, v in files.items():
        for item in items:
            tif_file = v[item]
            if not os.path.exists(tif_file):
                print('%s: %s doesn\'t exist' % (k, item))
                break
            dataset = gdal.Open(tif_file)
            if dataset is None:
                print('%s: %s open fail' % (k, item))
                break
        files_new[k] = v
    return files_new

In [None]:
min_box_size = 4

input_dir = '/home/zhoufeipeng/tmp/jinan'
dir1 = '融合图像'
dir2 = '高分图像'
dir3 = '多光谱'
fusion_dir = os.path.join(input_dir, dir1)
pan_dir = os.path.join(input_dir, dir2)
ms_dir = os.path.join(pan_dir, dir3)

fusion_files = sorted(glob.glob(os.path.join(fusion_dir, '*.tif')))
tif_filenames = [os.path.splitext(os.path.basename(fusion_file))[0] for fusion_file in fusion_files]

files = {}
for tif_filename, fusion_file in zip(tif_filenames, fusion_files):
    files[tif_filename] = {}
    files[tif_filename]['fusion'] = fusion_file

for tif_filename in tif_filenames:
    tif_file, shp_file = walk(os.path.join(pan_dir, tif_filename))
    files[tif_filename]['pan'] = tif_file
    files[tif_filename]['shp'] = shp_file
    ms_file = os.path.join(ms_dir, '%s.tif' % tif_filename)
    assert os.path.exists(ms_file), '{} doesn\'t exist'
    files[tif_filename]['ms'] = ms_file

files = check_tifs(files)

for k, v in files.items():
    print(k)
    print(v)
    break

In [None]:
cate2 = {'住房': 'Hb', '厂房': 'Fb', '大棚': 'Gh', '其他1': 'OI', '其他2': 'OII', None: 'OII'}
file_polys = {}
files_fail = []
polys_fail = defaultdict(int)
num_per_cate = defaultdict(int)
for k, v in files.items():
    print(k)
    file_polys[k] = {}
    file_polys[k]['points'] = []
    file_polys[k]['box'] = []
    file_polys[k]['rbox'] = []

    pan_tif_file = v['pan']
    ms_tif_file = v['ms']
    pan_tif_file = v['pan']

    dataset = gdal.Open(pan_tif_file)
    if dataset is None:
        print('fail to open {}'.format(pan_tif_file))
        files_fail.append(k)
        continue

    width = dataset.RasterXSize
    height = dataset.RasterYSize
    geoTransform = dataset.GetGeoTransform()
    if geoTransform is None:
        print('{}: geoTransfrom is None'.format(pan_tif_file))
        files_fail.append(k)
        continue

    shp_file = v['shp']
    dataSource = ogr.Open(shp_file)
    if dataSource is None:
        print('fail to open {}'.format(shp_file))
        files_fail.append(k)
        continue
    daLayer = dataSource.GetLayer(0)
    featureCount = daLayer.GetFeatureCount()
    
    daLayer.ResetReading()
    for _ in range(featureCount):
        feature = daLayer.GetNextFeature()
        fieldName = feature.GetField('类型')
        if fieldName == '其它2':
            fieldName = '其他2'
        geometry = feature.GetGeometryRef()
        if geometry is None:
            print('{}: geometry is None'.format(shp_file))
            polys_fail[k] += 1
            continue
        geometryType = geometry.GetGeometryType()
        if geometryType != ogr.wkbPolygon:
            print('{}: the FID is {}, the type of geometry is {}'.format(shp_file, feature.GetFID(), geometryType))
            polys_fail[k] += 1
            continue
        geometryCount = geometry.GetGeometryCount()
        if geometryCount != 1:
            print('{}: poly has {} rings'.format(shp_file, geometryCount))
            polys_fail[k] += 1
  
        ring = geometry.GetGeometryRef(0)
        numPoints = ring.GetPointCount()
        if numPoints < 4:
            print('{} : the num of ring is less than 3'.format(shp_file))
            polys_fail[k] += 1
            continue

        points = []
        max_y = 0
        max_x = 0
        min_y = height - 1
        min_x = width - 1
        for i in range(numPoints - 1):
            x, y = geo2cr(geoTransform, ring.GetX(i), ring.GetY(i))
            
            x = max(min(x, width - 1), 0)
            y = max(min(y, height - 1), 0)

            points.extend([x, y])
        
        points = np.array(points).reshape(-1, 2)
        x_min, y_min = points.min(axis=0)
        x_max, y_max = points.max(axis=0)
        w = x_max - x_min + 1
        h = y_max - y_min + 1
        box = [x_min, y_min, w, h]
        
        rect = cv2.minAreaRect(points.astype(np.int))
        x_c, y_c = rect[0]
        w, h = rect[1]
        a = rect[2]
        rbox = [x_c, y_c, w, h, a]

        if min(box[2], box[3]) < min_box_size:
            print('{}: the poly is so small, FID:{}, width:{}, height:{}'.format(shp_file, feature.GetFID(), box[2], box[3]))
            polys_fail[k] += 1
            continue
        num_per_cate[cate2[fieldName]] += 1
        file_polys[k]['box'].append(box)
        file_polys[k]['points'].append(points)
        file_polys[k]['rbox'].append(rbox)

In [None]:
out_img_dir = './jinan_summery_figure'
os.makedirs(out_img_dir, exist_ok=True)

In [None]:
print(num_per_cate)
labels = list(num_per_cate.keys())
plt.bar(labels, list(num_per_cate.values()))

# plt.savefig(os.path.abspath('cate.png'))
plt.show()
print('error with the num of files:', len(files_fail))
print('error with the num of polys:', sum(polys_fail.values()))
num_building = 0
for k, v in file_polys.items():
    num_building += len(file_polys[k]['points'])
print('the num of polys:', num_building)

In [None]:
angles = []
for k, v in file_polys.items():
    for rbox in v['rbox']:
        w, h, angle  = rbox[-3:]
        if w < h:
            angle = angle + 90
        angles.append(angle)
plt.hist(angles, rwidth=0.9)
plt.savefig(os.path.join(out_img_dir, 'angle.png'))
plt.show()

In [None]:
import bisect
sizes_range = [0, 32, 96, 1000]
sizes_num = [0 for _ in range(len(sizes_range))]
sizes = []
for k, v in file_polys.items():
    sizes_per_image = [sqrt(box[2] * box[3]) for box in v['box']]
    pos_per_image = [bisect.bisect_right(sizes_range, size) for size in sizes_per_image]
    for pos in pos_per_image:
        sizes_num[pos - 1] += 1
    sizes.extend(sizes_per_image)
print(sizes_num)
plt.hist(sizes, bins=8, rwidth=0.8)
plt.savefig(os.path.join(out_img_dir, 'instance_size.png'))
plt.show()

In [None]:
bbox_ratios = []
rbox_ratios = []
for k, v in file_polys.items():
    for box in v['box']:
        w = box[2]
        h = box[3]
        big = max(w, h)
        small = min(w, h)
        if big / small > 15:
            continue
        bbox_ratios.append(big / small)
        
    for rbox in v['rbox']:
        w = rbox[2]
        h = rbox[3]
        big = max(w, h)
        small = min(w, h)
        if big / small > 15:
            continue
        rbox_ratios.append(big / small)
plt.hist(bbox_ratios, bins=15, rwidth=0.8)
plt.savefig(os.path.join(out_img_dir, 'bbox_ratio.png'))
plt.show()
plt.hist(rbox_ratios, bins=15, rwidth=0.8)
# plt.savefig(os.path.abspath('rbox_ratio.png'))
plt.show()