In [2]:
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
from tqdm import tqdm
import shutil
import json
from datetime import datetime
from pathlib import Path
import random
from osgeo import gdal
import glob
import rasterio

In [None]:
country_list = ["BR_A001", "NZ_A001", "NZ_A002", "US_A001", "ZA_A001"]
wv_list = ["WV01", "WV02", "WV03"]

In [13]:
# train_dir = '/scratch2/ziyliu/pro_data/train'
# tiles = [tile.name for tile in Path(train_dir).iterdir() if tile.is_dir()]
# print(tiles)

# 为文件命名开头加上tile名字
# for tile in tiles:
#     folder_path = os.path.join(train_dir, tile)
#     for image in os.listdir(os.path.join(folder_path, 'image')):
#         image_path = os.path.join(folder_path, 'image', image)
#         new_image_path = os.path.join(folder_path, 'image', tile+'_'+image)
#         os.rename(image_path, new_image_path)
#     for mask in os.listdir(os.path.join(folder_path, 'mask')):
#         mask_path = os.path.join(folder_path, 'mask', mask)
#         new_mask_path = os.path.join(folder_path, 'mask', tile+'_'+mask)
#         os.rename(mask_path, new_mask_path)

['NZ_A001', 'NZ_A002', 'US_A001', 'ZA_A001']


In [4]:
# 从train数据集中随机选取image和mask对，重投影到相同分辨率
# 然后保存为一对image和mask到repro_source文件夹里作为val和test的数据来源

indir = '../pro_data/train/'
reprosourcedir = '../pro_data/repro_resource/'
if not os.path.exists(reprosourcedir):
    os.makedirs(reprosourcedir)

pair_filenames = {}
tiles = [tile.name for tile in Path(indir).iterdir() if tile.is_dir()]

for tile in tiles:
    if tile not in pair_filenames:
        pair_filenames[tile] = {"image_path": [], "mask_path": []}
    for image in os.listdir(indir + '/' + tile + '/image/'):
        if image.endswith(".tif"):
            image_path = indir + '/' + tile + '/image/' + image
            pair_filenames[tile]["image_path"].append(image_path)
    for mask in os.listdir(indir + '/' + tile + '/mask/'):
        if mask.endswith(".tif"):
            mask_path = indir + '/' + tile + '/mask/' + mask
            pair_filenames[tile]["mask_path"].append(mask_path)

i = 0
while i < 30:
    tile = random.choice(tiles)
    img_fname = random.choice(pair_filenames[tile]["image_path"])
    mask_fname = random.choice(pair_filenames[tile]["mask_path"])
    if os.path.exists(reprosourcedir + os.path.basename(img_fname)):
        continue
    shutil.copy(img_fname, reprosourcedir + os.path.basename(img_fname))   ### 需要在命名中加入tile信息
    i += 1

    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    xSize = src_img.RasterXSize
    ySize = src_img.RasterYSize

    minX = geotransform[0]
    maxY = geotransform[3]
    maxX = minX + geotransform[1] * xSize
    minY = maxY + geotransform[5] * ySize

    output_path = os.path.join(reprosourcedir, os.path.basename(img_fname)).replace(".tif", "_mask.tif")
    warp_options = gdal.WarpOptions(format='GTiff',
                                    outputBounds=[minX, minY, maxX, maxY],
                                    xRes=geotransform[1], yRes=geotransform[5],
                                    dstSRS=projection,
                                    resampleAlg=gdal.GRA_NearestNeighbour,
                                    dstNodata=255)  # 使用0填充空缺部分

    # 执行重投影和剪切
    gdal.Warp(output_path, mask_fname, options=warp_options)

In [2]:
# 对重投影的数据进行裁剪，切割出完整的val块作为训练中的test数据集（val的整块图像预测的部分）
base_dir = '../pro_data/sate_dataset_V4/repro_resource/'
# valdir = '../pro_data/val/'
testdir = '../pro_data/sate_dataset_V4/val_test/'
if not os.path.exists(testdir):
    os.makedirs(testdir)

mask_filenames = sorted(list(glob.glob(os.path.join(base_dir, '**', '*mask*.tif'), recursive=True)))
img_filenames = [fname.rsplit('_mask', 1)[0] + '.tif' for fname in mask_filenames]

