In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
from tqdm import tqdm
import shutil
import json
from datetime import datetime
from pathlib import Path
import random
from osgeo import gdal
import glob
import rasterio

In [None]:
country_list = ["BR_A001", "NZ_A001", "NZ_A002", "US_A001", "ZA_A001"]
wv_list = ["WV01", "WV02", "WV03"]

In [5]:
# 从train数据集中随机选取image和mask对，重投影到相同分辨率，然后保存为一对image和mask到repro_sourcewen文件夹里作为val和test的数据来源

indir = '/scratch2/ziyliu/LAMA/lama/sate_dataset/train/'
reprosourcedir = '../pro_data/repro_source/'
cropdir = '../pro_data/repro_crop/'
if not os.path.exists(reprosourcedir):
    os.makedirs(reprosourcedir)
if not os.path.exists(cropdir):
    os.makedirs(cropdir)

pair_filenames = {}
tiles = [tile.name for tile in Path(indir).iterdir() if tile.is_dir()]

for tile in tiles:
    if tile not in pair_filenames:
        pair_filenames[tile] = {"image_path": [], "mask_path": []}
    for image in os.listdir(indir + '/' + tile + '/image/'):
        if image.endswith(".tif"):
            image_path = indir + '/' + tile + '/image/' + image
            pair_filenames[tile]["image_path"].append(image_path)
    for mask in os.listdir(indir + '/' + tile + '/mask/'):
        if mask.endswith(".tif"):
            mask_path = indir + '/' + tile + '/mask/' + mask
            pair_filenames[tile]["mask_path"].append(mask_path)

for i in range(30):
    tile = random.choice(tiles)
    img_fname = random.choice(pair_filenames[tile]["image_path"])
    mask_fname = random.choice(pair_filenames[tile]["mask_path"])
    shutil.copy(img_fname, reprosourcedir + os.path.basename(img_fname))

    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    xSize = src_img.RasterXSize
    ySize = src_img.RasterYSize

    minX = geotransform[0]
    maxY = geotransform[3]
    maxX = minX + geotransform[1] * xSize
    minY = maxY + geotransform[5] * ySize

    output_path = os.path.join(reprosourcedir, os.path.basename(img_fname)).replace(".tif", "_mask.tif")
    warp_options = gdal.WarpOptions(format='GTiff',
                                    outputBounds=[minX, minY, maxX, maxY],
                                    xRes=geotransform[1], yRes=geotransform[5],
                                    dstSRS=projection,
                                    resampleAlg=gdal.GRA_NearestNeighbour,
                                    dstNodata=255)  # 使用0填充空缺部分

    # 执行重投影和剪切
    gdal.Warp(output_path, mask_fname, options=warp_options)    

In [13]:
# 对重投影的数据进行裁剪，切割出完整的val块作为训练中的test数据集（val的整块图像预测的部分）
base_dir = '../pro_data/repro_source/'
# valdir = '../pro_data/val/'
testdir = '../pro_data/val_test/'
if not os.path.exists(testdir):
    os.makedirs(testdir)

mask_filenames = sorted(list(glob.glob(os.path.join(base_dir, '**', '*mask*.tif'), recursive=True)))
img_filenames = [fname.rsplit('_mask', 1)[0] + '.tif' for fname in mask_filenames]

for i, (img_fname, mask_fname) in enumerate(zip(img_filenames, mask_filenames)):
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    h, w = src_img.RasterYSize, src_img.RasterXSize

    src_mask = gdal.Open(mask_fname)
    mask_geotransform = src_mask.GetGeoTransform()
    mask_projection = src_mask.GetProjection()
    mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

    assert h == mask_h and w == mask_w

    # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)的区域
    x_off = 0
    y_off = int(h * 0.7)

    new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
    new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

    file_name = os.path.basename(img_fname).split('.')[0]
    output_path = os.path.join(testdir, f'{file_name}_crop.tif')
    output_mask_path = os.path.join(testdir, f'{file_name}_crop_mask.tif')

    # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)位置的区域
    out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, int(w * 0.5), h - y_off])
    out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, int(w * 0.5), h - y_off])

    out_ds.SetGeoTransform(new_gt)
    out_ds.SetProjection(projection)
    out_mask_ds.SetGeoTransform(new_mask_gt)
    out_mask_ds.SetProjection(mask_projection)

    out_ds = None
    out_mask_ds = None


