In [13]:
from osgeo import gdal
import numpy as np
import os
# os.environ['PROJ_LIB'] = r'C:\Users\Lenovo\.conda\envs\zph\Library\share\proj'
# os.environ['GDAL_DATA'] = r'C:\Users\Lenovo\.conda\envs\zph\Library\share'
# gdal.PushErrorHandler("CPLQuietErrorHandler")


class ImageProcess:
    def __init__(self, filepath: str):
        self.filepath = filepath
        self.dataset = gdal.Open(self.filepath, gdal.GA_ReadOnly)
        self.info = []
        self.img_data = None
        self.data_8bit = None

    def read_img_info(self):
        # 获取波段、宽、高
        img_bands = self.dataset.RasterCount
        img_width = self.dataset.RasterXSize
        img_height = self.dataset.RasterYSize
        # 获取仿射矩阵、投影
        img_geotrans = self.dataset.GetGeoTransform()
        img_proj = self.dataset.GetProjection()
        self.info = [img_bands, img_width, img_height, img_geotrans, img_proj]
        return self.info

    def read_img_data(self):
        self.img_data = self.dataset.ReadAsArray(0, 0, self.info[1], self.info[2])
        return self.img_data

    # 影像写入文件
    @staticmethod
    def write_img(filename: str, img_data: np.array, **kwargs):
        # 判断栅格数据的数据类型
        if 'int8' in img_data.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in img_data.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32
        # 判读数组维数
        if len(img_data.shape) >= 3:
            img_bands, img_height, img_width = img_data.shape
        else:
            img_bands, (img_height, img_width) = 1, img_data.shape
        # 创建文件
        driver = gdal.GetDriverByName("GTiff")
        outdataset = driver.Create(filename, img_width, img_height, img_bands, datatype)
        # 写入仿射变换参数
        if 'img_geotrans' in kwargs:
            outdataset.SetGeoTransform(kwargs['img_geotrans'])
        # 写入投影
        if 'img_proj' in kwargs:
            outdataset.SetProjection(kwargs['img_proj'])
        # 写入文件
        if img_bands == 1:
            outdataset.GetRasterBand(1).WriteArray(img_data)  # 写入数组数据
        else:
            for i in range(img_bands):
                outdataset.GetRasterBand(i + 1).WriteArray(img_data[i])

        del outdataset


def read_multi_bands(image_path):
    """
    读取多波段文件
    :param image_path: 多波段文件路径
    :return: 影像对象，影像元信息，影像矩阵
    """
    # 影像读取
    image = ImageProcess(filepath=image_path)
    # 读取影像元信息
    image_info = image.read_img_info()
    print(f"多波段影像元信息：{image_info}")
    # 读取影像矩阵
    image_data = image.read_img_data()
    print(f"多波段矩阵大小：{image_data.shape}")
    return image, image_info, image_data



In [14]:
import math
import numpy as np
from alive_progress import alive_bar



def cal_single_band_slice(single_band_data, slice_size=512):
    """
    计算单波段的格网裁剪四角点
    :param single_band_data:单波段原始数据
    :param slice_size: 裁剪大小
    :return: 嵌套列表，每一个块的四角行列号
    """
    single_band_size = single_band_data.shape
    row_num = math.ceil(single_band_size[0] / slice_size)  # 向上取整
    col_num = math.ceil(single_band_size[1] / slice_size)  # 向上取整
    print(f"行列数：{single_band_size}，行分割数量：{row_num}，列分割数量：{col_num}")
    slice_index = []
    for i in range(row_num):
        for j in range(col_num):
            row_min = i * slice_size
            row_max = (i + 1) * slice_size
            if (i + 1) * slice_size > single_band_size[0]:
                row_max = single_band_size[0]
            col_min = j * slice_size
            col_max = (j + 1) * slice_size
            if (j + 1) * slice_size > single_band_size[1]:
                col_max = single_band_size[1]
            slice_index.append([row_min, row_max, col_min, col_max])
    return slice_index





def multi_bands_slice(multi_bands_data, index=[0, 512, 0, 512], slice_size=512, edge_fill=False):
    """
    依据四角坐标，切分多波段影像
    :param multi_bands_data: 原始多波段矩阵
    :param index: 四角坐标
    :param slice_size: 分块大小
    :param edge_fill: 是否进行边缘填充
    :return: 切分好的多波段矩阵
    """
    if edge_fill==True:
        if (index[1] - index[0] != slice_size) or (index[3] - index[2] != slice_size):
            result = np.empty(shape=(multi_bands_data.shape[0], slice_size, slice_size))
            new_row_min = index[0] % slice_size    # 0
            new_row_max = new_row_min + (index[1] - index[0])  
            new_col_min = index[2] % slice_size    # 0
            new_col_max = new_col_min + (index[3] - index[2])
            result[:, new_row_min:new_row_max, new_col_min:new_col_max] = multi_bands_data[:, index[0]:index[1],
                                                                          index[2]:index[3]]
        else:
            result = multi_bands_data[:, index[0]:index[1], index[2]:index[3]] 
    else:
        result = multi_bands_data[:, index[0]:index[1], index[2]:index[3]]
    return result.astype(multi_bands_data.dtype)


def slice_conbine(slice_all, slice_index):
    """
    将分块矩阵进行合并
    :param slice_all: 所有的分块矩阵列表
    :param slice_index: 分块的四角坐标
    :return: 合并的矩阵
    """
    combine_data = np.zeros(shape=(slice_index[-1][1], slice_index[-1][3]))
    # print(combine_data.shape)
    for i, slice_element in enumerate(slice_index):
        combine_data[slice_element[0]:slice_element[1], slice_element[2]:slice_element[3]] = slice_all[i]
    return combine_data