for i, (img_fname, mask_fname) in enumerate(zip(img_filenames, mask_filenames)):
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    h, w = src_img.RasterYSize, src_img.RasterXSize

    src_mask = gdal.Open(mask_fname)
    mask_geotransform = src_mask.GetGeoTransform()
    mask_projection = src_mask.GetProjection()
    mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

    assert h == mask_h and w == mask_w

    tile = os.path.basename(img_fname)[:7]
    if tile == 'NZ_A002':
        x_off = 0
        y_off = 0

        new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
        new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

        file_name = os.path.basename(img_fname).split('.')[0]
        output_path = os.path.join(testdir, f'{file_name}_val.tif')
        output_mask_path = os.path.join(testdir, f'{file_name}_val_mask.tif')

        # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)位置的区域
        out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, int(w * 0.5), int(h * 0.3)]) # 没有取到index为int(h * 0.3)的这一行
        out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, int(w * 0.5), int(h * 0.3)])    
    else:
        # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)的区域
        x_off = 0
        y_off = int(h * 0.7)

        new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
        new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

        file_name = os.path.basename(img_fname).split('.')[0]
        output_path = os.path.join(testdir, f'{file_name}_val.tif')
        output_mask_path = os.path.join(testdir, f'{file_name}_val_mask.tif')

        # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)位置的区域
        out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, int(w * 0.5), h - y_off])
        out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, int(w * 0.5), h - y_off])

    out_ds.SetGeoTransform(new_gt)
    out_ds.SetProjection(projection)
    out_mask_ds.SetGeoTransform(new_mask_gt)
    out_mask_ds.SetProjection(mask_projection)

    out_ds = None
    out_mask_ds = None

In [6]:
# 对重投影的数据进行裁剪，切割出完整的test块作为结束后test的test数据集（test的整块图像预测的部分）
base_dir = '../pro_data/sate_dataset_V4/repro_resource/'
# valdir = '../pro_data/val/'
test_testdir = '../pro_data/sate_dataset_V4/test_test/'
if not os.path.exists(test_testdir):
    os.makedirs(test_testdir)

mask_filenames = sorted(list(glob.glob(os.path.join(base_dir, '**', '*mask*.tif'), recursive=True)))
img_filenames = [fname.rsplit('_mask', 1)[0] + '.tif' for fname in mask_filenames]

for i, (img_fname, mask_fname) in enumerate(zip(img_filenames, mask_filenames)):
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    h, w = src_img.RasterYSize, src_img.RasterXSize

    src_mask = gdal.Open(mask_fname)
    mask_geotransform = src_mask.GetGeoTransform()
    mask_projection = src_mask.GetProjection()
    mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

    assert h == mask_h and w == mask_w

    tile = os.path.basename(img_fname)[:7]
    if tile == 'NZ_A002':
        x_off = int(w * 0.5)
        y_off = 0

        new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
        new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

        file_name = os.path.basename(img_fname).split('.')[0]
        output_path = os.path.join(test_testdir, f'{file_name}_test.tif')
        output_mask_path = os.path.join(test_testdir, f'{file_name}_test_mask.tif')

        # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)位置的区域
        out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, w - x_off, int(h * 0.3)])
        out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, w - x_off, int(h * 0.3)])
    else:
        # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)的区域
        x_off = int(w * 0.5)
        y_off = int(h * 0.7)

        new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
        new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

        file_name = os.path.basename(img_fname).split('.')[0]
        output_path = os.path.join(test_testdir, f'{file_name}_test.tif')
        output_mask_path = os.path.join(test_testdir, f'{file_name}_test_mask.tif')

        # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)位置的区域
        out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, w - x_off, h - y_off])
        out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, w - x_off, h - y_off])

    out_ds.SetGeoTransform(new_gt)
    out_ds.SetProjection(projection)
    out_mask_ds.SetGeoTransform(new_mask_gt)
    out_mask_ds.SetProjection(mask_projection)

    out_ds = None
    out_mask_ds = None

In [6]:
################
# # (BRZ图像切割成四份方便预测)