In [3]:
# 对重投影的数据进行裁剪，切割出完整的test块作为结束后test的test数据集（test的整块图像预测的部分）
base_dir = '../pro_data/repro_source/'
# valdir = '../pro_data/val/'
test_testdir = '../pro_data/test_test/'
if not os.path.exists(test_testdir):
    os.makedirs(test_testdir)

mask_filenames = sorted(list(glob.glob(os.path.join(base_dir, '**', '*mask*.tif'), recursive=True)))
img_filenames = [fname.rsplit('_mask', 1)[0] + '.tif' for fname in mask_filenames]

for i, (img_fname, mask_fname) in enumerate(zip(img_filenames, mask_filenames)):
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    h, w = src_img.RasterYSize, src_img.RasterXSize

    src_mask = gdal.Open(mask_fname)
    mask_geotransform = src_mask.GetGeoTransform()
    mask_projection = src_mask.GetProjection()
    mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

    assert h == mask_h and w == mask_w

    # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)的区域
    x_off = int(w * 0.5)
    y_off = int(h * 0.7)

    new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
    new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

    file_name = os.path.basename(img_fname).split('.')[0]
    output_path = os.path.join(test_testdir, f'{file_name}_test_crop.tif')
    output_mask_path = os.path.join(test_testdir, f'{file_name}_test_crop_mask.tif')

    # 直接截取(0, int(w * 0.5) - 256)，(int(h * 0.7), h - 256)位置的区域
    out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, w - x_off, h - y_off])
    out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, w - x_off, h - y_off])

    out_ds.SetGeoTransform(new_gt)
    out_ds.SetProjection(projection)
    out_mask_ds.SetGeoTransform(new_mask_gt)
    out_mask_ds.SetProjection(mask_projection)

    out_ds = None
    out_mask_ds = None

In [1]:
print('字符串长度: ', len('21OCT20113141-P1BS-014905030010_01_P001'))

字符串长度:  39


In [None]:
# 计算最大最小值然后保存下来
indir = '/scratch2/ziyliu/LAMA/lama/sate_dataset/train/'
tiles = [tile.name for tile in Path(indir).iterdir() if tile.is_dir()]
dict = {}

for tile in tiles:
    for image in os.listdir(indir + '/' + tile + '/image/'):
        fname = indir + '/' + tile + '/image/' + image

        file_name = image[:39]
        with rasterio.open(fname) as src:
            img_array = src.read(1).astype('float32')
            mu = np.mean(img_array[img_array < 65535])
            std = np.std(img_array[img_array < 65535]) #2048
            min_value = np.maximum(0, mu - 3 * std)
            max_value = mu + 3 * std

            # 向dict中添加数据
            dict[file_name] = {
                'min_value': min_value,
                'max_value': max_value
            }

In [6]:
print(dict)

