In [1]:
from osgeo import gdal
import numpy as np
import os
# os.environ['PROJ_LIB'] = r'C:\Users\Lenovo\.conda\envs\zph\Library\share\proj'
# os.environ['GDAL_DATA'] = r'C:\Users\Lenovo\.conda\envs\zph\Library\share'
# gdal.PushErrorHandler("CPLQuietErrorHandler")


class ImageProcess:
    def __init__(self, filepath: str):
        self.filepath = filepath
        self.dataset = gdal.Open(self.filepath, gdal.GA_ReadOnly)
        self.info = []
        self.img_data = None
        self.data_8bit = None

    def read_img_info(self):
        # 获取波段、宽、高
        img_bands = self.dataset.RasterCount
        img_width = self.dataset.RasterXSize
        img_height = self.dataset.RasterYSize
        # 获取仿射矩阵、投影
        img_geotrans = self.dataset.GetGeoTransform()
        img_proj = self.dataset.GetProjection()
        self.info = [img_bands, img_width, img_height, img_geotrans, img_proj]
        return self.info

    def read_img_data(self):
        self.img_data = self.dataset.ReadAsArray(0, 0, self.info[1], self.info[2])
        return self.img_data

    # 影像写入文件
    @staticmethod
    def write_img(filename: str, img_data: np.array, **kwargs):
        # 判断栅格数据的数据类型
        if 'int8' in img_data.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in img_data.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32
        # 判读数组维数
        if len(img_data.shape) >= 3:
            img_bands, img_height, img_width = img_data.shape
        else:
            img_bands, (img_height, img_width) = 1, img_data.shape
        # 创建文件
        driver = gdal.GetDriverByName("GTiff")
        outdataset = driver.Create(filename, img_width, img_height, img_bands, datatype)
        # 写入仿射变换参数
        if 'img_geotrans' in kwargs:
            outdataset.SetGeoTransform(kwargs['img_geotrans'])
        # 写入投影
        if 'img_proj' in kwargs:
            outdataset.SetProjection(kwargs['img_proj'])
        # 写入文件
        if img_bands == 1:
            outdataset.GetRasterBand(1).WriteArray(img_data)  # 写入数组数据
        else:
            for i in range(img_bands):
                outdataset.GetRasterBand(i + 1).WriteArray(img_data[i])

        del outdataset


def read_single_band(band_path):
    """
    读取单波段文件
    :param band_path: 单波段文件路径
    :return: 影像对象，影像元信息，影像矩阵
    """
    # 影像读取
    band = ImageProcess(filepath=band_path)
    # 读取影像元信息
    band_info = band.read_img_info()
    print(f"单波段影像元信息：{band_info}")
    # 读取影像矩阵
    band_data = band.read_img_data()
    print(f"单波段矩阵大小：{band_data.shape}")
    return band, band_info, band_data



In [2]:
import math
import numpy as np
from alive_progress import alive_bar
# from module.image import *


def cal_single_band_slice(single_band_data, slice_size=512):
    """
    计算单波段的格网裁剪四角点
    :param single_band_data:单波段原始数据
    :param slice_size: 裁剪大小
    :return: 嵌套列表，每一个块的四角行列号
    """
    single_band_size = single_band_data.shape
    row_num = math.ceil(single_band_size[0] / slice_size)  # 向上取整
    col_num = math.ceil(single_band_size[1] / slice_size)  # 向上取整
    print(f"行列数：{single_band_size}，行分割数量：{row_num}，列分割数量：{col_num}")
    slice_index = []
    for i in range(row_num):
        for j in range(col_num):
            row_min = i * slice_size
            row_max = (i + 1) * slice_size
            if (i + 1) * slice_size > single_band_size[0]:
                row_max = single_band_size[0]
            col_min = j * slice_size
            col_max = (j + 1) * slice_size
            if (j + 1) * slice_size > single_band_size[1]:
                col_max = single_band_size[1]
            slice_index.append([row_min, row_max, col_min, col_max])
    return slice_index