# mask_fname = '/scratch2/ziyliu/LAMA/lama/BRZ_test/2_7723282967316973_014886554010_01_P002_mask.tif'
# img_fname = '/scratch2/ziyliu/LAMA/lama/BRZ_test/2_7723282967316973_014886554010_01_P002.tif'


# src_img = gdal.Open(img_fname)
# geotransform = src_img.GetGeoTransform()
# projection = src_img.GetProjection()
# h, w = src_img.RasterYSize, src_img.RasterXSize

# src_mask = gdal.Open(mask_fname)
# mask_geotransform = src_mask.GetGeoTransform()
# mask_projection = src_mask.GetProjection()
# mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

# assert h == mask_h and w == mask_w

# x_off = 0
# y_off = int(h * 0.25) * 3

# new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
# new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

# output_path = img_fname.replace('.tif', '_22.tif')
# output_mask_path = img_fname.replace('.tif', '_22_mask.tif')

# # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)位置的区域
# out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, w, h - 3*int(h * 0.25)]) # 没有取到index为int(h * 0.3)的这一行
# out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, w, h-3*int(h * 0.25)])    



# out_ds.SetGeoTransform(new_gt)
# out_ds.SetProjection(projection)
# out_mask_ds.SetGeoTransform(new_mask_gt)
# out_mask_ds.SetProjection(mask_projection)

# out_ds = None
# out_mask_ds = None

In [9]:
print('字符串长度: ', len('NZ_A001_21OCT20113141-P1BS-014905030010_01_P001'))

字符串长度:  47


In [11]:
# 计算最大最小值然后保存下来
indir = '../pro_data/train/'
tiles = [tile.name for tile in Path(indir).iterdir() if tile.is_dir()]
dict = {}

for tile in tiles:
    for image in os.listdir(indir + '/' + tile + '/image/'):
        if not image.endswith('.tif'):
            continue
        fname = indir + '/' + tile + '/image/' + image

        file_name = image[:47]
        with rasterio.open(fname) as src:
            img_array = src.read(1).astype('float32')
            mu = np.mean(img_array[img_array < 65535])
            std = np.std(img_array[img_array < 65535]) #2048
            min_value = np.maximum(0, mu - 3 * std)
            max_value = mu + 3 * std

            # 向dict中添加数据
            dict[file_name] = {
                'min_value': min_value,
                'max_value': max_value
            }

In [12]:
print(dict)

