# Raster Reclassification Script

## Overview
This notebook contains a Python script for reclassifying pixel values in raster datasets using the GDAL library. The `reclassify_raster` function processes raster files, allowing users to modify pixel values based on a specified reclassification dictionary. The function handles large datasets efficiently by processing them in manageable blocks.

## Function: `reclassify_raster`

### Purpose
The `reclassify_raster` function reads an input raster file, reclassifies its pixel values according to a user-defined mapping, and writes the reclassified values to an output raster file. This is useful for tasks such as data simplification, feature extraction, or changing classifications for further analysis.

### Parameters
- **`input_path` (str)**: Path to the input raster file (e.g., a GeoTIFF).
- **`output_path` (str)**: Path where the output raster will be saved.
- **`reclass_dict` (dict)**: A dictionary mapping original pixel values to new values. 
  - Keys can represent original pixel values or ranges (as tuples), and values represent the new reclassified values.
- **`nodata_value` (int or float)**: The value to assign to NoData pixels in the output raster.
- **`data_type` (int)**: The data type for the output raster. This should be a GDAL data type constant (e.g., `gdal.GDT_Byte`, `gdal.GDT_Int16`, etc.).

### Returns
- **None**: The function writes the reclassified raster to the output path and does not return any value.

### Notes
- The function assumes that the input raster can be processed in manageable blocks to avoid memory issues.
- It sets a NoData value for the output raster to properly handle missing data.

## Author
- Rubén Crespo Ceballos

In [1]:
from osgeo import gdal, gdalconst
import os
import numpy as np
import csv

In [None]:
"""For small rasters"""
def reclassify_raster(input_path, output_path, reclass_dict, nodata_value, data_type):
    """
    Reclassifies the pixel values of a raster dataset based on a provided reclassification dictionary.

    This function reads a raster file from the specified input path, reclassifies its pixel values 
    according to the given dictionary, and writes the reclassified values to a new raster file 
    at the specified output path. It preserves the original raster's projection and geotransform 
    information.

    Parameters:
        input_path (str): Path to the input raster file (e.g., a GeoTIFF).
        output_path (str): Path where the output raster will be saved.
        reclass_dict (dict): A dictionary mapping original pixel values to new values. 
                             Keys represent original values, and values represent the reclassified values.
        nodata_value (int or float): The value to assign to NoData pixels in the output raster.
        data_type (int): The data type for the output raster. This should be a GDAL data type constant 
                         (e.g., gdal.GDT_Byte, gdal.GDT_Int16, etc.).

    Returns:
        None: The function writes the reclassified raster to the output path and does not return any value.

    Notes:
        - The function assumes that the input raster is small enough to be processed in memory.
        - It sets a NoData value for the output raster to properly handle missing data.
    """
    
    # Open the input raster
    input_ds = gdal.Open(input_path, gdalconst.GA_ReadOnly)
    if input_ds is None:
        print(f"Error: Unable to open input raster {input_path}")
        return

    # Get raster information
    rows = input_ds.RasterYSize
    cols = input_ds.RasterXSize
    bands = input_ds.RasterCount

    # Create an output raster
    driver = gdal.GetDriverByName('GTiff')
    output_ds = driver.Create(output_path, cols, rows, bands, 
                              data_type,  # CUIDADO AQUI
                              options=['COMPRESS=DEFLATE', 'TILED=YES'])
    output_ds.SetProjection(input_ds.GetProjection())
    output_ds.SetGeoTransform(input_ds.GetGeoTransform())

    # Reclassify each pixel
    for band in range(1, bands + 1):
        input_band = input_ds.GetRasterBand(band)
        output_band = output_ds.GetRasterBand(band)

        # Set NoData value
        output_band.SetNoDataValue(nodata_value)  # Set nodata value to -2147483648 **CUIDADO**

        # Read the band data into an array
        data = input_band.ReadAsArray()

        # Create a copy of the original data to avoid in-place modification issues
        reclassified_data = np.copy(data)

        # Reclassify using the provided dictionary
        for key, value in reclass_dict.items():
            reclassified_data[data == key] = value

        # Write the reclassified data to the output band
        output_band.WriteArray(reclassified_data)

    # Close datasets
    input_ds = None
    output_ds = None