def coordtransf(Xpixel, Ypixel, GeoTransform):
    """
    像素坐标和地理坐标仿射变换
    :param Xpixel: 左上角行号
    :param Ypixel: 左上角列号
    :param GeoTransform: 原始仿射矩阵
    :return: 新的仿射矩阵
    """
    XGeo = GeoTransform[0] + GeoTransform[1] * Xpixel + Ypixel * GeoTransform[2]
    YGeo = GeoTransform[3] + GeoTransform[4] * Xpixel + Ypixel * GeoTransform[5]
    slice_geotrans = (XGeo, GeoTransform[1], GeoTransform[2], YGeo, GeoTransform[4], GeoTransform[5])
    return slice_geotrans


def multi_bands_grid_slice(image_path, image_slice_dir, slice_size, edge_fill=False):
    """
    多波段格网裁剪
    :param image_path: 原始多波段影像
    :param image_slice_dir: 裁剪保存文件夹
    :param slice_size: 裁剪大小
    :return:
    """
    image, image_info, image_data = read_multi_bands(image_path)

    # 计算分块的四角行列号
    slice_index = cal_single_band_slice(image_data[0, :, :], slice_size=slice_size)
    # 执行裁剪
    with alive_bar(len(slice_index), force_tty=True) as bar:
        for i, slice_element in enumerate(slice_index):
            slice_data = multi_bands_slice(image_data, index=slice_element, slice_size=slice_size,
                                           edge_fill=edge_fill)  # 裁剪多波段影像
            slice_geotrans = coordtransf(slice_element[2], slice_element[0], image_info[3])  # 转换仿射坐标
            image.write_img(image_slice_dir + r'\multi_grid_slice_' + str(i) + '.tif', slice_data,
                            img_geotrans=slice_geotrans, img_proj=image_info[4])  # 写入文件
            bar()
        print('多波段格网裁剪完成')


In [15]:
image_path = "E:\\unet_complete_version\\raw_material\\data\\naip.tif"
image_slice_dir = "dataset_result\imagry"
slice_size = 512
edge_fill = True
slice = multi_bands_grid_slice(image_path, image_slice_dir, slice_size, edge_fill=edge_fill)

多波段影像元信息：[3, 5079, 5079, (2710573.1336033433, 1.9692670323847825, 0.0, 264944.4413223281, 0.0, -1.9692670323847825), 'PROJCS["NAD83 / Pennsylvania South (ftUS)",GEOGCS["NAD83",DATUM["North American Datum 1983",SPHEROID["GRS 1980",6378137,298.257222101004]],PRIMEM["Greenwich",0],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]]],PROJECTION["Lambert_Conformal_Conic_2SP"],PARAMETER["latitude_of_origin",39.3333333333333],PARAMETER["central_meridian",-77.75],PARAMETER["standard_parallel_1",40.9666666666667],PARAMETER["standard_parallel_2",39.9333333333333],PARAMETER["false_easting",1968500],PARAMETER["false_northing",0],UNIT["US survey foot",0.304800609601219,AUTHORITY["EPSG","9003"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]']
多波段矩阵大小：(3, 5079, 5079)
行列数：(5079, 5079)，行分割数量：10，列分割数量：10
on 100: 多波段格网裁剪完成                                                               
|████████████████████████████████████████| 100/100 [100%] in 1.9s (52.33/s)     


In [None]:
import os
import numpy as np
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
from pathlib import Path

# this is used to augment the image
transform_aug = transforms.Compose([
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.RandomRotation(degrees=(-90, 90)),
     transforms.RandomVerticalFlip(p=0.5)
])

class CustomDataset(Dataset):
    def __init__(self, img_dir, msk_dir, pytorch=True, transforms=None):
        super().__init__()
        
        self.files = [self.combine_files(f, img_dir, msk_dir) for f in img_dir.iterdir() if not f.is_dir()]
        self.pytorch = pytorch
        self.transforms = transforms

    def combine_files(self, r_file: Path, img_dir, msk_dir):
        files = {'image': r_file, 
                 'mask': msk_dir/r_file.name.replace('naip_tiles', 'lu_masks')}
        return files
        
    def __len__(self):
        return len(self.files)
    
    def open_as_array(self, idx, augment=False):
        image_path = str(self.files[idx]['image'])
        image = cv2.imread(image_path)
        
        if augment:
            # Apply augmentation to the image
            # You can add your image augmentation logic here
            pass
        
        image = cv2.resize(image, (512, 512))
        image = image.transpose((2, 0, 1)) / 255.0
        
        return image
    
    def open_mask(self, idx, add_dims=False, augment=False):
        mask_path = str(self.files[idx]['mask'])
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        if augment:
            # Apply augmentation to the mask
            # You can add your mask augmentation logic here
            pass

        mask = cv2.resize(mask, (512, 512))
        mask = np.where(mask == 5, 1, 0)
        
        if add_dims:
            mask = np.expand_dims(mask, 0)
        
        return mask

    def __getitem__(self, idx):
        image = self.open_as_array(idx, augment=True)
        mask = self.open_mask(idx, add_dims=False, augment=True)
        
        return torch.tensor(image, dtype=torch.float32), torch.tensor(mask, dtype=torch.int64)
    
    def __repr__(self):
        s = 'Dataset class with {} files'.format(self.__len())
        return s

base_path = Path('raw_material\\overall data\\dataset\\trainning')
data = CustomDataset(base_path/'imgs', base_path/'labels')
print(len(data))

train_ds, valid_ds = torch.utils.data.random_split(data, (300, 20))
train_dl = DataLoader(train_ds, batch_size=1, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=1, shuffle=True)

print('train_dl is:', len(train_dl))
print('valid_dl is:', len(valid_dl))
for batch in valid_dl:
    if batch is None:
        print("Data batch is empty")
    else:
        print(batch)