{'NZ_A001_19OCT06222701-P1BS-014771210010_01_P001': {'min_value': 0.0, 'max_value': 3620.7481689453125}, 'NZ_A001_19OCT06222702-P1BS-014771210010_01_P002': {'min_value': 0.0, 'max_value': 3612.9026489257812}, 'NZ_A001_20APR24222440-P1BS-014421974010_01_P002': {'min_value': 0.0, 'max_value': 3670.6707763671875}, 'NZ_A001_20APR24222441-P1BS-014421974010_01_P003': {'min_value': 0.0, 'max_value': 3683.2294921875}, 'NZ_A001_20FEB11013638-P1BS-014422089010_01_P005': {'min_value': 0.0, 'max_value': 4111.425476074219}, 'NZ_A001_21AUG20222532-P1BS-014422040010_01_P002': {'min_value': 0.0, 'max_value': 3567.2447509765625}, 'NZ_A001_21FEB11013741-P1BS-014422115010_01_P002': {'min_value': 0.0, 'max_value': 4365.789367675781}, 'NZ_A001_21MAY19222720-P1BS-014422073010_01_P002': {'min_value': 0.0, 'max_value': 3710.3179931640625}, 'NZ_A002_20FEB11013638-P1BS-014422089010_01_P005': {'min_value': 0.0, 'max_value': 3303.8785400390625}, 'NZ_A002_20MAY29223553-P1BS-014421981010_01_P002': {'min_value': 0.0

In [5]:
# 从train数据集中随机选取image和mask对，重投影到相同分辨率，
# 然后从val block中随机裁剪出256*256的图像对作为val crops
import rasterio
from rasterio.warp import reproject, Resampling

indir = '../pro_data/sate_dataset_V4/train/'
val_indir = '../pro_data/sate_dataset_V4/val_val/'

# if os.path.exists(val_indir):
#     shutil.rmtree(val_indir)
# os.makedirs(val_indir)

reprojected_masks = '../pro_data/reprojected_masks'
if os.path.exists(reprojected_masks):
    shutil.rmtree(reprojected_masks)
os.makedirs(reprojected_masks)

pair_filenames = {}
tiles = [tile.name for tile in Path(indir).iterdir() if tile.is_dir()]

for tile in tiles:
    if tile not in pair_filenames:
        pair_filenames[tile] = {"image_path": [], "mask_path": []}
    for image in os.listdir(indir + '/' + tile + '/image/'):
        if image.endswith(".tif"):
            image_path = indir + '/' + tile + '/image/' + image
            pair_filenames[tile]["image_path"].append(image_path)
    for mask in os.listdir(indir + '/' + tile + '/mask/'):
        if mask.endswith(".tif"):
            mask_path = indir + '/' + tile + '/mask/' + mask
            pair_filenames[tile]["mask_path"].append(mask_path)
            
i = 0
patch_size = 256

while i < 2600:
    tile = random.choice(tiles)
    img_fname = random.choice(pair_filenames[tile]["image_path"])
    mask_fname = random.choice(pair_filenames[tile]["mask_path"])

    # 重投影mask到image上
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    xSize = src_img.RasterXSize
    ySize = src_img.RasterYSize

    minX = geotransform[0]
    maxY = geotransform[3]
    maxX = minX + geotransform[1] * xSize
    minY = maxY + geotransform[5] * ySize

    output_path = os.path.join(reprojected_masks, os.path.basename(mask_fname)).replace(".tif", f"_reprojected{i}.tif")
    warp_options = gdal.WarpOptions(format='GTiff',
                                    outputBounds=[minX, minY, maxX, maxY],
                                    xRes=geotransform[1], yRes=geotransform[5],
                                    dstSRS=projection,
                                    resampleAlg=gdal.GRA_NearestNeighbour,
                                    dstNodata=255)  # 使用0填充空缺部分

    # 执行重投影和剪切
    gdal.Warp(output_path, mask_fname, options=warp_options)

    # 清除src_img, geotransform, projection等
    src_img = None
    geotransform = None
    projection = None
    mask_fname = output_path

    # 读取重投影后的mask和image
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    h, w = src_img.RasterYSize, src_img.RasterXSize

    src_mask = gdal.Open(mask_fname)
    mask_geotransform = src_mask.GetGeoTransform()
    mask_projection = src_mask.GetProjection()
    mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

    assert h == mask_h and w == mask_w

    img_read_array = src_img.ReadAsArray()
    mask_read_array = src_mask.ReadAsArray()
    nodata_read_mask = (img_read_array == 2**16-1).astype(np.uint8)

    if tile == 'NZ_A002':
        w_range = [0, int(w/2)-1] # 首尾都包含，索引
        h_range = [0, int(h*0.3)-1]
    else:
        w_range = [0, int(w/2)-1] # 首尾都包含，索引
        h_range = [int(h*0.7), h]
    try_times = 0
    while True:
        # 如果截取的部分有nodata值，则重新随机选取
        x_off = random.randint(w_range[0], w_range[1]-patch_size) # 首尾都包含，索引
        y_off = random.randint(h_range[0], h_range[1]-patch_size)
        nodata_read_mask_patch = nodata_read_mask[y_off:y_off+patch_size, x_off:x_off+patch_size]
        mask_patch = mask_read_array[y_off:y_off+patch_size, x_off:x_off+patch_size]
        try_times += 1
        if try_times>10000:
            break
        if np.sum(nodata_read_mask_patch) == 0 and np.sum(mask_patch) > 0:
            i += 1
            break
    if try_times>10000:
        continue
    new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
    new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

    file_name = os.path.basename(img_fname).split('.')[0]
    output_path = os.path.join(val_indir, f'{file_name}_val_crop00{i}.tif')
    output_mask_path = os.path.join(val_indir, f'{file_name}_val_crop00{i}_mask.tif')

    out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])
    out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])

    out_ds.SetGeoTransform(new_gt)
    out_ds.SetProjection(projection)
    out_mask_ds.SetGeoTransform(new_mask_gt)
    out_mask_ds.SetProjection(mask_projection)

    out_ds = None
    out_mask_ds = None