In [None]:
"""For big rasters"""
def reclassify_raster(input_path, output_path, reclass_dict, nodata_value, data_type, block_size=512):
    """
    Reclassifies the pixel values of a raster dataset in blocks based on a provided reclassification dictionary.

    This function reads a raster file from the specified input path, reclassifies its pixel values 
    according to the given dictionary, and writes the reclassified values to a new raster file 
    at the specified output path. It processes the raster in blocks to efficiently handle larger datasets.

    Parameters:
        input_path (str): Path to the input raster file (e.g., a GeoTIFF).
        output_path (str): Path where the output raster will be saved.
        reclass_dict (dict): A dictionary mapping original pixel values to new values. 
                             Keys can represent original values or ranges (as tuples), 
                             and values represent the reclassified values.
        nodata_value (int or float): The value to assign to NoData pixels in the output raster.
        data_type (int): The data type for the output raster. This should be a GDAL data type constant 
                         (e.g., gdal.GDT_Byte, gdal.GDT_Int16, etc.).
        block_size (int, optional): The size of the blocks to process the raster in. 
                                     Default is 512 pixels in each dimension.

    Returns:
        None: The function writes the reclassified raster to the output path and does not return any value.

    Notes:
        - The function assumes that the input raster can be processed in manageable blocks to avoid memory issues.
        - It sets a NoData value for the output raster to properly handle missing data.
    """
    # Open the input raster
    input_ds = gdal.Open(input_path, gdalconst.GA_ReadOnly)
    if input_ds is None:
        print(f"Error: Unable to open input raster {input_path}")
        return

    # Get raster information
    rows = input_ds.RasterYSize
    cols = input_ds.RasterXSize
    bands = input_ds.RasterCount

    # Create an output raster
    driver = gdal.GetDriverByName('GTiff')
    output_ds = driver.Create(output_path, cols, rows, bands, 
                              data_type,
                              options=['COMPRESS=DEFLATE', 'TILED=YES', 'COPY_SRC_OVERVIEWS=YES'])
    output_ds.SetProjection(input_ds.GetProjection())
    output_ds.SetGeoTransform(input_ds.GetGeoTransform())

    # Reclassify each pixel in blocks
    for band in range(1, bands + 1):
        input_band = input_ds.GetRasterBand(band)
        output_band = output_ds.GetRasterBand(band)

        # Set NODATA value
        output_band.SetNoDataValue(nodata_value)

        # Process the raster in blocks
        for i in range(0, rows, block_size):
            for j in range(0, cols, block_size):
                block_rows = min(block_size, rows - i)
                block_cols = min(block_size, cols - j)

                # Read block data
                data = input_band.ReadAsArray(j, i, block_cols, block_rows)

                # Create a copy for reclassification
                reclassified_block = np.copy(data)

                # Reclassify using the provided dictionary
                for key, value in reclass_dict.items():
                    if isinstance(key, tuple):  # Handle ranges
                        low, high = key
                        reclassified_block[(data >= low) & (data <= high)] = value
                    else:
                        reclassified_block[data == key] = value

                # Write the reclassified block to the output band
                output_band.WriteArray(reclassified_block, j, i)

    # Close datasets
    input_ds = None
    output_ds = None

In [None]:
def parse_value(value):
    """Convert a string to an integer, float, or np.nan as appropriate.
    Attempts to convert the string to a float. If the float is a whole number,
    it will convert to an integer. Returns np.nan if the conversion is not possible.

    Parameters:
        value (str): The string to parse.

    Returns:
        int, float, or np.nan: Parsed number or np.nan if conversion fails.
        
    """
    try:
        if value.lower() == 'nan':
            num = np.nan
        else:
            # Try to convert to a float
            num = float(value)
            # If it's a whole number, convert to int; otherwise, keep as float
            if num.is_integer():
                num = int(value)         
        return num
    except ValueError:
        return print(value, " is not valid")
        
def csv_to_dict(file_path):
    """
    Reads a CSV file without headers, containing two columns, and returns a dictionary. The CSV file should not have headers.

    Parameters:
        file_path (str): Path to the CSV file.

    Returns:
        dict: A dictionary with keys as integers or tuples of integers, and values as integers or np.nan.
    """
    result_dict = {}
    
    # Open the CSV file
    with open(file_path, mode='r') as file:
        csv_reader = csv.reader(file)
        
        for row in csv_reader:
            # Process the key (first column)
            key_str = row[0]
            if '-' in key_str:  # If key has a separator, make it a tuple
                key = tuple(parse_value(x) for x in key_str.split('-'))
            else:  # Otherwise, try to make it an integer or np.nan if empty/invalid
                key = parse_value(key_str)
            
            # Process the value (second column)
            value_str = row[1]
            value = parse_value(value_str)

            # Add the key-value pair to the dictionary
            result_dict[key] = value
            
    return result_dict

In [None]:
if __name__ == "__main__":
    input_folder = r"Y:\z_resources\im-nca-colombia\klab_ouputs\17.10.24\ecosystem_type_model\ecosystem_type_final_maps\mec"
    output_folder = r"Y:\z_resources\im-nca-colombia\klab_ouputs\17.10.24\ecosystem_type_model\ecosystem_type_final_maps\mec"
    reclass_file = r""
    
    """
    Please specify recoding values with two entries per line.
    The file is in a csv format.
    Class values not found in the image will be skipped.
    Class values not in the list will not be re-coded.
    If an old value appears on more than one line, the last one listed is used.

    Recoding rules:
    The first value is the original value and the second is the new recoded value:
    - original value can be a single float, int and range of values.
    - new recoded value can be a single float, int.
    Example, to recode class 1 to 12, a range from 1 to 100 to 22 and a nodata value to 0, enter one line each, like:
    1, 12
    1-100, 22
    nan, 0
    """

    # Open the csv to get the values
    reclass_dict = csv_to_dict(reclass_file)
    # Or do it manually
    # reclass_dict = {111: 1, 522: 43, 999: np.nan, np.nan: 0, (1, 100): 2}

    # Assign nodata value
    nodata_value = 1000
    # Assign nodata type
    data_type = gdal.GDT_Float32
    """
    Options:
    gdal.GDT_Byte,
    gdal.GDT_Int16,
    gdal.GDT_UInt16,
    gdal.GDT_Int32,
    gdal.GDT_UInt32,
    gdal.GDT_Float32,
    gdal.GDT_Float64
    """
    
    # Loop through each raster in the input folder
    for file_name in os.listdir(input_folder):
        if file_name.endswith(".tif") or file_name.endswith(".tiff"):  # Assuming raster files are in GeoTIFF format
            input_path = os.path.join(input_folder, file_name)
            output_name = os.path.splitext(file_name)[0] + "_reclass.tif" #_reclass.tif
            output_path = os.path.join(output_folder, output_name)

            reclassify_raster(input_path, output_path, reclass_dict, nodata_value, data_type)
            print("finished: " + output_name)
    print("Reclassification complete.")

finished: et_mec_colombia_2011_epsg3116_reclass.tif
finished: et_mec_colombia_2020_epsg3116_reclass.tif
Reclassification complete.
