In [3]:
import os
import contextlib
import joblib
from tqdm import tqdm


@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
    """Context manager to patch joblib to report into tqdm progress bar given as argument
    """
    class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
        def __call__(self, *args, **kwargs):
            tqdm_object.update(n=self.batch_size)
            return super().__call__(*args, **kwargs)

    old_batch_callback = joblib.parallel.BatchCompletionCallBack
    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
    try:
        yield tqdm_object
    finally:
        joblib.parallel.BatchCompletionCallBack = old_batch_callback
        tqdm_object.close()

def get_paths(filename, old_dir, new_dir, new_ext):
    old_path = os.path.join(old_dir, filename)
    basename, ext = os.path.splitext(filename)
    if new_ext is not None:
        new_filename = basename + new_ext
    else:
        new_filename = filename
    new_path = os.path.join(new_dir, new_filename)
    return old_path, new_path

def dir_map(old_dir: str, new_dir: str, func, old_ext: str, new_ext: str = None, n_jobs: int = -1) -> None:
    """
    Maps the files in the input directory `old_dir` with the extension `old_ext` to the output directory `new_dir`
    using the provided function `func`. If `new_ext` is specified, the output files will have the extension `new_ext`;
    otherwise, they will have the same extension as the input files. If `new_dir` does not exist, it will be created.

    Args:
        old_dir (str): The input directory containing the files to be processed.
        new_dir (str): The output directory where the processed files will be written.
        func (callable): A function that takes the path to an input file and the path to an output file as arguments,
            reads the contents of the input file, processes the contents, and writes the result to the output file.
        old_ext (str): The extension of the input files to be processed.
        new_ext (str, optional): The extension to use for the output files. If not specified, the output files will
            have the same extension as the input files.
        n_jobs (int, optional): The number of parallel jobs to use. If -1, the number of jobs is set to the number
            of available CPU cores.
    """
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)
    files = [filename for filename in os.listdir(old_dir) if filename.endswith(old_ext)]

    with tqdm_joblib(tqdm(desc="Processing files", total=len(files))):
        joblib.Parallel(n_jobs=n_jobs, backend='multiprocessing')(joblib.delayed(func)(*get_paths(filename, old_dir, new_dir, new_ext)) for filename in files)

In [22]:
import os
import numpy as np
from PIL import Image
import rasterio


def extract_RGB(input_file, output_file=None):
    """Extracts the first 3 channels of an input (16Bit) geotiff file and saves them to a PNG or JPEG file.

    Args:
        input_file (str): The path to the input geotiff file.
        output_file (str, optional): The path to the output file. If None, the output file will be created in the same
            directory as the input file with the same base name and the appropriate extension based on the image format.
            Defaults to None.

    Raises:
        ValueError: If the specified output format is not supported.
    """
    import warnings
    warnings.filterwarnings('ignore', message='Dataset has no geotransform')
    warnings.filterwarnings('ignore', message='All-NaN slice encountered')
    # Open the input file and read its metadata
    with rasterio.open(input_file) as src:
        # height, width = src.shape
        # count = src.count
        # dtype = src.dtypes[0]

        # Read the first 3 channels of the input file
        data = src.read(indexes=[1, 2, 3])
    data_float = data.astype('float32')
    data_float[(data_float == 0) | (data_float == 65535.0)] = np.nan
    # NoData percentage threshold
    if np.count_nonzero(np.isnan(data_float)) / data_float.size > 0.2 :
        return
    # Calculate the percentile range of each channel
    min_vals, max_vals = [], []
    for channel in range(data_float.shape[0]):
        
        min_val, max_val = np.nanpercentile(data_float[channel], (2, 98))
        min_vals.append(min_val)
        max_vals.append(max_val)

    # Min-max stretch each channel to enhance contrast and handle out-of-range values
    for channel in range(data.shape[0]):
        # Clip the pixel values below the 2nd percentile to the 2nd percentile value
        data[channel] = np.clip(data[channel], min_vals[channel], None)
        # Clip the pixel values above the 98th percentile to the 98th percentile value
        data[channel] = np.clip(data[channel], None, max_vals[channel])
        # Min-max stretch the pixel values to the range [0, 255]
        if (max_vals[channel] - min_vals[channel]) == 0:
            return
        data[channel] = (data[channel] - min_vals[channel]) / (max_vals[channel] - min_vals[channel]) * 255

    # Cast the data to uint8
    data = data.astype(np.uint8)

    # Create a PIL image from the data
    image = Image.fromarray(np.transpose(data, [1, 2, 0]))

    # Set the output file path
    if output_file is None:
        output_file = os.path.splitext(input_file)[0] + '.png'
    else:
        output_file = os.path.splitext(output_file)[0] + '.png'

    # Determine the image format based on the file extension
    format = os.path.splitext(output_file)[1][1:].lower()
    if format not in ['png', 'jpg']:
        raise ValueError(f"Unsupported output format '{format}'. Supported formats are 'png' and 'jpg'.")

    # Save the image to the output file
    if format == 'png':
        image.save(output_file)
    elif format == 'jpg':
        image.convert('RGB').save(output_file)
!rm ./Dataset/Guochan_HR/*.png

In [23]:
!rm -f ./Dataset/Guochan_HR/*.png
dir_map("../超分重建数据/高分_切片去坐标/image_chips", "./Dataset/Guochan_HR", extract_RGB, old_ext='tif')
!ls -al ./Dataset/Guochan_HR/*.png | wc -l

Processing files: 100%|██████████| 2813/2813 [00:27<00:00, 100.99it/s]


2462


In [3]:
import rasterio
from rasterio.enums import Resampling

scale_factor = 1/2
def raster_rescale(in_file, out_file, scale_factor):

    with rasterio.open(in_file) as src: #（band
        # resample data to target shape
        data = src.read(
            out_shape=(
                src.count,
                int(src.height * scale_factor),
                int(src.width * scale_factor)
            ),
            resampling=Resampling.cubic
        )

        # scale image transform
        transform = src.transform * src.transform.scale(
            (src.width / data.shape[-1]),
            (src.height / data.shape[-2])
        )
        height = int(src.height * scale_factor)
        width = int(src.width * scale_factor)
        profile = src.meta.copy()
        profile.update({
            'driver': 'GTiff',
            'height': height,
            'width': width,
            'transform': transform,
        })
        with rasterio.open(out_file, 'w', **profile) as dst:
            dst.write(data)


In [4]:
from glob import glob
for in_file in glob('./Input/RS_LR_x1/*.tif'):
    raster_rescale(in_file, in_file.replace('RS_LR_x1', 'RS_LR_x4'), 1/4)

  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


In [None]:
for in_file in glob('./Input/RS_LR_x4/*.tif'):
    out_file = in_file.replace('.tif', '.png')
    !convert {in_file} {out_file}