In [7]:
# 从train数据集中随机选取image和mask对，重投影到相同分辨率
# 然后从test block中随机裁剪出256*256的图像对作为test crops的数据

import rasterio
from rasterio.warp import reproject, Resampling

indir = '../pro_data/sate_dataset_V4/train/'
test_indir = '../pro_data/sate_dataset_V4/test_crops/'

if os.path.exists(test_indir):
    shutil.rmtree(test_indir)
os.makedirs(test_indir)

reprojected_masks = '../pro_data/reprojected_masks'
if os.path.exists(reprojected_masks):
    shutil.rmtree(reprojected_masks)
os.makedirs(reprojected_masks)

pair_filenames = {}
tiles = [tile.name for tile in Path(indir).iterdir() if tile.is_dir()]

for tile in tiles:
    if tile not in pair_filenames:
        pair_filenames[tile] = {"image_path": [], "mask_path": []}
    for image in os.listdir(indir + '/' + tile + '/image/'):
        if image.endswith(".tif"):
            image_path = indir + '/' + tile + '/image/' + image
            pair_filenames[tile]["image_path"].append(image_path)
    for mask in os.listdir(indir + '/' + tile + '/mask/'):
        if mask.endswith(".tif"):
            mask_path = indir + '/' + tile + '/mask/' + mask
            pair_filenames[tile]["mask_path"].append(mask_path)
            
i = 0
patch_size = 256

while i < 2600:
    tile = random.choice(tiles)
    img_fname = random.choice(pair_filenames[tile]["image_path"])
    mask_fname = random.choice(pair_filenames[tile]["mask_path"])

    # 重投影mask到image上
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    xSize = src_img.RasterXSize
    ySize = src_img.RasterYSize

    minX = geotransform[0]
    maxY = geotransform[3]
    maxX = minX + geotransform[1] * xSize
    minY = maxY + geotransform[5] * ySize

    output_path = os.path.join(reprojected_masks, os.path.basename(mask_fname)).replace(".tif", f"_reprojected{i}.tif")
    warp_options = gdal.WarpOptions(format='GTiff',
                                    outputBounds=[minX, minY, maxX, maxY],
                                    xRes=geotransform[1], yRes=geotransform[5],
                                    dstSRS=projection,
                                    resampleAlg=gdal.GRA_NearestNeighbour,
                                    dstNodata=255)  # 使用0填充空缺部分

    # 执行重投影和剪切
    gdal.Warp(output_path, mask_fname, options=warp_options)

    # 清除src_img, geotransform, projection等
    src_img = None
    geotransform = None
    projection = None
    mask_fname = output_path

    # 读取重投影后的mask和image
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    h, w = src_img.RasterYSize, src_img.RasterXSize

    src_mask = gdal.Open(mask_fname)
    mask_geotransform = src_mask.GetGeoTransform()
    mask_projection = src_mask.GetProjection()
    mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

    assert h == mask_h and w == mask_w

    img_read_array = src_img.ReadAsArray()
    mask_read_array = src_mask.ReadAsArray()
    nodata_read_mask = (img_read_array == 2**16-1).astype(np.uint8)

    # test区域
    if tile == 'NZ_A002':
        w_range = [int(w/2), w] # 首尾都包含，索引
        h_range = [0, int(h*0.3)-1]
    else:
        w_range = [int(w/2), w] # 首尾都包含，索引
        h_range = [int(h*0.7), h]
    try_times = 0
    while True:
        # 如果截取的部分有nodata值，则重新随机选取
        x_off = random.randint(w_range[0], w_range[1]-patch_size) # 首尾都包含，索引
        y_off = random.randint(h_range[0], h_range[1]-patch_size)
        nodata_read_mask_patch = nodata_read_mask[y_off:y_off+patch_size, x_off:x_off+patch_size]
        mask_patch = mask_read_array[y_off:y_off+patch_size, x_off:x_off+patch_size]
        try_times += 1
        if try_times>10000:
            break
        if np.sum(nodata_read_mask_patch) == 0 and np.sum(mask_patch) > 0:
            i += 1
            break
    if try_times>10000:
        continue
    new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
    new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

    file_name = os.path.basename(img_fname).split('.')[0]
    output_path = os.path.join(test_indir, f'{file_name}_test_crop00{i}.tif')
    output_mask_path = os.path.join(test_indir, f'{file_name}_test_crop00{i}_mask.tif')

    out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])
    out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])

    out_ds.SetGeoTransform(new_gt)
    out_ds.SetProjection(projection)
    out_mask_ds.SetGeoTransform(new_mask_gt)
    out_mask_ds.SetProjection(mask_projection)

    out_ds = None
    out_mask_ds = None