def single_band_slice(single_band_data, index=[0, 1000, 0, 1000], slice_size=1000, edge_fill=False):
    """
    依据四角坐标，切分单波段影像
    :param single_band_data:原始矩阵数据
    :param index: 四角坐标
    :param slice_size: 分块大小
    :param edge_fill: 是否进行边缘填充
    :return: 切分好的单波段矩阵
    """
    if edge_fill:
        if (index[1] - index[0] != slice_size) or (index[3] - index[2] != slice_size):
            result = np.empty(shape=(slice_size, slice_size))
            new_row_min = index[0] % slice_size
            new_row_max = new_row_min + (index[1] - index[0])
            new_col_min = index[2] % slice_size
            new_col_max = new_col_min + (index[3] - index[2])
            result[new_row_min:new_row_max, new_col_min:new_col_max] = single_band_data[index[0]:index[1],
                                                                       index[2]:index[3]]
        else:
            result = single_band_data[index[0]:index[1], index[2]:index[3]]
    else:
        result = single_band_data[index[0]:index[1], index[2]:index[3]]
    return result.astype(single_band_data.dtype)



def slice_conbine(slice_all, slice_index):
    """
    将分块矩阵进行合并
    :param slice_all: 所有的分块矩阵列表
    :param slice_index: 分块的四角坐标
    :return: 合并的矩阵
    """
    combine_data = np.zeros(shape=(slice_index[-1][1], slice_index[-1][3]))
    # print(combine_data.shape)
    for i, slice_element in enumerate(slice_index):
        combine_data[slice_element[0]:slice_element[1], slice_element[2]:slice_element[3]] = slice_all[i]
    return combine_data


def coordtransf(Xpixel, Ypixel, GeoTransform):
    """
    像素坐标和地理坐标仿射变换
    :param Xpixel: 左上角行号
    :param Ypixel: 左上角列号
    :param GeoTransform: 原始仿射矩阵
    :return: 新的仿射矩阵
    """
    XGeo = GeoTransform[0] + GeoTransform[1] * Xpixel + Ypixel * GeoTransform[2]
    YGeo = GeoTransform[3] + GeoTransform[4] * Xpixel + Ypixel * GeoTransform[5]
    slice_geotrans = (XGeo, GeoTransform[1], GeoTransform[2], YGeo, GeoTransform[4], GeoTransform[5])
    return slice_geotrans


def single_band_grid_slice(band_path, band_slice_dir, slice_size, edge_fill=False):
    """
    单波段格网裁剪
    :param band_path: 原始单波段影像
    :param band_slice_dir: 裁剪保存文件夹
    :param slice_size: 裁剪大小
    :return:
    """
    band, band_info, band_data = read_single_band(band_path)
    # 计算分块的四角行列号
    slice_index = cal_single_band_slice(band_data, slice_size=slice_size)
    # 执行裁剪
    with alive_bar(len(slice_index), force_tty=True) as bar:
        for i, slice_element in enumerate(slice_index):
            slice_data = single_band_slice(band_data, index=slice_element, slice_size=slice_size,
                                           edge_fill=edge_fill)  # 裁剪单波段影像
            slice_geotrans = coordtransf(slice_element[2], slice_element[0], band_info[3])  # 转换仿射坐标
            band.write_img(band_slice_dir + r'\single_grid_slice_' + str(i) + '.tif', slice_data,
                           img_geotrans=slice_geotrans, img_proj=band_info[4])  # 写入文件
            bar()
        print('单波段格网裁剪完成')

In [3]:
image_path = "raw_material\data\landuse.tif"
image_slice_dir = "dataset_result\imagry_classification"
slice_size = 512
edge_fill = True
slice = single_band_grid_slice(image_path, image_slice_dir, slice_size, edge_fill=edge_fill)


# configure the training data folder
folder1 = os.path.join(root, 'dataset')
if not os.path.exists(folder1): os.mkdir(folder1)
foldertraining = os.path.join(folder1, 'trainning')
if not os.path.exists(foldertraining): os.mkdir(foldertraining)
train_label_folder = os.path.join(foldertraining, 'labels')
if not os.path.exists(train_label_folder): os.mkdir(train_label_folder)

AttributeError: 'NoneType' object has no attribute 'RasterCount'

In [10]:

root = "raw_material"
folder1 = os.path.join(root, 'dataset')
print(folder1)
# if not os.path.exists(folder1): os.mkdir(folder1)



raw_material\dataset


In [11]:
conda activate tensorflow



Note: you may need to restart the kernel to use updated packages.