{'20FEB11013638-P1BS-014422089010_01_P005': {'min_value': 0.0, 'max_value': 3303.8785400390625}, '21FEB11013741-P1BS-014422115010_01_P002': {'min_value': 0.0, 'max_value': 3431.257568359375}, '20APR24222440-P1BS-014421974010_01_P002': {'min_value': 0.0, 'max_value': 2946.7777709960938}, '20APR24222441-P1BS-014421974010_01_P003': {'min_value': 0.0, 'max_value': 3683.2294921875}, '21AUG20222532-P1BS-014422040010_01_P002': {'min_value': 0.0, 'max_value': 2821.8384399414062}, '19OCT06222701-P1BS-014771210010_01_P001': {'min_value': 0.0, 'max_value': 3620.7481689453125}, '19OCT06222702-P1BS-014771210010_01_P002': {'min_value': 0.0, 'max_value': 2774.9259033203125}, '21MAY19222720-P1BS-014422073010_01_P002': {'min_value': 0.0, 'max_value': 3039.9217529296875}, '20SEP23222626-P1BS-014421992010_01_P002': {'min_value': 0.0, 'max_value': 3081.2487182617188}, '21JAN26221637-P1BS-014421999010_01_P002': {'min_value': 0.0, 'max_value': 3417.5930786132812}, '21FEB16224204-P1BS-014422001010_01_P002': 

In [4]:
# 从整块数据中滑动窗口切割出小块作为val数据集
# testdir = '../pro_data/val_test/'
# valdir = '../pro_data/val_all/'

testdir = '../pro_data/test_test/'
valdir = '../pro_data/test_all/'

if not os.path.exists(valdir):
    os.makedirs(valdir)

mask_filenames = sorted(list(glob.glob(os.path.join(testdir, '**', '*mask*.tif'), recursive=True)))
img_filenames = [fname.rsplit('_mask', 1)[0] + '.tif' for fname in mask_filenames]

patch_size=256
overlap=0.3

for num, (img_fname, mask_fname) in enumerate(zip(img_filenames, mask_filenames)):
    src_img = gdal.Open(img_fname)
    geotransform = src_img.GetGeoTransform()
    projection = src_img.GetProjection()
    h, w = src_img.RasterYSize, src_img.RasterXSize

    src_mask = gdal.Open(mask_fname)
    mask_geotransform = src_mask.GetGeoTransform()
    mask_projection = src_mask.GetProjection()
    mask_h, mask_w = src_mask.RasterYSize, src_mask.RasterXSize

    assert h == mask_h and w == mask_w

    stride = int(patch_size * (1 - overlap))

    # calculate how many whole complete patches can be extracted
    num_patches_h = (h - patch_size) // stride + 1
    num_patches_w = (w - patch_size) // stride + 1

    if (h - patch_size) % stride != 0:
        num_patches_h += 1
    if (w - patch_size) % stride != 0:
        num_patches_w += 1

    index = 0

    for i in range(num_patches_h):
        for j in range(num_patches_w):
            # 计算patch的起始坐标
            start_i = min(i * stride, h - patch_size)
            start_j = min(j * stride, w - patch_size)

            x_off = start_j
            y_off = start_i

            new_gt = (geotransform[0] + x_off * geotransform[1], geotransform[1], 0, geotransform[3] + y_off * geotransform[5], 0, geotransform[5])
            new_mask_gt = (mask_geotransform[0] + x_off * mask_geotransform[1], mask_geotransform[1], 0, mask_geotransform[3] + y_off * mask_geotransform[5], 0, mask_geotransform[5])

            file_name = os.path.basename(img_fname).split('.')[0]
            output_path = os.path.join(valdir, f'{file_name}_crop00{index}.tif')
            output_mask_path = os.path.join(valdir, f'{file_name}_crop00{index}_mask00{index}.tif')

            out_ds = gdal.Translate(output_path, src_img, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])
            out_mask_ds = gdal.Translate(output_mask_path, src_mask, format='GTiff', srcWin=[x_off, y_off, patch_size, patch_size])

            out_ds.SetGeoTransform(new_gt)
            out_ds.SetProjection(projection)
            out_mask_ds.SetGeoTransform(new_mask_gt)
            out_mask_ds.SetProjection(mask_projection)

            out_ds = None
            out_mask_ds = None

            index += 1

In [5]:
## 从所有三千多对val patch里随机选取70%作为val数据集
valdir_all = '../pro_data/val_all/'
valdir = '../pro_data/val_val/'

if not os.path.exists(valdir):
    os.makedirs(valdir)

mask_filenames = sorted(list(glob.glob(os.path.join(valdir_all, '**', '*mask*.tif'), recursive=True)))
img_filenames = [fname.rsplit('_mask', 1)[0] + '.tif' for fname in mask_filenames]

num_val = len(img_filenames)
val_length = int(num_val * 0.7)
val_indices = random.sample(range(num_val), val_length)

for i, (img_fname, mask_fname) in enumerate(zip(img_filenames, mask_filenames)):
    if i in val_indices:
        shutil.copy(img_fname, valdir + os.path.basename(img_fname))
        shutil.copy(mask_fname, valdir + os.path.basename(mask_fname))