In [8]:
# 从train数据集中随机选取image和mask对，重投影到相同分辨率
# 然后从train block中随机裁剪出256*256的图像对作为train crops的数据，用于观察overfitting情况

import rasterio
from rasterio.warp import reproject, Resampling

indir = '../pro_data/sate_dataset_V4/train/'
train_indir = '../pro_data/sate_dataset_V4/train_crops/'

if os.path.exists(train_indir):
    shutil.rmtree(train_indir)
os.makedirs(train_indir)

pair_filenames = {}
tiles = [tile.name for tile in Path(indir).iterdir() if tile.is_dir()]

for tile in tiles:
    if tile not in pair_filenames:
        pair_filenames[tile] = {"image_path": [], "mask_path": []}
    for image in os.listdir(indir + '/' + tile + '/image/'):
        if image.endswith(".tif"):
            image_path = indir + '/' + tile + '/image/' + image
            pair_filenames[tile]["image_path"].append(image_path)
    for mask in os.listdir(indir + '/' + tile + '/mask_dilation3/'):
        if mask.endswith(".tif"):
            mask_path = indir + '/' + tile + '/mask_dilation3/' + mask
            pair_filenames[tile]["mask_path"].append(mask_path)
            
i = 0
patch_size = 256

reprojected_masks = '../pro_data/reprojected_masks'
if os.path.exists(reprojected_masks):
    shutil.rmtree(reprojected_masks)
os.makedirs(reprojected_masks)

while i < 800:
    tile = random.choice(tiles)
    img_fname = random.choice(pair_filenames[tile]["image_path"])
    mask_fname = random.choice(pair_filenames[tile]["mask_path"])

    # 重投影mask到image上
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    xSize = src_img.RasterXSize
    ySize = src_img.RasterYSize

    minX = geotransform[0]
    maxY = geotransform[3]
    maxX = minX + geotransform[1] * xSize
    minY = maxY + geotransform[5] * ySize

    output_path = os.path.join(reprojected_masks, os.path.basename(mask_fname)).replace(".tif", f"_reprojected{i}.tif")
    warp_options = gdal.WarpOptions(format='GTiff',
                                    outputBounds=[minX, minY, maxX, maxY],
                                    xRes=geotransform[1], yRes=geotransform[5],
                                    dstSRS=projection,
                                    resampleAlg=gdal.GRA_NearestNeighbour,
                                    dstNodata=255)  # 使用0填充空缺部分

    # 执行重投影和剪切
    gdal.Warp(output_path, mask_fname, options=warp_options)

    # 清除src_img, geotransform, projection等
    src_img = None
    geotransform = None
    projection = None
    mask_fname = output_path

    # 读取重投影后的mask和image
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    h, w = src_img.RasterYSize, src_img.RasterXSize

    src_mask = gdal.Open(mask_fname)
    mask_geotransform = src_mask.GetGeoTransform()
    mask_projection = src_mask.GetProjection()
    mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

    assert h == mask_h and w == mask_w

    img_read_array = src_img.ReadAsArray()
    mask_read_array = src_mask.ReadAsArray()
    nodata_read_mask = (img_read_array == 2**16-1).astype(np.uint8)

    if tile == 'NZ_A002':
        w_range = [0, w] # 首尾都包含，索引
        h_range = [int(h*0.3), h]
    else:
        w_range = [0, w] # 首尾都包含，索引
        h_range = [0, int(h*0.7)-1]
    try_times = 0
    while True:
        # 如果截取的部分有nodata值，则重新随机选取
        x_off = random.randint(w_range[0], w_range[1]-patch_size) # 首尾都包含，索引
        y_off = random.randint(h_range[0], h_range[1]-patch_size)
        nodata_read_mask_patch = nodata_read_mask[y_off:y_off+patch_size, x_off:x_off+patch_size]
        mask_patch = mask_read_array[y_off:y_off+patch_size, x_off:x_off+patch_size]
        try_times += 1
        if try_times>10000:
            break
        if np.sum(nodata_read_mask_patch) == 0 and np.sum(mask_patch) > 0:
            i += 1
            break
    if try_times>10000:
        continue
    new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
    new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

    file_name = os.path.basename(img_fname).split('.')[0]
    output_path = os.path.join(train_indir, f'{file_name}_train_crop00{i}.tif')
    output_mask_path = os.path.join(train_indir, f'{file_name}_train_crop00{i}_mask.tif')

    out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])
    out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])

    out_ds.SetGeoTransform(new_gt)
    out_ds.SetProjection(projection)
    out_mask_ds.SetGeoTransform(new_mask_gt)
    out_mask_ds.SetProjection(mask_projection)

    out_ds = None
    out_mask_ds = None

In [None]:
## 以下是弃用代码

In [4]:
# 从整块数据中滑动窗口切割出小块作为val数据集
# testdir = '../pro_data/val_test/'
# valdir = '../pro_data/val_all/'

testdir = '../pro_data/test_test/'
valdir = '../pro_data/test_all/'

if not os.path.exists(valdir):
    os.makedirs(valdir)

mask_filenames = sorted(list(glob.glob(os.path.join(testdir, '**', '*mask*.tif'), recursive=True)))
img_filenames = [fname.rsplit('_mask', 1)[0] + '.tif' for fname in mask_filenames]

patch_size=256
overlap=0.3

for num, (img_fname, mask_fname) in enumerate(zip(img_filenames, mask_filenames)):
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    h, w = src_img.RasterYSize, src_img.RasterXSize

    src_mask = gdal.Open(mask_fname)
    mask_geotransform = src_mask.GetGeoTransform()
    mask_projection = src_mask.GetProjection()
    mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

    assert h == mask_h and w == mask_w

    stride = int(patch_size * (1 - overlap))

    # calculate how many whole complete patches can be extracted
    num_patches_h = (h - patch_size) // stride + 1
    num_patches_w = (w - patch_size) // stride + 1

    if (h - patch_size) % stride != 0:
        num_patches_h += 1
    if (w - patch_size) % stride != 0:
        num_patches_w += 1

    index = 0

    for i in range(num_patches_h):
        for j in range(num_patches_w):
            # 计算patch的起始坐标
            start_i = min(i * stride, h - patch_size)
            start_j = min(j * stride, w - patch_size)

            x_off = start_j
            y_off = start_i

            new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
            new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

            file_name = os.path.basename(img_fname).split('.')[0]
            output_path = os.path.join(valdir, f'{file_name}_crop00{index}.tif')
            output_mask_path = os.path.join(valdir, f'{file_name}_crop00{index}_mask00{index}.tif')

            out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])
            out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])

            out_ds.SetGeoTransform(new_gt)
            out_ds.SetProjection(projection)
            out_mask_ds.SetGeoTransform(new_mask_gt)
            out_mask_ds.SetProjection(mask_projection)

            out_ds = None
            out_mask_ds = None

            index += 1

In [5]:
## 从所有三千多对val patch里随机选取70%作为val数据集
valdir_all = '../pro_data/val_all/'
valdir = '../pro_data/val_val/'

if not os.path.exists(valdir):
    os.makedirs(valdir)

mask_filenames = sorted(list(glob.glob(os.path.join(valdir_all, '**', '*mask*.tif'), recursive=True)))
img_filenames = [fname.rsplit('_mask', 1)[0] + '.tif' for fname in mask_filenames]

num_val = len(img_filenames)
val_length = int(num_val * 0.7)
val_indices = random.sample(range(num_val), val_length)

for i, (img_fname, mask_fname) in enumerate(zip(img_filenames, mask_filenames)):
    if i in val_indices:
        shutil.copy(img_fname, valdir + os.path.basename(img_fname))
        shutil.copy(mask_fname, valdir + os.path.basename(mask_fname))